diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 4229763816..530ecf8916 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -4,7 +4,6 @@ from .aspp_head import ASPPHead from .cc_head import CCHead from .da_head import DAHead -from .ddr_head import DDRHead from .dm_head import DMHead from .dnl_head import DNLHead from .dpt_head import DPTHead @@ -13,19 +12,14 @@ from .fcn_head import FCNHead from .fpn_head import FPNHead from .gc_head import GCHead -from .ham_head import LightHamHead from .isa_head import ISAHead from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator from .lraspp_head import LRASPPHead -from .mask2former_head import Mask2FormerHead -from .maskformer_head import MaskFormerHead from .nl_head import NLHead from .ocr_head import OCRHead -from .pid_head import PIDHead from .point_head import PointHead from .psa_head import PSAHead from .psp_head import PSPHead -from .san_head import SideAdapterCLIPHead from .segformer_head import SegformerHead from .segmenter_mask_head import SegmenterMaskTransformerHead from .sep_aspp_head import DepthwiseSeparableASPPHead @@ -34,7 +28,7 @@ from .setr_up_head import SETRUPHead from .stdc_head import STDCHead from .uper_head import UPerHead -from .vpd_depth_head import VPDDepthHead +from .atm_head import ATMHead __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', @@ -43,6 +37,5 @@ 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', - 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', - 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead' + 'KernelUpdateHead', 'ATMHead', 'KernelUpdator' ] diff --git a/mmseg/models/decode_heads/atm_head.py b/mmseg/models/decode_heads/atm_head.py new file mode 100644 index 0000000000..e948e63368 --- /dev/null +++ b/mmseg/models/decode_heads/atm_head.py @@ -0,0 +1,318 @@ +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import TransformerDecoder, TransformerDecoderLayer +from typing import Optional +import math +from functools import partial +from mmcv.runner import auto_fp16, force_fp32 +import matplotlib.pyplot as plt + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from timm.models.layers import trunc_normal_ +import matplotlib.pyplot as plt +from mmseg.models.losses import accuracy + +def trunc_normal_init(module: nn.Module, + mean: float = 0, + std: float = 1, + a: float = -2, + b: float = 2, + bias: float = 0) -> None: + if hasattr(module, 'weight') and module.weight is not None: + trunc_normal_(module.weight, mean, std, a, b) # type: ignore + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) # type: ignore + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +class TPN_Decoder(TransformerDecoder): + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None): + output = tgt + # attns = [] + for mod in self.layers: + output, attn = mod(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + # attns.append(attn) + + if self.norm is not None: + output = self.norm(output) + + return output, attn + +class TPN_DecoderLayer(TransformerDecoderLayer): + def __init__(self, **kwargs): + super(TPN_DecoderLayer, self).__init__(**kwargs) + del self.multihead_attn + self.multihead_attn = Attention( + kwargs['d_model'], num_heads=kwargs['nhead'], qkv_bias=True, attn_drop=0.1) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2, attn2 = self.multihead_attn( + tgt.transpose(0, 1), memory.transpose(0, 1), memory.transpose(0, 1)) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt, attn2 + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, xq, xk, xv): + B, Nq, C = xq.size() + Nk = xk.size()[1] + Nv = xv.size()[1] + + q = self.q(xq).reshape(B, Nq, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + k = self.k(xk).reshape(B, Nk, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(xv).reshape(B, Nv, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn_save = attn.clone() + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x.transpose(0, 1), attn_save.sum(dim=1) / self.num_heads + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +@HEADS.register_module() +class ATMHead(BaseDecodeHead): + def __init__( + self, + img_size, + in_channels, + embed_dims=768, + num_layers=3, + num_heads=8, + use_stages=3, + use_proj=True, + CE_loss=False, + crop_train=False, + shrink_ratio=None, + **kwargs, + ): + super(ATMHead, self).__init__( + in_channels=in_channels, **kwargs) + + self.image_size = img_size + self.use_stages = use_stages + self.crop_train = crop_train + nhead = num_heads + dim = embed_dims + input_proj = [] + proj_norm = [] + atm_decoders = [] + for i in range(self.use_stages): + # FC layer to change ch + if use_proj: + proj = nn.Linear(self.in_channels, dim) + trunc_normal_(proj.weight, std=.02) + else: + proj = nn.Identity() + self.add_module("input_proj_{}".format(i + 1), proj) + input_proj.append(proj) + # norm layer + if use_proj: + norm = nn.LayerNorm(dim) + else: + norm = nn.Identity() + self.add_module("proj_norm_{}".format(i + 1), norm) + proj_norm.append(norm) + # decoder layer + decoder_layer = TPN_DecoderLayer(d_model=dim, nhead=nhead, dim_feedforward=dim * 4) + decoder = TPN_Decoder(decoder_layer, num_layers) + self.add_module("decoder_{}".format(i + 1), decoder) + atm_decoders.append(decoder) + + self.input_proj = input_proj + self.proj_norm = proj_norm + self.decoder = atm_decoders + self.q = nn.Embedding(self.num_classes, dim) + + self.class_embed = nn.Linear(dim, self.num_classes + 1) + self.CE_loss = CE_loss + delattr(self, 'conv_seg') + + def init_weights(self): + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.0) + + def forward(self, inputs): + x = [] + for stage_ in inputs[:self.use_stages]: + x.append(self.d4_to_d3(stage_) if stage_.dim() > 3 else stage_) + x.reverse() + bs = x[0].size()[0] + + laterals = [] + attns = [] + maps_size = [] + qs = [] + q = self.q.weight.repeat(bs, 1, 1).transpose(0, 1) + + for idx, (x_, proj_, norm_, decoder_) in enumerate(zip(x, self.input_proj, self.proj_norm, self.decoder)): + lateral = norm_(proj_(x_)) + # if idx == 0: + if True: + laterals.append(lateral) + else: + if laterals[idx - 1].size()[1] == lateral.size()[1]: + laterals.append(lateral + laterals[idx - 1]) + else: + # nearest interpolate + l_ = self.d3_to_d4(laterals[idx - 1]) + l_ = F.interpolate(l_, scale_factor=2, mode="nearest") + l_ = self.d4_to_d3(l_) + laterals.append(l_ + lateral) + + q, attn = decoder_(q, lateral.transpose(0, 1)) + attn = attn.transpose(-1, -2) + if self.crop_train and self.training: + blank_attn = torch.zeros_like(attn) + blank_attn = blank_attn[:, 0].unsqueeze(1).repeat(1, (self.image_size//16)**2, 1) + blank_attn[:, inputs[-1]] = attn + attn = blank_attn + self.crop_idx = inputs[-1] + attn = self.d3_to_d4(attn) + maps_size.append(attn.size()[-2:]) + qs.append(q.transpose(0, 1)) + attns.append(attn) + qs = torch.stack(qs, dim=0) + outputs_class = self.class_embed(qs) + out = {"pred_logits": outputs_class[-1]} + + outputs_seg_masks = [] + size = maps_size[-1] + + for i_attn, attn in enumerate(attns): + if i_attn == 0: + outputs_seg_masks.append(F.interpolate(attn, size=size, mode='bilinear', align_corners=False)) + else: + outputs_seg_masks.append(outputs_seg_masks[i_attn - 1] + + F.interpolate(attn, size=size, mode='bilinear', align_corners=False)) + + out["pred_masks"] = F.interpolate(outputs_seg_masks[-1], + size=(self.image_size, self.image_size), + mode='bilinear', align_corners=False) + + out["pred"] = self.semantic_inference(out["pred_logits"], out["pred_masks"]) + + if self.training: + # [l, bs, queries, embed] + outputs_seg_masks = torch.stack(outputs_seg_masks, dim=0) + out["aux_outputs"] = self._set_aux_loss( + outputs_class, outputs_seg_masks + ) + else: + return out["pred"] + + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_masks": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) + return semseg + + def d3_to_d4(self, t): + n, hw, c = t.size() + if hw % 2 != 0: + t = t[:, 1:] + h = w = int(math.sqrt(hw)) + return t.transpose(1, 2).reshape(n, c, h, w) + + def d4_to_d3(self, t): + return t.flatten(-2).transpose(-1, -2) + + @force_fp32(apply_to=('seg_logit',)) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + if self.CE_loss: + return super().losses(seg_logit["pred"], seg_label) + + if isinstance(seg_logit, dict): + # atm loss + seg_label = seg_label.squeeze(1) + if self.crop_train: + # mask seg_label by crop_idx + bs, h, w = seg_label.size() + mask_label = seg_label.reshape(bs, h//16, 16, w//16, 16)\ + .permute(0, 1, 3, 2, 4).reshape(bs, h*w//256, 256) + empty_label = torch.zeros_like(mask_label) + self.ignore_index + empty_label[:, self.crop_idx] = mask_label[:, self.crop_idx] + seg_label = empty_label.reshape(bs, h//16, w//16, 16, 16)\ + .permute(0, 1, 3, 2, 4).reshape(bs, h, w) + loss = self.loss_decode( + seg_logit, + seg_label, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy(seg_logit["pred"], seg_label, ignore_index=self.ignore_index) + return loss \ No newline at end of file