Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RTFormer 预训练 #3788

Open
1 task done
mandylyin opened this issue Aug 29, 2024 · 3 comments
Open
1 task done

RTFormer 预训练 #3788

mandylyin opened this issue Aug 29, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@mandylyin
Copy link

mandylyin commented Aug 29, 2024

问题确认 Search before asking

  • 我已经搜索过问题,但是没有找到解答。I have searched the question and found no related answer.

请提出你的问题 Please ask your question

你好,请问怎么复现论文里的ImageNet预训练结果?我把RTFormer拷贝到paddleclas里,添加了分类头,修改了forward函数,其它没修改。仅使用RandCropImage和RandFlipImage数据增强top1大概66%,添加了其它数据增强后top1就只有不到20%,请问是什么原因导致的?我用4张卡训练。

分类头代码如下:

self.pool2d_avg = AdaptiveAvgPool2D(1)
self.fc = nn.Linear(
           base_chs * 4,
           class_num,
           weight_attr=ParamAttr(name="fc_weights"),
           bias_attr=ParamAttr(name="fc_offset"))

forward代码:

def forward(self, x):
      x1 = self.layer1(self.conv1(x))  # c, 1/4
      x2 = self.layer2(self.relu(x1))  # 2c, 1/8
      x3 = self.layer3(self.relu(x2))  # 4c, 1/16
      x3_ = x2 + F.interpolate(
           self.compression3(x3), size=paddle.shape(x2)[2:], mode='bilinear')
      x3_ = self.layer3_(self.relu(x3_))  # 2c, 1/8

      x4_, x4 = self.layer4(
           [self.relu(x3_), self.relu(x3)])  # 2c, 1/8; 8c, 1/16
      x5_, x5 = self.layer5(
           [self.relu(x4_), self.relu(x4)])  # 2c, 1/8; 8c, 1/32

      x6 = self.spp(x5)
      x6 = F.interpolate(
           x6, size=paddle.shape(x5_)[2:], mode='bilinear')  # 2c, 1/8
      x7 = paddle.concat([x5_, x6], axis=1) # 4c, 1/8
      x8 = self.pool2d_avg(x7)
      x9 = paddle.flatten(x8, start_axis=1, stop_axis=-1)
      out = self.fc(x9)

      return out
# global configs
Global:
  checkpoints: null
  pretrained_model: null
  output_dir: ./output/
  device: gpu
  save_interval: 1
  eval_during_train: True
  eval_interval: 1
  epochs: 300
  print_batch_step: 10
  use_visualdl: True
  # used for static mode and model export
  image_shape: [3, 224, 224]
  save_inference_dir: ./inference
  # use_dali: True

# model architecture
Arch:
  name: RTFormer_slim
  class_num: 1000
 
# loss function config for traing/eval process
Loss:
  Train:
    - CELoss:
        weight: 1.0
        epsilon: 0.1
  Eval:
    - CELoss:
        weight: 1.0

Optimizer:
  name: AdamW
  beta1: 0.9
  beta2: 0.999
  epsilon: 1e-8
  weight_decay: 0.04
  no_weight_decay_name: .bias norm
  one_dim_param_no_weight_decay: True
  lr:
    name: Cosine
    learning_rate: 0.0005
    eta_min: 0.0002
    warmup_epoch: 5
    warmup_start_lr: 5e-7

# data loader for train and eval
DataLoader:
  Train:
    dataset:
      name: ImageNetDataset
      image_root: ../../datasets/ILSVRC2012/
      cls_label_path: ../../datasets/ILSVRC2012/train.txt
      relabel: True
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - RandCropImage:
            size: 224
            interpolation: bicubic
            backend: pil
        - RandFlipImage:
            flip_code: 1
        # - TimmAutoAugment:
        #     config_str: rand-m9-mstd0.5-inc1
        #     interpolation: bicubic
        #     img_size: 224
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''
        - RandomErasing:
            EPSILON: 0.25
            sl: 0.02
            sh: 1.0/3.0
            r1: 0.3
            attempt: 10
            use_log_aspect: True
            mode: pixel
      # batch_transform_ops:
      #   - OpSampler:
      #       MixupOperator:
      #         alpha: 0.2
      #         prob: 0.5
      #       CutmixOperator:
      #         alpha: 1.0
      #         prob: 0.5

    sampler:
      name: DistributedBatchSampler
      batch_size: 256
      drop_last: False
      shuffle: True
    loader:
      num_workers: 8
      use_shared_memory: True

  Eval:
    dataset: 
      name: ImageNetDataset
      image_root: ../../datasets/ILSVRC2012/
      cls_label_path: ../../datasets/ILSVRC2012/val.txt
      relabel: True
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - ResizeImage:
            resize_short: 256
        - CropImage:
            size: 224
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''
    sampler:
      name: DistributedBatchSampler
      batch_size: 256
      drop_last: False
      shuffle: False
    loader:
      num_workers: 4
      use_shared_memory: True

Infer:
  infer_imgs: docs/images/inference_deployment/whl_demo.jpg
  batch_size: 10
  transforms:
    - DecodeImage:
        to_rgb: True
        channel_first: False
    - ResizeImage:
        resize_short: 256
    - CropImage:
        size: 224
    - NormalizeImage:
        scale: 1.0/255.0
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]
        order: ''
    - ToCHWImage:
  PostProcess:
    name: Topk
    topk: 5
    class_id_map_file: ppcls/utils/imagenet1k_label_list.txt

Metric:
  Train:
    - TopkAcc:
        topk: [1, 5]
  Eval:
    - TopkAcc:
        topk: [1, 5]

@mandylyin mandylyin added the question Further information is requested label Aug 29, 2024
@zhang-prog
Copy link
Collaborator

image
论文中的 minimum lr: 5e-6,你这里配置的 eta_min: 0.0002,差的有点多,可以改掉这里再试试。

@mandylyin
Copy link
Author

最一开始用的是5e-6,加了其它数据增强后top1就总上不去,被我一通改,也没跑出个所以然来。我再改回去试试。

@zhang-prog
Copy link
Collaborator

好的,可以参考论文里面的参数再试试看

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants
@mandylyin @zhang-prog and others