Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Configs] Add new config file for ade20k and maskformer #3722

Open
wants to merge 11 commits into
base: dev-1.x
Choose a base branch
from
80 changes: 80 additions & 0 deletions mmseg/configs/_base_/datasets/ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import (RandomFlip, RandomResize, Resize,
TestTimeAug)
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler

from mmseg.datasets.ade import ADE20KDataset
from mmseg.datasets.transforms.formatting import PackSegInputs
from mmseg.datasets.transforms.loading import LoadAnnotations
from mmseg.datasets.transforms.transforms import (PhotoMetricDistortion,
RandomCrop)
from mmseg.evaluation import IoUMetric

# dataset settings
dataset_type = ADE20KDataset
data_root = 'data/ade/ADEChallengeData2016'
crop_size = (512, 512)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=LoadAnnotations, reduce_zero_label=True),
dict(
type=RandomResize,
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type=RandomCrop, crop_size=crop_size, cat_max_ratio=0.75),
dict(type=RandomFlip, prob=0.5),
dict(type=PhotoMetricDistortion),
dict(type=PackSegInputs)
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type=LoadAnnotations, reduce_zero_label=True),
dict(type=PackSegInputs)
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type=LoadImageFromFile, backend_args=None),
dict(
type=TestTimeAug,
transforms=[[
dict(type=Resize, scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type=RandomFlip, prob=0., direction='horizontal'),
dict(type=RandomFlip, prob=1., direction='horizontal')
], [dict(type=LoadAnnotations)],
[dict(type=PackSegInputs)]])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type=DefaultSampler, shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type=IoUMetric, iou_metrics=['mIoU'])
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .maskformer_r50_d32_8xb2_160k_ade20k_512x512 import *

model.update(
dict(
backbone=dict(
depth=101,
init_cfg=dict(
type=PretrainedInit, checkpoint='torchvision://resnet101'))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.layers import PixelDecoder
from mmdet.models.losses import CrossEntropyLoss
from mmdet.models.losses.dice_loss import DiceLoss
from mmdet.models.losses.focal_loss import FocalLoss
from mmdet.models.task_modules.assigners import (ClassificationCost,
HungarianAssigner)
from mmdet.models.task_modules.assigners.match_cost import (DiceCost,
FocalLossCost)
from mmdet.models.task_modules.samplers.mask_pseudo_sampler import \
MaskPseudoSampler
from mmengine.config import read_base
from mmengine.model.weight_init import PretrainedInit
from mmengine.optim.scheduler.lr_scheduler import PolyLR
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import SyncBatchNorm as SyncBN
from torch.nn.modules.normalization import GroupNorm as GN
from torch.optim.adamw import AdamW

from mmseg.models.backbones import ResNet
from mmseg.models.data_preprocessor import SegDataPreProcessor
from mmseg.models.decode_heads import MaskFormerHead
from mmseg.models.segmentors import EncoderDecoder

with read_base():
from .._base_.datasets.ade20k import *
from .._base_.default_runtime import *
from .._base_.schedules.schedule_160k import *

norm_cfg = dict(type=SyncBN, requires_grad=True)
crop_size = (512, 512)
data_preprocessor = dict(
type=SegDataPreProcessor,
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
# model_cfg
num_classes = 150
model = dict(
type=EncoderDecoder,
data_preprocessor=data_preprocessor,
backbone=dict(
type=ResNet,
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=True,
style='pytorch',
contract_dilation=True,
init_cfg=dict(
type=PretrainedInit, checkpoint='torchvision://resnet50')),
decode_head=dict(
type=MaskFormerHead,
in_channels=[256, 512, 1024,
2048], # input channels of pixel_decoder modules
feat_channels=256,
in_index=[0, 1, 2, 3],
num_classes=150,
out_channels=256,
num_queries=100,
pixel_decoder=dict(
type=PixelDecoder,
norm_cfg=dict(type=GN, num_groups=32),
act_cfg=dict(type=ReLU)),
enforce_decoder_input_project=False,
positional_encoding=dict( # SinePositionalEncoding
num_feats=128, normalize=True),
transformer_decoder=dict( # DetrTransformerDecoder
return_intermediate=True,
num_layers=6,
layer_cfg=dict( # DetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=None,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type=ReLU, inplace=True),
ffn_drop=0.1,
dropout_layer=None,
add_identity=True)),
init_cfg=None),
loss_cls=dict(
type=CrossEntropyLoss,
use_sigmoid=False,
loss_weight=1.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type=FocalLoss,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=20.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=1.0),
train_cfg=dict(
assigner=dict(
type=HungarianAssigner,
match_costs=[
dict(type=ClassificationCost, weight=1.0),
dict(type=FocalLossCost, weight=20.0, binary_input=True),
dict(type=DiceCost, weight=1.0, pred_act=True, eps=1.0)
]),
sampler=dict(type=MaskPseudoSampler))),
# training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'),
)
# optimizer
optimizer.update(
dict(type=AdamW, lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001))
optim_wrapper.merge(
dict(
_delete_=True,
type=OptimWrapper,
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
})))
# learning policy
param_scheduler = [
dict(
type=PolyLR, eta_min=0, power=0.9, begin=0, end=160000, by_epoch=False)
]

# In MaskFormer implementation we use batch size 2 per GPU as default
train_dataloader.update(dict(batch_size=2, num_workers=2))
val_dataloader.update(dict(batch_size=1, num_workers=4))
test_dataloader = val_dataloader
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base
from mmengine.optim.scheduler.lr_scheduler import LinearLR
from torch.nn.modules.activation import GELU
from torch.nn.modules.normalization import LayerNorm as LN

from mmseg.models.backbones import SwinTransformer

with read_base():
from .maskformer_r50_d32_8xb2_160k_ade20k_512x512 import *

checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_small_patch4_window7_224_20220317-7ba6d6dd.pth' # noqa

backbone_norm_cfg = dict(type=LN, requires_grad=True)
depths = [2, 2, 18, 2]
model.merge(
dict(
backbone=dict(
_delete_=True,
type=SwinTransformer,
pretrain_img_size=224,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=depths,
num_heads=[3, 6, 12, 24],
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
use_abs_pos_embed=False,
act_cfg=dict(type=GELU),
norm_cfg=backbone_norm_cfg,
init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint_file)),
decode_head=dict(
type=MaskFormerHead,
in_channels=[96, 192, 384,
768], # input channels of pixel_decoder modules
)))

# optimizer
optimizer.update(
dict(type=AdamW, lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01))
# set all layers in backbone to lr_mult=1.0
# set all norm layers, position_embeding,
# query_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=1.0, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
embed_multi = dict(decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optim_wrapper.merge(
dict(
_delete_=True,
type=OptimWrapper,
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(custom_keys=custom_keys)))

# learning policy
param_scheduler = [
dict(type=LinearLR, start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type=PolyLR,
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]
Loading