From 1e937961b38d0c4b4d6616f58f94d6e092eba252 Mon Sep 17 00:00:00 2001 From: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Date: Wed, 9 Aug 2023 23:57:01 +0800 Subject: [PATCH] [CodeCamp2023-367] Add pp_mobileseg model (#3239) --- projects/pp_mobileseg/README.md | 58 ++ projects/pp_mobileseg/backbones/__init__.py | 4 + .../pp_mobileseg/backbones/strideformer.py | 958 ++++++++++++++++++ .../configs/_base_/datasets/ade20k.py | 68 ++ .../configs/_base_/default_runtime.py | 15 + .../configs/_base_/models/pp_mobile.py | 47 + .../configs/_base_/schedules/schedule_80k.py | 24 + ...obilenetv3_2x16_80k_ade20k_512x512_base.py | 13 + ...obilenetv3_2x16_80k_ade20k_512x512_tiny.py | 39 + projects/pp_mobileseg/decode_head/__init__.py | 6 + .../decode_head/pp_mobileseg_head.py | 94 ++ 11 files changed, 1326 insertions(+) create mode 100644 projects/pp_mobileseg/README.md create mode 100644 projects/pp_mobileseg/backbones/__init__.py create mode 100644 projects/pp_mobileseg/backbones/strideformer.py create mode 100644 projects/pp_mobileseg/configs/_base_/datasets/ade20k.py create mode 100644 projects/pp_mobileseg/configs/_base_/default_runtime.py create mode 100644 projects/pp_mobileseg/configs/_base_/models/pp_mobile.py create mode 100644 projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py create mode 100644 projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py create mode 100644 projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py create mode 100644 projects/pp_mobileseg/decode_head/__init__.py create mode 100644 projects/pp_mobileseg/decode_head/pp_mobileseg_head.py diff --git a/projects/pp_mobileseg/README.md b/projects/pp_mobileseg/README.md new file mode 100644 index 0000000000..effb5950ad --- /dev/null +++ b/projects/pp_mobileseg/README.md @@ -0,0 +1,58 @@ +# PP-MobileSeg: Exploring Transformer Blocks for Efficient Mobile Segmentation. + +## Reference + +> [PP-MobileSeg: Explore the Fast and Accurate Semantic Segmentation Model on Mobile Devices. ](https://arxiv.org/abs/2304.05152) + +## Introduction + +Official Repo + +Code Snippet + +## Abstract + +With the success of transformers in computer vision, several attempts have been made to adapt transformers to mobile devices. However, their performance is not satisfied for some real world applications. Therefore, we propose PP-MobileSeg, a SOTA semantic segmentation model for mobile devices. + +It is composed of three newly proposed parts, the strideformer backbone, the Aggregated Attention Module(AAM), and the Valid Interpolate Module(VIM): + +- With the four-stage MobileNetV3 block as the feature extractor, we manage to extract rich local features of different receptive fields with little parameter overhead. Also, we further efficiently empower features from the last two stages with the global view using strided sea attention. +- To effectively fuse the features, we use AAM to filter the detail features with ensemble voting and add the semantic feature to it to enhance the semantic information to the most content. +- At last, we use VIM to upsample the downsampled feature to the original resolution and significantly decrease latency in model inference stage. It only interpolates classes present in the final prediction which only takes around 10% in the ADE20K dataset. This is a common scenario for datasets with large classes. Therefore it significantly decreases the latency of the final upsample process which takes the greatest part of the model's overall latency. + +Extensive experiments show that PP-MobileSeg achieves a superior params-accuracy-latency tradeoff compared to other SOTA methods. + +
+ +
+ +## Performance + +### ADE20K + +| Model | Backbone | Training Iters | Batchsize | Train Resolution | mIoU(%) | latency(ms)\* | params(M) | config | Links | +| ----------------- | ----------------- | -------------- | --------- | ---------------- | ------- | ------------- | --------- | ------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| PP-MobileSeg-Base | StrideFormer-Base | 80000 | 32 | 512x512 | 41.57% | 265.5 | 5.62 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base-ed0be681.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_base/train.log) | +| PP-MobileSeg-Tiny | StrideFormer-Tiny | 80000 | 32 | 512x512 | 36.39% | 215.3 | 1.61 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny-e4b35e96.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_tiny/train.log) | + +## Citation + +If you find our project useful in your research, please consider citing: + +``` +@misc{liu2021paddleseg, + title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation}, + author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao}, + year={2021}, + eprint={2101.06175}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +@misc{paddleseg2019, + title={PaddleSeg, End-to-end image segmentation kit based on PaddlePaddle}, + author={PaddlePaddle Contributors}, + howpublished = {\url{https://github.com/PaddlePaddle/PaddleSeg}}, + year={2019} +} +``` diff --git a/projects/pp_mobileseg/backbones/__init__.py b/projects/pp_mobileseg/backbones/__init__.py new file mode 100644 index 0000000000..244b33d37a --- /dev/null +++ b/projects/pp_mobileseg/backbones/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .strideformer import StrideFormer + +__all__ = ['StrideFormer'] diff --git a/projects/pp_mobileseg/backbones/strideformer.py b/projects/pp_mobileseg/backbones/strideformer.py new file mode 100644 index 0000000000..3f09be5225 --- /dev/null +++ b/projects/pp_mobileseg/backbones/strideformer.py @@ -0,0 +1,958 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer +from mmcv.cnn.bricks.transformer import build_dropout +from mmengine.logging import print_log +from mmengine.model import BaseModule +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class StrideFormer(BaseModule): + """The StrideFormer implementation based on torch. + + The original article refers to:https://arxiv.org/abs/2304.05152 + Args: + mobileV3_cfg(list): Each sublist describe the config for a + MobileNetV3 block. + channels(list): The input channels for each MobileNetV3 block. + embed_dims(list): The channels of the features input to the sea + attention block. + key_dims(list, optional): The embeding dims for each head in + attention. + depths(list, optional): describes the depth of the attention block. + i,e: M,N. + num_heads(int, optional): The number of heads of the attention + blocks. + attn_ratios(int, optional): The expand ratio of V. + mlp_ratios(list, optional): The ratio of mlp blocks. + drop_path_rate(float, optional): The drop path rate in attention + block. + act_cfg(dict, optional): The activation layer of AAM: + Aggregate Attention Module. + inj_type(string, optional): The type of injection/AAM. + out_channels(int, optional): The output channels of the AAM. + dims(list, optional): The dimension of the fusion block. + out_feat_chs(list, optional): The input channels of the AAM. + stride_attention(bool, optional): whether to stride attention in + each attention layer. + pretrained(str, optional): the path of pretrained model. + """ + + def __init__( + self, + mobileV3_cfg, + channels, + embed_dims, + key_dims=[16, 24], + depths=[2, 2], + num_heads=8, + attn_ratios=2, + mlp_ratios=[2, 4], + drop_path_rate=0.1, + act_cfg=dict(type='ReLU'), + inj_type='AAM', + out_channels=256, + dims=(128, 160), + out_feat_chs=None, + stride_attention=True, + pretrained=None, + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + assert not (init_cfg and pretrained + ), 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.cfgs = mobileV3_cfg + self.dims = dims + for i in range(len(self.cfgs)): + smb = StackedMV3Block( + cfgs=self.cfgs[i], + stem=True if i == 0 else False, + in_channels=channels[i], + ) + setattr(self, f'smb{i + 1}', smb) + for i in range(len(depths)): + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths[i]) + ] + trans = BasicLayer( + block_num=depths[i], + embedding_dim=embed_dims[i], + key_dim=key_dims[i], + num_heads=num_heads, + mlp_ratio=mlp_ratios[i], + attn_ratio=attn_ratios, + drop=0, + attn_drop=0.0, + drop_path=dpr, + act_cfg=act_cfg, + stride_attention=stride_attention, + ) + setattr(self, f'trans{i + 1}', trans) + + self.inj_type = inj_type + if self.inj_type == 'AAM': + self.inj_module = InjectionMultiSumallmultiallsum( + in_channels=out_feat_chs, out_channels=out_channels) + self.feat_channels = [ + out_channels, + ] + elif self.inj_type == 'AAMSx8': + self.inj_module = InjectionMultiSumallmultiallsumSimpx8( + in_channels=out_feat_chs, out_channels=out_channels) + self.feat_channels = [ + out_channels, + ] + elif self.inj_type == 'origin': + for i in range(len(dims)): + fuse = FusionBlock( + out_feat_chs[0] if i == 0 else dims[i - 1], + out_feat_chs[i + 1], + embed_dim=dims[i], + act_cfg=None, + ) + setattr(self, f'fuse{i + 1}', fuse) + self.feat_channels = [ + dims[i], + ] + else: + raise NotImplementedError(self.inj_module + ' is not implemented') + + self.pretrained = pretrained + # self.init_weights() + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + print_log(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), + self.interpolate_mode, + ) + + load_state_dict(self, state_dict, strict=False, logger=None) + + def forward(self, x): + x_hw = x.shape[2:] + outputs = [] + num_smb_stage = len(self.cfgs) + num_trans_stage = len(self.depths) + + for i in range(num_smb_stage): + smb = getattr(self, f'smb{i + 1}') + x = smb(x) + + # 1/8 shared feat + if i == 1: + outputs.append(x) + if num_trans_stage + i >= num_smb_stage: + trans = getattr( + self, f'trans{i + num_trans_stage - num_smb_stage + 1}') + x = trans(x) + outputs.append(x) + if self.inj_type == 'origin': + x_detail = outputs[0] + for i in range(len(self.dims)): + fuse = getattr(self, f'fuse{i + 1}') + + x_detail = fuse(x_detail, outputs[i + 1]) + output = x_detail + else: + output = self.inj_module(outputs) + + return [output, x_hw] + + +class StackedMV3Block(nn.Module): + """The MobileNetV3 block. + + Args: + cfgs (list): The MobileNetV3 config list of a stage. + stem (bool): Whether is the first stage or not. + in_channels (int, optional): The channels of input image. Default: 3. + scale: float=1.0. + The coefficient that controls the size of network parameters. + + Returns: + model: nn.Module. + A stage of specific MobileNetV3 model depends on args. + """ + + def __init__(self, + cfgs, + stem, + in_channels, + scale=1.0, + norm_cfg=dict(type='BN')): + super().__init__() + + self.scale = scale + self.stem = stem + + if self.stem: + self.conv = ConvModule( + in_channels=3, + out_channels=_make_divisible(in_channels * self.scale), + kernel_size=3, + stride=2, + padding=1, + groups=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=dict(type='HSwish'), + ) + + self.blocks = nn.ModuleList() + for i, (k, exp, c, se, act, s) in enumerate(cfgs): + self.blocks.append( + ResidualUnit( + in_channel=_make_divisible(in_channels * self.scale), + mid_channel=_make_divisible(self.scale * exp), + out_channel=_make_divisible(self.scale * c), + kernel_size=k, + stride=s, + use_se=se, + act=act, + dilation=1, + )) + in_channels = _make_divisible(self.scale * c) + + def forward(self, x): + if self.stem: + x = self.conv(x) + for i, block in enumerate(self.blocks): + x = block(x) + + return x + + +class ResidualUnit(nn.Module): + """The Residual module. + + Args: + in_channel (int, optional): The channels of input feature. + mid_channel (int, optional): The channels of middle process. + out_channel (int, optional): The channels of output feature. + kernel_size (int, optional): The size of the convolving kernel. + stride (int, optional): The stride size. + use_se (bool, optional): if to use the SEModule. + act (string, optional): activation layer. + dilation (int, optional): The dilation size. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + """ + + def __init__( + self, + in_channel, + mid_channel, + out_channel, + kernel_size, + stride, + use_se, + act=None, + dilation=1, + norm_cfg=dict(type='BN'), + ): + super().__init__() + self.if_shortcut = stride == 1 and in_channel == out_channel + self.if_se = use_se + self.expand_conv = ConvModule( + in_channels=in_channel, + out_channels=mid_channel, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=dict(type=act) if act is not None else None, + ) + self.bottleneck_conv = ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=kernel_size, + stride=stride, + padding=int((kernel_size - 1) // 2) * dilation, + bias=False, + groups=mid_channel, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=dict(type=act) if act is not None else None, + ) + if self.if_se: + self.mid_se = SEModule(mid_channel) + self.linear_conv = ConvModule( + in_channels=mid_channel, + out_channels=out_channel, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + ) + + def forward(self, x): + identity = x + x = self.expand_conv(x) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = torch.add(identity, x) + return x + + +class SEModule(nn.Module): + """SE Module. + + Args: + channel (int, optional): The channels of input feature. + reduction (int, optional): The channel reduction rate. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + """ + + def __init__(self, channel, reduction=4, act_cfg=dict(type='ReLU')): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_act1 = ConvModule( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + norm_cfg=None, + act_cfg=act_cfg, + ) + + self.conv_act2 = ConvModule( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + norm_cfg=None, + act_cfg=dict(type='Hardsigmoid', slope=0.2, offset=0.5), + ) + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv_act1(x) + x = self.conv_act2(x) + return torch.mul(identity, x) + + +class BasicLayer(nn.Module): + """The transformer basic layer. + + Args: + block_num (int): the block nums of the transformer basic layer. + embedding_dim (int): The feature dimension. + key_dim (int): the key dim. + num_heads (int): Parallel attention heads. + mlp_ratio (float): the mlp ratio. + attn_ratio (float): the attention ratio. + drop (float): Probability of an element to be zeroed + after the feed forward layer.Default: 0.0. + attn_drop (float): The drop out rate for attention layer. + Default: 0.0. + drop_path (float): stochastic depth rate. Default 0.0. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + stride_attention (bool, optional): whether to stride attention in + each attention layer. + """ + + def __init__( + self, + block_num, + embedding_dim, + key_dim, + num_heads, + mlp_ratio=4.0, + attn_ratio=2.0, + drop=0.0, + attn_drop=0.0, + drop_path=None, + act_cfg=None, + stride_attention=None, + ): + super().__init__() + self.block_num = block_num + + self.transformer_blocks = nn.ModuleList() + for i in range(self.block_num): + self.transformer_blocks.append( + Block( + embedding_dim, + key_dim=key_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attn_ratio=attn_ratio, + drop=drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + act_cfg=act_cfg, + stride_attention=stride_attention, + )) + + def forward(self, x): + for i in range(self.block_num): + x = self.transformer_blocks[i](x) + return x + + +class Block(nn.Module): + """the block of the transformer basic layer. + + Args: + dim (int): The feature dimension. + key_dim (int): The key dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (float): the mlp ratio. + attn_ratio (float): the attention ratio. + drop (float): Probability of an element to be zeroed + after the feed forward layer.Default: 0.0. + drop_path (float): stochastic depth rate. Default 0.0. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + stride_attention (bool, optional): whether to stride attention in + each attention layer. + """ + + def __init__( + self, + dim, + key_dim, + num_heads, + mlp_ratio=4.0, + attn_ratio=2.0, + drop=0.0, + drop_path=0.0, + act_cfg=None, + stride_attention=None, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn = SeaAttention( + dim, + key_dim=key_dim, + num_heads=num_heads, + attn_ratio=attn_ratio, + act_cfg=act_cfg, + stride_attention=stride_attention, + ) + self.drop_path = ( + build_dropout(dict(type='DropPath', drop_prob=drop_path)) + if drop_path > 0.0 else nn.Identity()) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_cfg=act_cfg, + drop=drop, + ) + + def forward(self, x1): + x1 = x1 + self.drop_path(self.attn(x1)) + x1 = x1 + self.drop_path(self.mlp(x1)) + + return x1 + + +class SqueezeAxialPositionalEmbedding(nn.Module): + """the Squeeze Axial Positional Embedding. + + Args: + dim (int): The feature dimension. + shape (int): The patch size. + """ + + def __init__(self, dim, shape): + super().__init__() + self.pos_embed = nn.init.normal_( + nn.Parameter(torch.zeros(1, dim, shape))) + + def forward(self, x): + B, C, N = x.shape + x = x + F.interpolate( + self.pos_embed, size=(N, ), mode='linear', align_corners=False) + return x + + +class SeaAttention(nn.Module): + """The sea attention. + + Args: + dim (int): The feature dimension. + key_dim (int): The key dimension. + num_heads (int): number of attention heads. + attn_ratio (float): the attention ratio. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + stride_attention (bool, optional): whether to stride attention in + each attention layer. + """ + + def __init__( + self, + dim, + key_dim, + num_heads, + attn_ratio=4.0, + act_cfg=None, + norm_cfg=dict(type='BN'), + stride_attention=False, + ): + + super().__init__() + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + + self.to_q = ConvModule( + dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None) + self.to_k = ConvModule( + dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None) + + self.to_v = ConvModule( + dim, self.dh, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None) + self.stride_attention = stride_attention + if self.stride_attention: + self.stride_conv = ConvModule( + dim, + dim, + kernel_size=3, + stride=2, + padding=1, + bias=True, + groups=dim, + norm_cfg=norm_cfg, + act_cfg=None, + ) + + self.proj = ConvModule( + self.dh, + dim, + 1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('act', 'conv', 'norm'), + ) + self.proj_encode_row = ConvModule( + self.dh, + self.dh, + 1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('act', 'conv', 'norm'), + ) + self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16) + self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16) + self.proj_encode_column = ConvModule( + self.dh, + self.dh, + 1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('act', 'conv', 'norm'), + ) + self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16) + self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16) + self.dwconv = ConvModule( + 2 * self.dh, + 2 * self.dh, + 3, + padding=1, + groups=2 * self.dh, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + self.pwconv = ConvModule( + 2 * self.dh, dim, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None) + self.sigmoid = build_activation_layer(dict(type='HSigmoid')) + + def forward(self, x): + B, C, H_ori, W_ori = x.shape + if self.stride_attention: + x = self.stride_conv(x) + B, C, H, W = x.shape + + q = self.to_q(x) # [B, nhead*dim, H, W] + k = self.to_k(x) + v = self.to_v(x) + + qkv = torch.cat([q, k, v], dim=1) + qkv = self.dwconv(qkv) + qkv = self.pwconv(qkv) + + qrow = (self.pos_emb_rowq(q.mean(-1)).reshape( + [B, self.num_heads, -1, H]).permute( + (0, 1, 3, 2))) # [B, nhead, H, dim] + krow = self.pos_emb_rowk(k.mean(-1)).reshape( + [B, self.num_heads, -1, H]) # [B, nhead, dim, H] + vrow = (v.mean(-1).reshape([B, self.num_heads, -1, + H]).permute([0, 1, 3, 2]) + ) # [B, nhead, H, dim*attn_ratio] + + attn_row = torch.matmul(qrow, krow) * self.scale # [B, nhead, H, H] + attn_row = nn.functional.softmax(attn_row, dim=-1) + + xx_row = torch.matmul(attn_row, vrow) # [B, nhead, H, dim*attn_ratio] + xx_row = self.proj_encode_row( + xx_row.permute([0, 1, 3, 2]).reshape([B, self.dh, H, 1])) + + # squeeze column + qcolumn = ( + self.pos_emb_columnq(q.mean(-2)).reshape( + [B, self.num_heads, -1, W]).permute([0, 1, 3, 2])) + kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape( + [B, self.num_heads, -1, W]) + vcolumn = ( + torch.mean(v, -2).reshape([B, self.num_heads, -1, + W]).permute([0, 1, 3, 2])) + + attn_column = torch.matmul(qcolumn, kcolumn) * self.scale + attn_column = nn.functional.softmax(attn_column, dim=-1) + + xx_column = torch.matmul(attn_column, vcolumn) # B nH W C + xx_column = self.proj_encode_column( + xx_column.permute([0, 1, 3, 2]).reshape([B, self.dh, 1, W])) + + xx = torch.add(xx_row, xx_column) # [B, self.dh, H, W] + xx = torch.add(v, xx) + + xx = self.proj(xx) + xx = self.sigmoid(xx) * qkv + if self.stride_attention: + xx = F.interpolate(xx, size=(H_ori, W_ori), mode='bilinear') + + return xx + + +class MLP(nn.Module): + """the Multilayer Perceptron. + + Args: + in_features (int): the input feature. + hidden_features (int): the hidden feature. + out_features (int): the output feature. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + drop (float): Probability of an element to be zeroed. + Default 0.0 + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=None, + norm_cfg=dict(type='BN'), + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ConvModule( + in_features, + hidden_features, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + ) + self.dwconv = ConvModule( + hidden_features, + hidden_features, + kernel_size=3, + padding=1, + groups=hidden_features, + norm_cfg=None, + act_cfg=act_cfg, + ) + + self.fc2 = ConvModule( + hidden_features, + out_features, + 1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + ) + self.drop = build_dropout(dict(type='Dropout', drop_prob=drop)) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class FusionBlock(nn.Module): + """The feature fusion block. + + Args: + in_channel (int): the input channel. + out_channel (int): the output channel. + embed_dim (int): embedding dimension. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + """ + + def __init__( + self, + in_channel, + out_channel, + embed_dim, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + ) -> None: + super().__init__() + self.local_embedding = ConvModule( + in_channels=in_channel, + out_channels=embed_dim, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + ) + + self.global_act = ConvModule( + in_channels=out_channel, + out_channels=embed_dim, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg if act_cfg is not None else None, + ) + + def forward(self, x_l, x_g): + """ + x_g: global features + x_l: local features + """ + B, C, H, W = x_l.shape + + local_feat = self.local_embedding(x_l) + global_act = self.global_act(x_g) + sig_act = F.interpolate( + global_act, size=(H, W), mode='bilinear', align_corners=False) + + out = local_feat * sig_act + + return out + + +class InjectionMultiSumallmultiallsum(nn.Module): + """the Aggregate Attention Module. + + Args: + in_channels (tuple): the input channel. + out_channels (int): the output channel. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + """ + + def __init__( + self, + in_channels=(64, 128, 256, 384), + out_channels=256, + act_cfg=dict(type='Sigmoid'), + norm_cfg=dict(type='BN'), + ): + super().__init__() + self.embedding_list = nn.ModuleList() + self.act_embedding_list = nn.ModuleList() + self.act_list = nn.ModuleList() + for i in range(len(in_channels)): + self.embedding_list.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + )) + self.act_embedding_list.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + )) + + def forward(self, inputs): # x_x8, x_x16, x_x32, x_x64 + low_feat1 = F.interpolate(inputs[0], scale_factor=0.5, mode='bilinear') + low_feat1_act = self.act_embedding_list[0](low_feat1) + low_feat1 = self.embedding_list[0](low_feat1) + + low_feat2 = F.interpolate( + inputs[1], size=low_feat1.shape[-2:], mode='bilinear') + low_feat2_act = self.act_embedding_list[1](low_feat2) # x16 + low_feat2 = self.embedding_list[1](low_feat2) + + high_feat_act = F.interpolate( + self.act_embedding_list[2](inputs[2]), + size=low_feat2.shape[2:], + mode='bilinear', + ) + high_feat = F.interpolate( + self.embedding_list[2](inputs[2]), + size=low_feat2.shape[2:], + mode='bilinear') + + res = ( + low_feat1_act * low_feat2_act * high_feat_act * + (low_feat1 + low_feat2) + high_feat) + + return res + + +class InjectionMultiSumallmultiallsumSimpx8(nn.Module): + """the Aggregate Attention Module. + + Args: + in_channels (tuple): the input channel. + out_channels (int): the output channel. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + """ + + def __init__( + self, + in_channels=(64, 128, 256, 384), + out_channels=256, + act_cfg=dict(type='Sigmoid'), + norm_cfg=dict(type='BN'), + ): + super().__init__() + self.embedding_list = nn.ModuleList() + self.act_embedding_list = nn.ModuleList() + self.act_list = nn.ModuleList() + for i in range(len(in_channels)): + if i != 1: + self.embedding_list.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None, + )) + if i != 0: + self.act_embedding_list.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + )) + + def forward(self, inputs): + # x_x8, x_x16, x_x32 + low_feat1 = self.embedding_list[0](inputs[0]) + + low_feat2 = F.interpolate( + inputs[1], size=low_feat1.shape[-2:], mode='bilinear') + low_feat2_act = self.act_embedding_list[0](low_feat2) + + high_feat_act = F.interpolate( + self.act_embedding_list[1](inputs[2]), + size=low_feat2.shape[2:], + mode='bilinear', + ) + high_feat = F.interpolate( + self.embedding_list[1](inputs[2]), + size=low_feat2.shape[2:], + mode='bilinear') + + res = low_feat2_act * high_feat_act * low_feat1 + high_feat + + return res + + +def _make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +@MODELS.register_module() +class Hardsigmoid(nn.Module): + """the hardsigmoid activation. + + Args: + slope (float, optional): The slope of hardsigmoid function. + Default is 0.1666667. + offset (float, optional): The offset of hardsigmoid function. + Default is 0.5. + inplace (bool): can optionally do the operation in-place. + Default: ``False`` + """ + + def __init__(self, slope=0.1666667, offset=0.5, inplace=False): + super().__init__() + self.slope = slope + self.offset = offset + + def forward(self, x): + return (x * self.slope + self.offset).clamp(0, 1) diff --git a/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py b/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py new file mode 100644 index 0000000000..48340d11ee --- /dev/null +++ b/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py @@ -0,0 +1,68 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict( + type='RandomResize', + scale=(2048, 512), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 512), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='PackSegInputs') +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/training', seg_map_path='annotations/training'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/validation', + seg_map_path='annotations/validation'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/pp_mobileseg/configs/_base_/default_runtime.py b/projects/pp_mobileseg/configs/_base_/default_runtime.py new file mode 100644 index 0000000000..272b4d2467 --- /dev/null +++ b/projects/pp_mobileseg/configs/_base_/default_runtime.py @@ -0,0 +1,15 @@ +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') +log_processor = dict(by_epoch=False) +log_level = 'INFO' +load_from = None +resume = False + +tta_model = dict(type='SegTTAModel') diff --git a/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py b/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py new file mode 100644 index 0000000000..0c7695636f --- /dev/null +++ b/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) + +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + # pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='StrideFormer', + mobileV3_cfg=[ + # k t c, s + [[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2], + [3, 96, 32, False, 'ReLU', 1]], # cfg1 + [[5, 128, 64, True, 'HSwish', 2], [5, 240, 64, True, 'HSwish', + 1]], # cfg2 + [[5, 384, 128, True, 'HSwish', 2], + [5, 384, 128, True, 'HSwish', 1]], # cfg3 + [[5, 768, 192, True, 'HSwish', 2], + [5, 768, 192, True, 'HSwish', 1]], # cfg4 + ], + channels=[16, 32, 64, 128, 192], + depths=[3, 3], + embed_dims=[128, 192], + num_heads=8, + inj_type='AAMSx8', + out_feat_chs=[64, 128, 192], + act_cfg=dict(type='ReLU6'), + ), + decode_head=dict( + type='PPMobileSegHead', + num_classes=150, + in_channels=256, + dropout_ratio=0.1, + use_dw=True, + act_cfg=dict(type='ReLU'), + align_corners=False), + + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py b/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000000..0dcd6c4d1b --- /dev/null +++ b/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,24 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) +# learning policy +param_scheduler = [ + dict( + type='PolyLR', + eta_min=1e-4, + power=0.9, + begin=0, + end=80000, + by_epoch=False) +] +# training schedule for 80k +train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) diff --git a/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py new file mode 100644 index 0000000000..d43b007466 --- /dev/null +++ b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py @@ -0,0 +1,13 @@ +_base_ = [ + '../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +checkpoint = './models/pp_mobile_base.pth' +crop_size = (512, 512) +data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32)) +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint), ), + decode_head=dict(num_classes=150, upsample='intepolate'), +) diff --git a/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py new file mode 100644 index 0000000000..20b1189cd4 --- /dev/null +++ b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py @@ -0,0 +1,39 @@ +_base_ = [ + '../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +checkpoint = './models/pp_mobile_tiny.pth' +crop_size = (512, 512) +data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32)) +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict( + init_cfg=dict(type='Pretrained', checkpoint=checkpoint), + type='StrideFormer', + mobileV3_cfg=[ + # k t c, s + [[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2], + [3, 48, 24, False, 'ReLU', 1]], # cfg1 + [[5, 96, 32, True, 'HSwish', 2], [5, 96, 32, True, 'HSwish', + 1]], # cfg2 + [[5, 160, 64, True, 'HSwish', 2], [5, 160, 64, True, 'HSwish', + 1]], # cfg3 + [[3, 384, 128, True, 'HSwish', 2], + [3, 384, 128, True, 'HSwish', 1]], # cfg4 + ], + channels=[16, 24, 32, 64, 128], + depths=[2, 2], + embed_dims=[64, 128], + num_heads=4, + inj_type='AAM', + out_feat_chs=[32, 64, 128], + act_cfg=dict(type='ReLU6'), + ), + decode_head=dict( + num_classes=150, + in_channels=256, + use_dw=True, + act_cfg=dict(type='ReLU'), + upsample='intepolate'), +) diff --git a/projects/pp_mobileseg/decode_head/__init__.py b/projects/pp_mobileseg/decode_head/__init__.py new file mode 100644 index 0000000000..6f71b784e1 --- /dev/null +++ b/projects/pp_mobileseg/decode_head/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .pp_mobileseg_head import PPMobileSegHead + +__all__ = [ + 'PPMobileSegHead', +] diff --git a/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py b/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py new file mode 100644 index 0000000000..243f026372 --- /dev/null +++ b/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_conv_layer +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class PPMobileSegHead(nn.Module): + """the segmentation head. + + Args: + num_classes (int): the classes num. + in_channels (int): the input channels. + use_dw (bool): if to use deepwith convolution. + dropout_ratio (float): Probability of an element to be zeroed. + Default 0.0。 + align_corners (bool, optional): Geometrically, we consider the pixels + of the input and output as squares rather than points. + upsample (str): the upsample method. + out_channels (int): the output channel. + conv_cfg (dict): Config dict for convolution layer. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, + num_classes, + in_channels, + use_dw=True, + dropout_ratio=0.1, + align_corners=False, + upsample='intepolate', + out_channels=None, + conv_cfg=dict(type='Conv'), + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN')): + + super().__init__() + self.align_corners = align_corners + self.last_channels = in_channels + self.upsample = upsample + self.num_classes = num_classes + self.out_channels = out_channels + self.linear_fuse = ConvModule( + in_channels=self.last_channels, + out_channels=self.last_channels, + kernel_size=1, + bias=False, + groups=self.last_channels if use_dw else 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.dropout = nn.Dropout2d(dropout_ratio) + self.conv_seg = build_conv_layer( + conv_cfg, self.last_channels, self.num_classes, kernel_size=1) + + def forward(self, x): + x, x_hw = x[0], x[1] + x = self.linear_fuse(x) + x = self.dropout(x) + x = self.conv_seg(x) + if self.upsample == 'intepolate' or self.training or \ + self.num_classes < 30: + x = F.interpolate( + x, x_hw, mode='bilinear', align_corners=self.align_corners) + elif self.upsample == 'vim': + labelset = torch.unique(torch.argmax(x, 1)) + x = torch.gather(x, 1, labelset) + x = F.interpolate( + x, x_hw, mode='bilinear', align_corners=self.align_corners) + + pred = torch.argmax(x, 1) + pred_retrieve = torch.zeros(pred.shape, dtype=torch.int32) + for i, val in enumerate(labelset): + pred_retrieve[pred == i] = labelset[i].cast('int32') + + x = pred_retrieve + else: + raise NotImplementedError(self.upsample, ' is not implemented') + + return [x] + + def predict(self, inputs, batch_img_metas: List[dict], test_cfg, + **kwargs) -> List[Tensor]: + """Forward function for testing, only ``pam_cam`` is used.""" + seg_logits = self.forward(inputs)[0] + return seg_logits