diff --git a/configs/_base_/models/san_vit-b16.py b/configs/_base_/models/san_vit-b16.py new file mode 100644 index 0000000000..96ac41b8da --- /dev/null +++ b/configs/_base_/models/san_vit-b16.py @@ -0,0 +1,137 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) + +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[122.7709, 116.7460, 104.0937], + std=[68.5005, 66.6322, 70.3232], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size_divisor=640, + test_cfg=dict(size_divisor=32)) + +num_classes = 171 +model = dict( + type='MultimodalEncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained='pretrain/clip_vit_base_patch16_224.pth', + asymetric_input=True, + encoder_resolution=0.5, + image_encoder=dict( + type='VisionTransformer', + img_size=(224, 224), + patch_size=16, + patch_pad=0, + in_channels=3, + embed_dims=768, + num_layers=9, + num_heads=12, + mlp_ratio=4, + out_origin=True, + out_indices=(2, 5, 8), + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + with_cls_token=True, + output_cls_token=True, + patch_bias=False, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + act_cfg=dict(type='QuickGELU'), + norm_eval=False, + interpolate_mode='bicubic', + frozen_exclude=['pos_embed']), + text_encoder=dict( + type='CLIPTextEncoder', + dataset_name=None, + templates='vild', + embed_dims=512, + num_layers=12, + num_heads=8, + mlp_ratio=4, + output_dims=512, + cache_feature=True, + cat_bg=True, + norm_cfg=dict(type='LN', eps=1e-5) + ), + decode_head=dict( + type='SideAdapterCLIPHead', + num_classes=num_classes, + deep_supervision_idxs=[7], + san_cfg=dict( + in_channels=3, + clip_channels=768, + embed_dims=240, + patch_size=16, + patch_bias=True, + num_queries=100, + cfg_encoder=dict( + num_encode_layer=8, + num_heads=6, + mlp_ratio=4 + ), + fusion_index=[0, 1, 2, 3], + cfg_decoder=dict( + num_heads=12, + num_layers=1, + embed_channels=256, + mlp_channels=256, + num_mlp=3, + rescale=True), + norm_cfg=dict(type='LN', eps=1e-6), + ), + maskgen_cfg=dict( + sos_token_format='cls_token', + sos_token_num=100, + cross_attn=False, + num_layers=3, + embed_dims=768, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + out_dims=512, + final_norm=True, + act_cfg=dict(type='QuickGELU'), + norm_cfg=dict(type='LN', eps=1e-5), + frozen_exclude=[] + ), + align_corners=False, + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='ClassificationCost', weight=2.0), + dict( + type='CrossEntropyLossCost', + weight=5.0, + use_sigmoid=True), + dict( + type='DiceCost', + weight=5.0, + pred_act=True, + eps=1.0) + ])), + loss_decode=[dict(type='CrossEntropyLoss', + loss_name='loss_cls_ce', + loss_weight=2.0, + class_weight=[1.0] * num_classes + [0.1]), + dict(type='CrossEntropyLoss', + use_sigmoid=True, + loss_name='loss_mask_ce', + loss_weight=5.0), + dict(type='DiceLoss', + ignore_index=None, + naive_dice=True, + eps=1, + loss_name='loss_mask_dice', + loss_weight=5.0) + ]), + + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) # yapf: disable diff --git a/configs/san/README.md b/configs/san/README.md new file mode 100644 index 0000000000..23e72aa65f --- /dev/null +++ b/configs/san/README.md @@ -0,0 +1,47 @@ +# SAN + +> [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) + +## Introduction + + + +Official Repo + +## Abstract + + + +This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. We hope our approach will serve as a solid baseline and help ease future research in open-vocabulary semantic segmentation. + + + +
+ +
+ +## Results and models + +### COCO-Stuff164k + +| Method | Backbone | Pretrained | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | ------------ | --------- | ------- | -------- | -------------- | ------ | ----- | ------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| SAN | ViT-B_16 | CLIP_ViT-B16 | 640x640 | 60000 | 12.61 | - | V100 | 41.93 | 41.77 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906-fd0a7684.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906.log) | +| SAN | ViT-L_14 | CLIP_ViT-L14 | 640x640 | 60000 | 22.84 | - | V100 | 45.78 | 43.99 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907-a11e098f.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907.log) | + +## Notes + +git push +The pretrained weights in config files are converted from open_clip models using tools/model_converters/clip2mmseg.py. + +## Citation + +```bibtex +@inproceedings{xu2023side, + title={Side adapter network for open-vocabulary semantic segmentation}, + author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={2945--2954}, + year={2023} +} +``` diff --git a/configs/san/san-vit-b16_coco-stuff164k-640x640.py b/configs/san/san-vit-b16_coco-stuff164k-640x640.py new file mode 100644 index 0000000000..40592486d1 --- /dev/null +++ b/configs/san/san-vit-b16_coco-stuff164k-640x640.py @@ -0,0 +1,82 @@ +_base_ = [ + '../_base_/models/san_vit-b16.py', '../_base_/datasets/coco-stuff164k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict( + type='RandomChoiceResize', + scales=[int(640 * x * 0.1) for x in range(5, 16)], + resize_type='ResizeShortestEdge', + max_size=2560), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.0), + dict(type='PhotoMetricDistortion'), + dict(type='RandomFlip', prob=0.5), + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +# By default, models are trained on 4 GPUs with 8 images per GPU +train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-base-patch16-224_3rdparty-d08f8887.pth' # noqa +data_preprocessor = dict( + mean=[122.7709, 116.7460, 104.0937], + std=[68.5005, 66.6322, 70.3232], + size_divisor=640, + test_cfg=dict(size_divisor=32)) +model = dict( + pretrained=pretrained, + text_encoder=dict(dataset_name='coco-stuff164k'), + decode_head=dict(num_classes=171)) + +# training schedule for 60k +train_cfg = dict( + type='IterBasedTrainLoop', + max_iters=60000, + val_interval=500, + val_begin=55000) +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + by_epoch=False, + interval=10000, + save_best='mIoU')) + +# AdamW optimizer, no weight decay for position embedding & layer norm +# in backbone +optim_wrapper = dict( + _delete_=True, + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001), + paramwise_cfg=dict( + custom_keys={ + 'img_encoder': dict(lr_mult=0.1, decay_mult=1.0), + 'pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + }), + loss_scale='dynamic', + clip_grad=dict(max_norm=0.01, norm_type=2)) + +param_scheduler = [ + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=0, + end=60000, + by_epoch=False, + ) +] diff --git a/configs/san/san-vit-b16_pascal_context-640x640.py b/configs/san/san-vit-b16_pascal_context-640x640.py new file mode 100644 index 0000000000..b164fe41fd --- /dev/null +++ b/configs/san/san-vit-b16_pascal_context-640x640.py @@ -0,0 +1,56 @@ +_base_ = [ + '../_base_/models/san_vit-b16.py', + '../_base_/datasets/pascal_context_59.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +# By default, models are trained on 8 GPUs with 2 images per GPU +train_dataloader = dict(batch_size=2) +val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +data_preprocessor = dict( + mean=[122.7709, 116.7460, 104.0937], + std=[68.5005, 66.6322, 70.3232], + size_divisor=640, + test_cfg=dict(size_divisor=32)) +model = dict( + data_preprocessor=data_preprocessor, + pretrained='pretrain/vit_base_patch16_224.pth', + text_encoder=dict(dataset_name='pascal_context'), + decode_head=dict(num_classes=59)) + +# AdamW optimizer, no weight decay for position embedding & layer norm +# in backbone +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=1500, + end=160000, + by_epoch=False, + ) +] diff --git a/configs/san/san-vit-b16_voc12aug-640x640.py b/configs/san/san-vit-b16_voc12aug-640x640.py new file mode 100644 index 0000000000..62e9b26f0a --- /dev/null +++ b/configs/san/san-vit-b16_voc12aug-640x640.py @@ -0,0 +1,65 @@ +_base_ = [ + '../_base_/models/san_vit-b16.py', + '../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) + +metainfo = dict( + classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', + 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'), + palette=[[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]) +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +# By default, models are trained on 8 GPUs with 2 images per GPU +train_dataloader = dict(batch_size=2) +val_dataloader = dict( + batch_size=1, dataset=dict(metainfo=metainfo, pipeline=test_pipeline)) +test_dataloader = val_dataloader + +data_preprocessor = dict( + mean=[122.7709, 116.7460, 104.0937], + std=[68.5005, 66.6322, 70.3232], + size_divisor=640, + test_cfg=dict(size_divisor=32)) +model = dict( + data_preprocessor=data_preprocessor, + pretrained='pretrain/vit_base_patch16_224.pth', + text_encoder=dict(dataset_name='voc'), + decode_head=dict(num_classes=20)) + +# AdamW optimizer, no weight decay for position embedding & layer norm +# in backbone +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + eta_min=0.0, + power=1.0, + begin=1500, + end=160000, + by_epoch=False, + ) +] diff --git a/configs/san/san-vit-l14_coco-stuff164k-640x640.py b/configs/san/san-vit-l14_coco-stuff164k-640x640.py new file mode 100644 index 0000000000..c34328db3f --- /dev/null +++ b/configs/san/san-vit-l14_coco-stuff164k-640x640.py @@ -0,0 +1,36 @@ +_base_ = ['./san-vit-b16_coco-stuff164k-640x640.py'] + +pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-large-patch14-336_3rdparty-0b5df9cb.pth' # noqa +model = dict( + type='MultimodalEncoderDecoder', + pretrained=pretrained, + encoder_resolution=0.7, + image_encoder=dict( + type='VisionTransformer', + img_size=(336, 336), + patch_size=14, + patch_pad=0, + embed_dims=1024, + num_layers=18, + num_heads=16, + out_indices=(5, 11, 17), + ), + text_encoder=dict( + type='CLIPTextEncoder', + embed_dims=768, + num_layers=12, + num_heads=12, + output_dims=768, + ), + decode_head=dict( + type='SideAdapterCLIPHead', + san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)), + maskgen_cfg=dict( + num_layers=6, + embed_dims=1024, + num_heads=16, + out_dims=768, + ))) + +# By default, models are trained on 8 GPUs with 4 images per GPU +train_dataloader = dict(batch_size=4) diff --git a/configs/san/san-vit-l14_pascal_context-640x640.py b/configs/san/san-vit-l14_pascal_context-640x640.py new file mode 100644 index 0000000000..a9545fac8e --- /dev/null +++ b/configs/san/san-vit-l14_pascal_context-640x640.py @@ -0,0 +1,32 @@ +_base_ = ['./san-vit-b16_pascal_context-640x640.py'] + +model = dict( + type='MultimodalEncoderDecoder', + pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', + encoder_resolution=0.7, + image_encoder=dict( + type='VisionTransformer', + img_size=(336, 336), + patch_size=14, + patch_pad=0, + embed_dims=1024, + num_layers=18, + num_heads=16, + out_indices=(5, 11, 17), + ), + text_encoder=dict( + type='CLIPTextEncoder', + embed_dims=768, + num_layers=12, + num_heads=12, + output_dims=768, + ), + decode_head=dict( + type='SideAdapterCLIPHead', + san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)), + maskgen_cfg=dict( + num_layers=6, + embed_dims=1024, + num_heads=16, + out_dims=768, + ))) diff --git a/configs/san/san-vit-l14_voc12aug-640x640.py b/configs/san/san-vit-l14_voc12aug-640x640.py new file mode 100644 index 0000000000..2f37715039 --- /dev/null +++ b/configs/san/san-vit-l14_voc12aug-640x640.py @@ -0,0 +1,32 @@ +_base_ = ['./san-vit-b16_voc12aug-640x640.py'] + +model = dict( + type='MultimodalEncoderDecoder', + pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', + encoder_resolution=0.7, + image_encoder=dict( + type='VisionTransformer', + img_size=(336, 336), + patch_size=14, + patch_pad=0, + embed_dims=1024, + num_layers=18, + num_heads=16, + out_indices=(5, 11, 17), + ), + text_encoder=dict( + type='CLIPTextEncoder', + embed_dims=768, + num_layers=12, + num_heads=12, + output_dims=768, + ), + decode_head=dict( + type='SideAdapterCLIPHead', + san_cfg=dict(clip_channels=1024, cfg_decoder=dict(num_heads=16)), + maskgen_cfg=dict( + num_layers=6, + embed_dims=1024, + num_heads=16, + out_dims=768, + ))) diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py index 7a520fb2fa..a98951283c 100644 --- a/mmseg/models/__init__.py +++ b/mmseg/models/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .assigners import * # noqa: F401,F403 from .backbones import * # noqa: F401,F403 from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, build_head, build_loss, build_segmentor) @@ -7,6 +8,7 @@ from .losses import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .segmentors import * # noqa: F401,F403 +from .text_encoder import * # noqa: F401,F403 __all__ = [ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', diff --git a/mmseg/models/assigners/__init__.py b/mmseg/models/assigners/__init__.py new file mode 100644 index 0000000000..d49b1b18b9 --- /dev/null +++ b/mmseg/models/assigners/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_assigner import BaseAssigner +from .hungarian_assigner import HungarianAssigner +from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost + +__all__ = [ + 'BaseAssigner', + 'HungarianAssigner', + 'ClassificationCost', + 'CrossEntropyLossCost', + 'DiceCost', +] diff --git a/mmseg/models/assigners/base_assigner.py b/mmseg/models/assigners/base_assigner.py new file mode 100644 index 0000000000..97895cdac2 --- /dev/null +++ b/mmseg/models/assigners/base_assigner.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Optional + +from mmengine.structures import InstanceData + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns masks to ground truth class labels.""" + + @abstractmethod + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs): + """Assign masks to either a ground truth class label or a negative + label.""" diff --git a/mmseg/models/assigners/hungarian_assigner.py b/mmseg/models/assigners/hungarian_assigner.py new file mode 100644 index 0000000000..28868f0a04 --- /dev/null +++ b/mmseg/models/assigners/hungarian_assigner.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from scipy.optimize import linear_sum_assignment +from torch.cuda.amp import autocast + +from mmseg.registry import TASK_UTILS +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class HungarianAssigner(BaseAssigner): + """Computes one-to-one matching between prediction masks and ground truth. + + This class uses bipartite matching-based assignment to computes an + assignment between the prediction masks and the ground truth. The + assignment result is based on the weighted sum of match costs. The + Hungarian algorithm is used to calculate the best matching with the + minimum cost. The prediction masks that are not matched are classified + as background. + + Args: + match_costs (ConfigDict|List[ConfigDict]): Match cost configs. + """ + + def __init__( + self, match_costs: Union[List[Union[dict, ConfigDict]], dict, + ConfigDict] + ) -> None: + + if isinstance(match_costs, dict): + match_costs = [match_costs] + elif isinstance(match_costs, list): + assert len(match_costs) > 0, \ + 'match_costs must not be a empty list.' + + self.match_costs = [ + TASK_UTILS.build(match_cost) for match_cost in match_costs + ] + + def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, + **kwargs): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The assignment first calculates the cost for each + category assigned to each query mask, and then uses the + Hungarian algorithm to calculate the minimum cost as the best + match. + + Args: + pred_instances (InstanceData): Instances of model + predictions. It includes "masks", with shape + (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) + gt_instances (InstanceData): Ground truth of instance + annotations. It includes "labels", with shape (k, ), + and "masks", with shape (k, h, w) or (k, l). + + Returns: + matched_quiery_inds (Tensor): The indexes of matched quieres. + matched_label_inds (Tensor): The indexes of matched labels. + """ + # compute weighted cost + cost_list = [] + with autocast(enabled=False): + for match_cost in self.match_costs: + cost = match_cost( + pred_instances=pred_instances, gt_instances=gt_instances) + cost_list.append(cost) + cost = torch.stack(cost_list).sum(dim=0) + + device = cost.device + # do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) + matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) + matched_label_inds = torch.from_numpy(matched_label_inds).to(device) + + return matched_quiery_inds, matched_label_inds diff --git a/mmseg/models/assigners/match_cost.py b/mmseg/models/assigners/match_cost.py new file mode 100644 index 0000000000..560df85290 --- /dev/null +++ b/mmseg/models/assigners/match_cost.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Union + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS + + +class BaseMatchCost: + """Base match cost class. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, weight: Union[float, int] = 1.) -> None: + self.weight = weight + + @abstractmethod + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Instances of model predictions. + It often includes "labels" and "scores". + gt_instances (InstanceData): Ground truth of instance + annotations. It usually includes "labels". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pass + + +@TASK_UTILS.register_module() +class ClassificationCost(BaseMatchCost): + """ClsSoftmaxCost. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmseg.models.assigners import ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight: Union[float, int] = 1) -> None: + super().__init__(weight=weight) + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): "scores" inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (InstanceData): "labels" inside should have + shape (num_gt, ). + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'scores'), \ + "pred_instances must contain 'scores'" + assert hasattr(gt_instances, 'labels'), \ + "gt_instances must contain 'labels'" + pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + + pred_scores = pred_scores.softmax(-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight + + +@TASK_UTILS.register_module() +class DiceCost(BaseMatchCost): + """Cost of mask assignments based on dice losses. + + Args: + pred_act (bool): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float): Defaults to 1e-3. + naive_dice (bool): If True, use the naive dice loss + in which the power of the number in the denominator is + the first power. If False, use the second power that + is adopted by K-Net and SOLO. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + pred_act: bool = False, + eps: float = 1e-3, + naive_dice: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.pred_act = pred_act + self.eps = eps + self.naive_dice = naive_dice + + def _binary_mask_dice_loss(self, mask_preds: Tensor, + gt_masks: Tensor) -> Tensor: + """ + Args: + mask_preds (Tensor): Mask prediction in shape (num_queries, *). + gt_masks (Tensor): Ground truth in shape (num_gt, *) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (num_queries, num_gt). + """ + mask_preds = mask_preds.flatten(1) + gt_masks = gt_masks.flatten(1).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + if self.naive_dice: + denominator = mask_preds.sum(-1)[:, None] + \ + gt_masks.sum(-1)[None, :] + else: + denominator = mask_preds.pow(2).sum(1)[:, None] + \ + gt_masks.pow(2).sum(1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Predicted instances which + must contain "masks". + gt_instances (InstanceData): Ground truth which must contain + "mask". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + + if self.pred_act: + pred_masks = pred_masks.sigmoid() + dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) + return dice_cost * self.weight + + +@TASK_UTILS.register_module() +class CrossEntropyLossCost(BaseMatchCost): + """CrossEntropyLossCost. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + use_sigmoid: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred: Tensor, + gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or + (num_queries, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + + Returns: + Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``masks``. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 3c96f65493..dd0f688fcc 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -132,12 +132,16 @@ class VisionTransformer(BaseModule): Args: img_size (int | tuple): Input image size. Default: 224. patch_size (int): The patch size. Default: 16. + patch_pad (str | int | None): The padding method in patch embedding. + Default: 'corner'. in_channels (int): Number of input channels. Default: 3. embed_dims (int): embedding dimension. Default: 768. num_layers (int): depth of transformer. Default: 12. num_heads (int): number of attention heads. Default: 12. mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4. + out_origin (bool): Whether to output the original input embedding. + Default: False out_indices (list | tuple | int): Output from which stages. Default: -1. qkv_bias (bool): enable bias for qkv if True. Default: True. @@ -154,8 +158,12 @@ class VisionTransformer(BaseModule): Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). + patch_bias (dict): Whether use bias in convolution of PatchEmbed Block. + Default: True. patch_norm (bool): Whether to add a norm in PatchEmbed Block. Default: False. + pre_norm (bool): Whether to add a norm before Transformer Layers. + Default: False. final_norm (bool): Whether to add a additional layer to normalize final feature map. Default: False. interpolate_mode (str): Select the interpolate mode for position @@ -167,6 +175,8 @@ class VisionTransformer(BaseModule): and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. + frozen_exclude (List): List of parameters that are not to be frozen. + Default: ["all"], "all" means there are no frozen parameters. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. @@ -175,11 +185,13 @@ class VisionTransformer(BaseModule): def __init__(self, img_size=224, patch_size=16, + patch_pad='corner', in_channels=3, embed_dims=768, num_layers=12, num_heads=12, mlp_ratio=4, + out_origin=False, out_indices=-1, qkv_bias=True, drop_rate=0., @@ -190,11 +202,14 @@ def __init__(self, norm_cfg=dict(type='LN'), act_cfg=dict(type='GELU'), patch_norm=False, + patch_bias=False, + pre_norm=False, final_norm=False, interpolate_mode='bicubic', num_fcs=2, norm_eval=False, with_cp=False, + frozen_exclude=['all'], pretrained=None, init_cfg=None): super().__init__(init_cfg=init_cfg) @@ -227,6 +242,8 @@ def __init__(self, self.norm_eval = norm_eval self.with_cp = with_cp self.pretrained = pretrained + self.out_origin = out_origin + self.frozen_exclude = frozen_exclude self.patch_embed = PatchEmbed( in_channels=in_channels, @@ -234,7 +251,8 @@ def __init__(self, conv_type='Conv2d', kernel_size=patch_size, stride=patch_size, - padding='corner', + padding=patch_pad, + bias=patch_bias, norm_cfg=norm_cfg if patch_norm else None, init_cfg=None, ) @@ -248,6 +266,12 @@ def __init__(self, self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) + self.pre_norm = pre_norm + + if self.pre_norm: + self.pre_ln_name, pre_ln = build_norm_layer( + norm_cfg, embed_dims, postfix='_pre') + self.add_module(self.pre_ln_name, pre_ln) if isinstance(out_indices, int): if out_indices == -1: @@ -285,20 +309,36 @@ def __init__(self, norm_cfg, embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) + self._freeze() + + @property + def pre_ln(self): + return getattr(self, self.pre_ln_name) + @property def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): - if (isinstance(self.init_cfg, dict) - and self.init_cfg.get('type') == 'Pretrained'): + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']: checkpoint = CheckpointLoader.load_checkpoint( self.init_cfg['checkpoint'], logger=None, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint + if self.init_cfg.get('type') == 'Pretrained': + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + elif self.init_cfg.get('type') == 'Pretrained_Part': + state_dict = checkpoint.copy() + para_prefix = 'image_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v if 'pos_embed' in state_dict.keys(): if self.pos_embed.shape != state_dict['pos_embed'].shape: @@ -334,6 +374,13 @@ def init_weights(self): elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): constant_init(m, val=1.0, bias=0.) + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + def _pos_embeding(self, patched_img, hw_shape, pos_embed): """Positioning embeding method. @@ -409,7 +456,23 @@ def forward(self, inputs): # Remove class token for transformer encoder input x = x[:, 1:] + if self.pre_norm: + x = self.pre_ln(x) + outs = [] + if self.out_origin: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1: diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index b63cdc3e2c..4229763816 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -25,6 +25,7 @@ from .point_head import PointHead from .psa_head import PSAHead from .psp_head import PSPHead +from .san_head import SideAdapterCLIPHead from .segformer_head import SegformerHead from .segmenter_mask_head import SegmenterMaskTransformerHead from .sep_aspp_head import DepthwiseSeparableASPPHead @@ -43,5 +44,5 @@ 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', - 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead' + 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead' ] diff --git a/mmseg/models/decode_heads/san_head.py b/mmseg/models/decode_heads/san_head.py new file mode 100644 index 0000000000..03dedf2e49 --- /dev/null +++ b/mmseg/models/decode_heads/san_head.py @@ -0,0 +1,733 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmcv.ops import point_sample +from mmengine.dist import all_reduce +from mmengine.model.weight_init import (caffe2_xavier_init, normal_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn import functional as F + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, MatchMasks, SampleList, + seg_data_to_instance_data) +from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer, + get_uncertain_point_coords_with_randomness, resize) +from .decode_head import BaseDecodeHead + + +class MLPMaskDecoder(nn.Module): + """Module for decoding query and visual features with MLP layers to + generate the attention biases and the mask proposals.""" + + def __init__( + self, + *, + in_channels: int, + total_heads: int = 1, + total_layers: int = 1, + embed_channels: int = 256, + mlp_channels: int = 256, + mlp_num_layers: int = 3, + rescale_attn_bias: bool = False, + ): + super().__init__() + self.total_heads = total_heads + self.total_layers = total_layers + + dense_affine_func = partial(nn.Conv2d, kernel_size=1) + # Query Branch + self.query_mlp = MLP(in_channels, mlp_channels, embed_channels, + mlp_num_layers) + # Pixel Branch + self.pix_mlp = MLP( + in_channels, + mlp_channels, + embed_channels, + mlp_num_layers, + affine_func=dense_affine_func, + ) + # Attention Bias Branch + self.attn_mlp = MLP( + in_channels, + mlp_channels, + embed_channels * self.total_heads * self.total_layers, + mlp_num_layers, + affine_func=dense_affine_func, + ) + if rescale_attn_bias: + self.bias_scaling = nn.Linear(1, 1) + else: + self.bias_scaling = nn.Identity() + + def forward(self, query: torch.Tensor, + x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward function. + Args: + query (Tensor): Query Tokens [B,N,C]. + x (Tensor): Visual features [B,C,H,W] + + Return: + mask_preds (Tensor): Mask proposals. + attn_bias (List[Tensor]): List of attention bias. + """ + query = self.query_mlp(query) + pix = self.pix_mlp(x) + b, c, h, w = pix.shape + # preidict mask + mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix) + # generate attn bias + attn = self.attn_mlp(x) + attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w) + attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn) + attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) + attn_bias = attn_bias.chunk(self.total_layers, dim=1) + attn_bias = [attn.squeeze(1) for attn in attn_bias] + return mask_preds, attn_bias + + +class SideAdapterNetwork(nn.Module): + """Side Adapter Network for predicting mask proposals and attention bias. + + Args: + in_channels (int): Number of input channels. Default: 3. + clip_channels (int): Number of channels of visual features. + Default: 768. + embed_dims (int): embedding dimension. Default: 240. + patch_size (int): The patch size. Default: 16. + patch_bias (bool): Whether use bias in patch embedding. + Default: True. + num_queries (int): Number of queries for mask proposals. + Default: 100. + fusion_index (List[int]): The layer number of the encode + transformer to fuse with the CLIP feature. + Default: [0, 1, 2, 3]. + cfg_encoder (ConfigType): Configs for the encode layers. + cfg_decoder (ConfigType): Configs for the decode layers. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + """ + + def __init__( + self, + in_channels: int = 3, + clip_channels: int = 768, + embed_dims: int = 240, + patch_size: int = 16, + patch_bias: bool = True, + num_queries: int = 100, + fusion_index: list = [0, 1, 2, 3], + cfg_encoder: ConfigType = ..., + cfg_decoder: ConfigType = ..., + norm_cfg: dict = dict(type='LN'), + ): + super().__init__() + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding=0, + input_size=(640, 640), + bias=patch_bias, + norm_cfg=None, + init_cfg=None, + ) + ori_h, ori_w = self.patch_embed.init_out_size + num_patches = ori_h * ori_w + self.pos_embed = nn.Parameter( + torch.randn(1, num_patches, embed_dims) * .02) + self.query_pos_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + self.query_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + encode_layers = [] + for i in range(cfg_encoder.num_encode_layer): + encode_layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=cfg_encoder.num_heads, + feedforward_channels=cfg_encoder.mlp_ratio * embed_dims, + norm_cfg=norm_cfg)) + self.encode_layers = nn.ModuleList(encode_layers) + conv_clips = [] + for i in range(len(fusion_index)): + conv_clips.append( + nn.Sequential( + LayerNorm2d(clip_channels), + ConvModule( + clip_channels, + embed_dims, + kernel_size=1, + norm_cfg=None, + act_cfg=None))) + self.conv_clips = nn.ModuleList(conv_clips) + self.fusion_index = fusion_index + self.mask_decoder = MLPMaskDecoder( + in_channels=embed_dims, + total_heads=cfg_decoder.num_heads, + total_layers=cfg_decoder.num_layers, + embed_channels=cfg_decoder.embed_channels, + mlp_channels=cfg_decoder.mlp_channels, + mlp_num_layers=cfg_decoder.num_mlp, + rescale_attn_bias=cfg_decoder.rescale) + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.query_embed, std=0.02) + nn.init.normal_(self.query_pos_embed, std=0.02) + for i in range(len(self.conv_clips)): + caffe2_xavier_init(self.conv_clips[i][1].conv) + + def fuse_clip(self, fused_index: int, x: torch.Tensor, + clip_feature: torch.Tensor, hwshape: Tuple[int, + int], L: int): + """Fuse CLIP feature and visual tokens.""" + fused_clip = (resize( + self.conv_clips[fused_index](clip_feature.contiguous()), + size=hwshape, + mode='bilinear', + align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:, + ...].shape) + x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1) + return x + + def encode_feature(self, image: torch.Tensor, + clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int]) -> List[List]: + """Encode images by a lightweight vision transformer.""" + assert len(self.fusion_index) == len(clip_features) + x, hwshape = self.patch_embed(image) + ori_h, ori_w = self.patch_embed.init_out_size + pos_embed = self.pos_embed + if self.pos_embed.shape[1] != x.shape[1]: + # resize the position embedding + pos_embed = ( + resize( + self.pos_embed.reshape(1, ori_h, ori_w, + -1).permute(0, 3, 1, 2), + size=hwshape, + mode='bicubic', + align_corners=False, + ).flatten(2).permute(0, 2, 1)) + pos_embed = torch.cat([ + self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed + ], + dim=1) + x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1) + x = x + pos_embed + L = hwshape[0] * hwshape[1] + fused_index = 0 + if self.fusion_index[fused_index] == 0: + x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L) + fused_index += 1 + outs = [] + for index, block in enumerate(self.encode_layers, start=1): + x = block(x) + if index < len(self.fusion_index + ) and index == self.fusion_index[fused_index]: + x = self.fuse_clip(fused_index, x, + clip_features[fused_index][0], hwshape, L) + fused_index += 1 + x_query = x[:, :-L, ...] + x_feat = x[:, -L:, ...].permute(0, 2, 1)\ + .reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1]) + + if index in deep_supervision_idxs or index == len( + self.encode_layers): + outs.append({'query': x_query, 'x': x_feat}) + + if index < len(self.encode_layers): + x = x + pos_embed + return outs + + def decode_feature(self, features): + mask_embeds = [] + attn_biases = [] + for feature in features: + mask_embed, attn_bias = self.mask_decoder(**feature) + mask_embeds.append(mask_embed) + attn_biases.append(attn_bias) + return mask_embeds, attn_biases + + def forward( + self, image: torch.Tensor, clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int] + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """Forward function.""" + features = self.encode_feature(image, clip_features, + deep_supervision_idxs) + mask_embeds, attn_biases = self.decode_feature(features) + return mask_embeds, attn_biases + + +class RecWithAttnbias(nn.Module): + """Mask recognition module by applying the attention biases to rest deeper + CLIP layers. + + Args: + sos_token_format (str): The format of sos token. It should be + chosen from ["cls_token", "learnable_token", "pos_embedding"]. + Default: 'cls_token'. + sos_token_num (int): Number of sos token. It should be equal to + the number of quries. Default: 100. + num_layers (int): Number of rest CLIP layers for mask recognition. + Default: 3. + cross_attn (bool): Whether use cross attention to update sos token. + Default: False. + embed_dims (int): The feature dimension of CLIP layers. + Default: 768. + num_heads (int): Parallel attention heads of CLIP layers. + Default: 768. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Whether to use bias in multihead-attention. + Default: True. + out_dims (int): Number of channels of the output mask proposals. + It should be equal to the out_dims of text_encoder. + Default: 512. + final_norm (True): Whether use norm layer for sos token. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + frozen_exclude (List): List of parameters that are not to be frozen. + """ + + def __init__(self, + sos_token_format: str = 'cls_token', + sos_token_num: int = 100, + num_layers: int = 3, + cross_attn: bool = False, + embed_dims: int = 768, + num_heads: int = 12, + mlp_ratio: int = 4, + num_fcs: int = 2, + qkv_bias: bool = True, + out_dims: int = 512, + final_norm: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + frozen_exclude: List = []): + super().__init__() + + assert sos_token_format in [ + 'cls_token', 'learnable_token', 'pos_embedding' + ] + self.sos_token_format = sos_token_format + self.sos_token_num = sos_token_num + self.frozen_exclude = frozen_exclude + self.cross_attn = cross_attn + self.num_layers = num_layers + self.num_heads = num_heads + if sos_token_format in ['learnable_token', 'pos_embedding']: + self.sos_token = nn.Parameter( + torch.randn(sos_token_num, 1, self.proj.shape[0])) + self.frozen.append('sos_token') + + layers = [] + for i in range(num_layers): + layers.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=qkv_bias), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=act_cfg), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.layers = nn.ModuleList(layers) + + self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1] + self.proj = nn.Linear(embed_dims, out_dims, bias=False) + + self.final_norm = final_norm + self._freeze() + + def init_weights(self, rec_state_dict): + if hasattr(self, 'sos_token'): + normal_init(self.sos_token, std=0.02) + if rec_state_dict is not None: + load_state_dict(self, rec_state_dict, strict=False, logger=None) + else: + super().init_weights() + + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + + def _build_attn_biases(self, attn_biases, target_shape): + formatted_attn_biases = [] + for attn_bias in attn_biases: + # convert it to proper format: N*num_head,L,L + # attn_bias: [N, num_head/1, num_sos,H,W] + n, num_head, num_sos, h, w = attn_bias.shape + # reshape and downsample + attn_bias = F.adaptive_max_pool2d( + attn_bias.reshape(n, num_head * num_sos, h, w), + output_size=target_shape) + attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape) + + true_num_head = self.num_heads + assert (num_head == 1 or num_head + == true_num_head), f'num_head={num_head} is not supported.' + if num_head == 1: + attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1) + attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1) + L = attn_bias.shape[-1] + if self.cross_attn: + # [n*num_head, num_sos, L] + formatted_attn_biases.append(attn_bias) + else: + # [n*num_head, num_sos+1+L, num_sos+1+L] + new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L, + num_sos + 1 + L) + new_attn_bias[:, :num_sos] = -100 + new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0 + new_attn_bias[:num_sos, num_sos] = -100 + new_attn_bias = ( + new_attn_bias[None, ...].expand(n * true_num_head, -1, + -1).clone()) + new_attn_bias[..., :num_sos, -L:] = attn_bias + formatted_attn_biases.append(new_attn_bias) + + if len(formatted_attn_biases) == 1: + formatted_attn_biases = [ + formatted_attn_biases[0] for _ in range(self.num_layers) + ] + return formatted_attn_biases + + def forward(self, bias: List[Tensor], feature: List[Tensor]): + """Forward function to recognize the category of masks + Args: + bias (List[Tensor]): Attention bias for transformer layers + feature (List[Tensor]): Output of the image encoder, + including cls_token and img_feature. + """ + cls_token = feature[1].unsqueeze(0) + img_feature = feature[0] + b, c, h, w = img_feature.shape + # construct clip shadow features + x = torch.cat( + [cls_token, + img_feature.reshape(b, c, -1).permute(2, 0, 1)]) + + # construct sos token + if self.sos_token_format == 'cls_token': + sos_token = cls_token.repeat(self.sos_token_num, 1, 1) + elif self.sos_token_format == 'learnable_token': + sos_token = self.sos_token.expand(-1, b, -1) + elif self.sos_token_format == 'pos_embedding': + sos_token = self.sos_token.expand(-1, b, -1) + cls_token + + # construct attn bias + attn_biases = self._build_attn_biases(bias, target_shape=(h, w)) + + if self.cross_attn: + for i, block in enumerate(self.layers): + if self.cross_attn: + sos_token = cross_attn_layer( + block, + sos_token, + x[1:, ], + attn_biases[i], + ) + if i < len(self.layers) - 1: + x = block(x) + else: + x = torch.cat([sos_token, x], dim=0) + for i, block in enumerate(self.layers): + x = block(x, attn_masks=[attn_biases[i]]) + sos_token = x[:self.sos_token_num] + + sos_token = sos_token.permute(1, 0, 2) # LND -> NLD + sos_token = self.ln_post(sos_token) + sos_token = self.proj(sos_token) + if self.final_norm: + sos_token = F.normalize(sos_token, dim=-1) + return sos_token + + +@MODELS.register_module() +class SideAdapterCLIPHead(BaseDecodeHead): + """Side Adapter Network (SAN) for open-vocabulary semantic segmentation + with pre-trained vision-language model. + + This decode head is the implementation of `Side Adapter Network + for Open-Vocabulary Semantic Segmentation` + . + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + num_classes (int): the number of classes. + san_cfg (ConfigType): Configs for SideAdapterNetwork module + maskgen_cfg (ConfigType): Configs for RecWithAttnbias module + """ + + def __init__(self, num_classes: int, san_cfg: ConfigType, + maskgen_cfg: ConfigType, deep_supervision_idxs: List[int], + train_cfg: ConfigType, **kwargs): + super().__init__( + in_channels=san_cfg.in_channels, + channels=san_cfg.embed_dims, + num_classes=num_classes, + **kwargs) + assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \ + 'num_queries in san_cfg should be equal to sos_token_num ' \ + 'in maskgen_cfg' + del self.conv_seg + self.side_adapter_network = SideAdapterNetwork(**san_cfg) + self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg) + self.deep_supervision_idxs = deep_supervision_idxs + self.train_cfg = train_cfg + if train_cfg: + self.match_masks = MatchMasks( + num_points=train_cfg.num_points, + num_queries=san_cfg.num_queries, + num_classes=num_classes, + assigner=train_cfg.assigner) + + def init_weights(self): + + rec_state_dict = None + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + rec_state_dict = checkpoint.copy() + para_prefix = 'decode_head.rec_with_attnbias' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + rec_state_dict.pop(k) + if para_prefix in k: + rec_state_dict[k[prefix_len:]] = v + + self.side_adapter_network.init_weights() + self.rec_with_attnbias.init_weights(rec_state_dict) + + def forward(self, inputs: Tuple[Tensor], + deep_supervision_idxs) -> Tuple[List]: + """Forward function. + + Args: + inputs (Tuple[Tensor]): A triplet including images, + list of multi-level visual features from image encoder and + class embeddings from text_encoder. + + Returns: + mask_props (List[Tensor]): Mask proposals predicted by SAN. + mask_logits (List[Tensor]): Class logits of mask proposals. + """ + imgs, clip_feature, class_embeds = inputs + # predict mask proposals and attention bias + mask_props, attn_biases = self.side_adapter_network( + imgs, clip_feature, deep_supervision_idxs) + + # mask recognition with attention bias + mask_embeds = [ + self.rec_with_attnbias(att_bias, clip_feature[-1]) + for att_bias in attn_biases + ] + # Obtain class prediction of masks by comparing the similarity + # between the image token and the text embedding of class names. + mask_logits = [ + torch.einsum('bqc,nc->bqn', mask_embed, class_embeds) + for mask_embed in mask_embeds + ] + return mask_props, mask_logits + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): Images, visual features from image encoder + and class embedding from text encoder. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + mask_props, mask_logits = self.forward(inputs, []) + + return self.predict_by_feat([mask_props[-1], mask_logits[-1]], + batch_img_metas) + + def predict_by_feat(self, seg_logits: List[Tensor], + batch_img_metas: List[dict]) -> Tensor: + """1. Transform a batch of mask proposals to the input shape. + 2. Generate segmentation map with mask proposals and class logits. + """ + mask_pred = seg_logits[0] + cls_score = seg_logits[1] + if 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred = F.interpolate( + mask_pred, size=size, mode='bilinear', align_corners=False) + + mask_cls = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred) + return seg_logits + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances = seg_data_to_instance_data(self.ignore_index, + batch_data_samples) + + # forward + all_mask_props, all_mask_logits = self.forward( + x, self.deep_supervision_idxs) + + # loss + losses = self.loss_by_feat(all_mask_logits, all_mask_props, + batch_gt_instances) + + return losses + + def loss_by_feat( + self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + + losses = [] + for i in range(num_dec_layers): + cls_scores = all_cls_scores[i] + mask_preds = all_mask_preds[i] + # matching N mask predictions to K category labels + (labels, mask_targets, mask_weights, + avg_factor) = self.match_masks.get_targets( + cls_scores, mask_preds, batch_gt_instances_list[i]) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + num_total_masks = cls_scores.new_tensor([avg_factor], + dtype=torch.float) + all_reduce(num_total_masks, op='mean') + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] != 0: + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, + self.train_cfg.num_points, + self.train_cfg.oversample_ratio, + self.train_cfg.importance_sample_ratio) + # shape (num_total_gts, h, w) + # -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), + points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + loss = dict() + for loss_decode in losses_decode: + if 'loss_cls' in loss_decode.loss_name: + if loss_decode.loss_name == 'loss_cls_ce': + loss[loss_decode.loss_name] = loss_decode( + cls_scores, labels) + else: + assert False, "Only support 'CrossEntropyLoss' in" \ + ' classification loss' + + elif 'loss_mask' in loss_decode.loss_name: + if mask_targets.shape[0] == 0: + loss[loss_decode.loss_name] = mask_preds.sum() + elif loss_decode.loss_name == 'loss_mask_ce': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * + self.train_cfg.num_points) + elif loss_decode.loss_name == 'loss_mask_dice': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks) + else: + assert False, "Only support 'CrossEntropyLoss' and" \ + " 'DiceLoss' in mask loss" + else: + assert False, "Only support for 'loss_cls' and 'loss_mask'" + + losses.append(loss) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict.update(losses[-1]) + # loss from other decoder layers + for i, loss in enumerate(losses[:-1]): + for k, v in loss.items(): + loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v + return loss_dict diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index caad344109..65553472c0 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -53,8 +53,22 @@ def cross_entropy(pred, # average loss over non-ignored elements # pytorch's official cross_entropy average loss over non-ignored elements # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa - if (avg_factor is None) and avg_non_ignore and reduction == 'mean': - avg_factor = label.numel() - (label == ignore_index).sum().item() + if (avg_factor is None) and reduction == 'mean': + if class_weight is None: + if avg_non_ignore: + avg_factor = label.numel() - (label + == ignore_index).sum().item() + else: + avg_factor = label.numel() + + else: + # the average factor should take the class weights into account + label_weights = torch.tensor([class_weight[cls] for cls in label], + device=class_weight.device) + if avg_non_ignore: + label_weights[label == ignore_index] = 0 + avg_factor = label_weights.sum() + if weight is not None: weight = weight.float() loss = weight_reduce_loss( diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 65eae8aebc..fb2ffdba8d 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -66,10 +66,11 @@ def dice_loss(pred: torch.Tensor, ignore_index (int, optional): The label index to be ignored. Defaults to 255. """ - num_classes = pred.shape[1] - pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] - target = target[:, torch.arange(num_classes) != ignore_index, :, :] - assert pred.shape[1] != 0 # if the ignored index is the only class + if ignore_index is not None: + num_classes = pred.shape[1] + pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] + target = target[:, torch.arange(num_classes) != ignore_index, :, :] + assert pred.shape[1] != 0 # if the ignored index is the only class input = pred.flatten(1) target = target.flatten(1).float() a = torch.sum(input * target, 1) diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py index ac63c73f74..59b012f417 100644 --- a/mmseg/models/segmentors/__init__.py +++ b/mmseg/models/segmentors/__init__.py @@ -3,9 +3,10 @@ from .cascade_encoder_decoder import CascadeEncoderDecoder from .depth_estimator import DepthEstimator from .encoder_decoder import EncoderDecoder +from .multimodal_encoder_decoder import MultimodalEncoderDecoder from .seg_tta import SegTTAModel __all__ = [ 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', - 'DepthEstimator' + 'MultimodalEncoderDecoder', 'DepthEstimator' ] diff --git a/mmseg/models/segmentors/multimodal_encoder_decoder.py b/mmseg/models/segmentors/multimodal_encoder_decoder.py new file mode 100644 index 0000000000..75aa8b9b17 --- /dev/null +++ b/mmseg/models/segmentors/multimodal_encoder_decoder.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class MultimodalEncoderDecoder(BaseSegmentor): + """Multimodal Encoder-Decoder segmentors. + + Multimodal segmentation architecture is used for open-vocabulary + semantic segmentation with combining the visual and language + pretrain models. It consists of a image_encoder (backbone) to extract + visual feature, a text encoder to extract text feature, and a decode + head to generate semantic maps. + Note that the deep supervision during training is implemented in decode head. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() + _decode_head_forward_train(): decode_head.loss() + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + image_encoder (ConfigType): The config for the visual encoder of segmentor. + text_encoder ((ConfigType): The config for the text encoder of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + asymetric_input (bool): whether to use different size of input for image encoder + and decode head. Defaults to False. + encoder_resolution (float): resize scale of input images for image encoder. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + image_encoder: ConfigType, + text_encoder: ConfigType, + decode_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + asymetric_input: bool = True, + encoder_resolution: float = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + image_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + text_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + decode_head.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + + if asymetric_input: + assert encoder_resolution is not None, \ + 'if asymetric_input set True, ' \ + 'clip_resolution must be a certain value' + self.asymetric_input = asymetric_input + self.encoder_resolution = encoder_resolution + self.image_encoder = MODELS.build(image_encoder) + self.text_encoder = MODELS.build(text_encoder) + self._init_decode_head(decode_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract visual features from images.""" + x = self.image_encoder(inputs) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode the name of classes with text_encoder and encode images with + image_encoder. + + Then decode the class embedding and visual feature into a semantic + segmentation map of the same size as input. + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + seg_logits = self.decode_head.predict([inputs, x, classifier_embeds], + batch_img_metas, self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train( + [inputs, x, classifier_embeds], data_samples) + losses.update(loss_decode) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = batch_img_metas[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/mmseg/models/text_encoder/__init__.py b/mmseg/models/text_encoder/__init__.py new file mode 100644 index 0000000000..199856d9d7 --- /dev/null +++ b/mmseg/models/text_encoder/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .clip_text_encoder import CLIPTextEncoder + +__all__ = ['CLIPTextEncoder'] diff --git a/mmseg/models/text_encoder/clip_text_encoder.py b/mmseg/models/text_encoder/clip_text_encoder.py new file mode 100644 index 0000000000..1a18b86395 --- /dev/null +++ b/mmseg/models/text_encoder/clip_text_encoder.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmengine.model import BaseModule, ModuleList +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn import functional as F + +from mmseg.registry import MODELS +from mmseg.utils import get_classes, get_predefined_templates, tokenizer + + +@MODELS.register_module() +class CLIPTextEncoder(BaseModule): + """A text encoder with transformer architecture to encode the label text. + + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + dataset_name: (str|None): The name of the dataset to which + the data belongs. + vocabulary: (List[str]|None): The list of class names. Default: None. + templates: (List[str]|None): The prompt template used for labels. + Default: None. + total_vocab_size: (int): Number of all words used by the pre-trained + model. Default: 49408 (CLIP). + context_length: (int): The max length of prompt text. + Default: 77 (CLIP). + embed_dims: (int): Width of transformer model. Default: 512. + num_layers: (int): Depth of transformer. Default: 12, + num_heads: (int): Number of attention heads in transformer. + Default: 8, + mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in + transformer. Default: 4, + output_dims: (int) Dim of output text embeddings. Default: 512, + cache_feature: (bool) Whether to save class embeddings in cache. + Default: True, + cat_bg: (bool) Whether to add background embedding. Default: True. + norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + dataset_name: str = None, + vocabulary: List[str] = None, + templates: str = 'vild', + total_vocab_size: int = 49408, + context_length: int = 77, + embed_dims: int = 512, + num_layers: int = 12, + num_heads: int = 8, + mlp_ratio: int = 4, + output_dims: int = 512, + cache_feature: bool = True, + cat_bg: bool = True, + norm_cfg: dict = dict(type='LN'), + init_cfg: dict = None): + super().__init__(init_cfg) + if isinstance(templates, List): + self.templates = templates + else: + self.templates = get_predefined_templates(templates) + + assert dataset_name is not None or vocabulary is not None, \ + "text_encoder required either 'dataset_name' or 'vocabulary'" + assert dataset_name is None or vocabulary is None, \ + "there is conflict between 'dataset_name' and 'vocabulary'" + self.dataset_name = dataset_name + self.vocabulary = vocabulary + self.num_pos = context_length + self.token_embedding = nn.Embedding(total_vocab_size, embed_dims) + self.positional_embedding = nn.Parameter( + torch.empty(context_length, embed_dims)) + self.text_projection = nn.Parameter( + torch.empty(embed_dims, output_dims)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.transformer = ModuleList() + self.register_buffer( + 'attn_mask', self.build_attention_mask(), persistent=False) + for i in range(num_layers): + self.transformer.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=True), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=dict(type='QuickGELU')), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.ln_final = build_norm_layer( + norm_cfg, embed_dims, postfix='_final')[1] + + self.cache_feature = cache_feature + if self.cache_feature: + self.cache = {} + + self._freeze() + + self.cat_bg = cat_bg + if self.cat_bg: + self.bg_embed = nn.Parameter( + torch.randn(1, self.text_projection.shape[1])) + + @property + def ln_final(self): + return getattr(self, self.final_name) + + def build_attention_mask(self): + """lazily create causal attention mask, with full attention between the + tokens. + + pytorch uses additive attention mask; fill with -inf + """ + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def _freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.cat_bg: + nn.init.normal_( + self.bg_embed, + std=self.bg_embed.shape[1]**-0.5, + ) + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + state_dict = checkpoint.copy() + para_prefix = 'text_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + load_state_dict(self, state_dict, strict=False, logger=None) + + else: + super().init_weights() + + @torch.no_grad() + def encode_text(self, text, normalize=False): + """encode class token.""" + + embed_device = self.token_embedding.weight.device + x = self.token_embedding( + text.to(embed_device)) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + for block in self.transformer: + x = block(query=x, attn_masks=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def template_encode(self, vocabulary): + """Prompt engineering.""" + text_embed_bucket = [] + for template in self.templates: + text_inputs = tokenizer.tokenize( + [template.format(noun) for noun in vocabulary]) + text_embed = self.encode_text(text_inputs, normalize=True) + text_embed_bucket.append(text_embed) + text_embed = torch.stack(text_embed_bucket).mean(dim=0) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + return text_embed + + def forward(self): + """Forward function.""" + if self.dataset_name is None: # encoding vocabulary directly + class_names = self.vocabulary + if self.cache_feature: + new_classes = [ + word for word in class_names if word not in self.cache + ] + if len(new_classes) > 0: + class_embeds = self.template_encode(new_classes) + self.cache.update(dict(zip(new_classes, class_embeds))) + class_embeds = torch.stack( + [self.cache[word] for word in class_names]) + else: + class_embeds = self.template_encode(class_names) + + else: # encoding the classes of the dataset + class_names = get_classes(self.dataset_name) + if class_names[0] == 'background': + class_names = class_names[1:] + if self.cache_feature: + if self.dataset_name not in self.cache: + class_embeds = self.template_encode(class_names) + self.cache[self.dataset_name] = class_embeds + else: + class_embeds = self.cache[self.dataset_name] + else: + class_embeds = self.template_encode(class_names) + + if self.cat_bg: + class_embeds = torch.cat([class_embeds, self.bg_embed]) + class_embeds = F.normalize(class_embeds, p=2, dim=-1) + return self.logit_scale.exp() * class_embeds + + +@MODELS.register_module() +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/main/clip/model.py + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index fc142f16fc..c0751b17c0 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -4,6 +4,7 @@ from .encoding import Encoding from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible +from .point_sample import get_uncertain_point_coords_with_randomness from .ppm import DAPPM, PAPPM from .res_layer import ResLayer from .se_layer import SELayer @@ -11,11 +12,16 @@ from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, nlc_to_nchw) from .up_conv_block import UpConvBlock + +# isort: off from .wrappers import Upsample, resize +from .san_layers import MLP, LayerNorm2d, cross_attn_layer __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', - 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck' + 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck', + 'cross_attn_layer', 'LayerNorm2d', 'MLP', + 'get_uncertain_point_coords_with_randomness' ] diff --git a/mmseg/models/utils/point_sample.py b/mmseg/models/utils/point_sample.py new file mode 100644 index 0000000000..1afc957f3d --- /dev/null +++ b/mmseg/models/utils/point_sample.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import point_sample +from torch import Tensor + + +def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor: + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_preds' for the foreground class in `classes`. + + Args: + mask_preds (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (Tensor): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_preds.shape[1] == 1: + gt_class_logits = mask_preds.clone() + else: + inds = torch.arange(mask_preds.shape[0], device=mask_preds.device) + gt_class_logits = mask_preds[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_preds: Tensor, labels: Tensor, num_points: int, + oversample_ratio: float, importance_sample_ratio: float) -> Tensor: + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (Tensor): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (float): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_preds.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=mask_preds.device) + point_logits = point_sample(mask_preds, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=mask_preds.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand( + batch_size, num_random_points, 2, device=mask_preds.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/mmseg/models/utils/san_layers.py b/mmseg/models/utils/san_layers.py new file mode 100644 index 0000000000..2267686daf --- /dev/null +++ b/mmseg/models/utils/san_layers.py @@ -0,0 +1,418 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501 +# Copyright (c) 2023 MendelXu. +# Licensed under the MIT License + +import warnings +from typing import Optional + +import torch +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from torch import Tensor, nn +from torch.nn import functional as F + + +def cross_attn_with_self_bias( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +): + """Forward function of multi-head attention. Modified from + multi_head_attention_forward in + https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + """ # noqa: E501 + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, \ + 'embed_dim must be divisible by num_heads' + scaling = float(head_dim)**-0.5 + + if not use_separate_proj_weight: + if (query is key or torch.equal( + query, key)) and (key is value or torch.equal(key, value)): + # self-attention + raise NotImplementedError('self-attention is not implemented') + + elif key is value or torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function + # with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + q_k = None + q_v = None + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1) + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + q_k = F.linear(query, _w, _b) + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + q_v = F.linear(query, _w, _b) + else: + q_proj_weight_non_opt = \ + torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = \ + torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = \ + torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, + in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, + in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, + in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool + ), 'Only float, byte, and bool types are supported for ' \ + 'attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + 'The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), key.size(0) + ]: + raise RuntimeError( + 'The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + 'Byte tensor for key_padding_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, 'bias cannot be added to static key.' + assert static_v is None, 'bias cannot be added to static value.' + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_k = q_k.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_v = q_v.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat( + [ + k, + torch.zeros( + (k.size(0), 1) + k.size()[2:], + dtype=k.dtype, + device=k.device), + ], + dim=1, + ) + v = torch.cat( + [ + v, + torch.zeros( + (v.size(0), 1) + v.size()[2:], + dtype=v.dtype, + device=v.device), + ], + dim=1, + ) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list( + attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, + tgt_len, src_len) + # attn_out_weights: [bsz * num_heads, tgt_len, src_len] + # ->[bsz * num_heads, tgt_len, src_len+1] + self_weight = (q * q_k).sum( + dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1] + total_attn_output_weights = torch.cat([attn_output_weights, self_weight], + dim=-1) + total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1) + total_attn_output_weights = F.dropout( + total_attn_output_weights, p=dropout_p, training=training) + attn_output_weights = \ + total_attn_output_weights[:, :, : -1] + # [bsz * num_heads, tgt_len, src_len] + self_weight = \ + total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1] + + attn_output = torch.bmm(attn_output_weights, + v) # [bsz * num_heads, tgt_len, head_dim] + attn_output = (attn_output + self_weight * q_v + ) # [bsz * num_heads, tgt_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view( + tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + return attn_output, attn_output_weights # .sum(dim=1) / num_heads + else: + return attn_output, None + + +def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias): + """Implementation of transformer layer with cross attention. The cross + attention shares the embedding weights with self-attention of tf_layer. + Args: + tf_layer: (TransformerEncoderLayer): The Module of transformer layer. + x (Tensor): query [K,N,C] + mem (Tensor): key and value [L,N,C] + attn_bias (Tensor): attention bias [N*num_head,K,L] + + Return: + x (Tensor): cross attention output [K,N,C] + """ + self_attn_layer = tf_layer.attentions[0].attn + attn_layer_paras = { + 'embed_dim_to_check': self_attn_layer.embed_dim, + 'num_heads': self_attn_layer.num_heads, + 'in_proj_weight': self_attn_layer.in_proj_weight, + 'in_proj_bias': self_attn_layer.in_proj_bias, + 'bias_k': self_attn_layer.bias_k, + 'bias_v': self_attn_layer.bias_v, + 'add_zero_attn': self_attn_layer.add_zero_attn, + 'dropout_p': self_attn_layer.dropout, + 'out_proj_weight': self_attn_layer.out_proj.weight, + 'out_proj_bias': self_attn_layer.out_proj.bias, + 'training': self_attn_layer.training + } + + q_x = tf_layer.norms[0](x) + k_x = v_x = tf_layer.norms[0](mem) + x = x + cross_attn_with_self_bias( + q_x, + k_x, + v_x, + attn_mask=attn_bias, + need_weights=False, + **attn_layer_paras)[0] + x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x) + return x + + +class LayerNorm2d(nn.Module): + """A LayerNorm variant, popularized by Transformers, that performs point- + wise mean and variance normalization over the channel dimension for inputs + that have shape (batch_size, channels, height, width). + + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape, ) + + def forward(self, x: torch.Tensor): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, + input_dim, + hidden_dim, + output_dim, + num_layers, + affine_func=nn.Linear): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + affine_func(n, k) + for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: torch.Tensor): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index f69043764a..0a2af58c6e 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -11,23 +11,60 @@ vaihingen_palette, voc_classes, voc_palette) # yapf: enable from .collect_env import collect_env +from .get_templates import get_predefined_templates from .io import datafrombytes from .misc import add_prefix, stack_batch from .set_env import register_all_modules +from .tokenizer import tokenize from .typing_utils import (ConfigType, ForwardResults, MultiConfig, OptConfigType, OptMultiConfig, OptSampleList, SampleList, TensorDict, TensorList) +# isort: off +from .mask_classification import MatchMasks, seg_data_to_instance_data + __all__ = [ - 'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix', - 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig', - 'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', - 'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes', - 'cocostuff_classes', 'loveda_classes', 'potsdam_classes', - 'vaihingen_classes', 'isaid_classes', 'stare_classes', - 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', - 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', - 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', - 'datafrombytes', 'synapse_palette', 'synapse_classes', 'bdd100k_classes', - 'bdd100k_palette' + 'collect_env', + 'register_all_modules', + 'stack_batch', + 'add_prefix', + 'ConfigType', + 'OptConfigType', + 'MultiConfig', + 'OptMultiConfig', + 'SampleList', + 'OptSampleList', + 'TensorDict', + 'TensorList', + 'ForwardResults', + 'cityscapes_classes', + 'ade_classes', + 'voc_classes', + 'cocostuff_classes', + 'loveda_classes', + 'potsdam_classes', + 'vaihingen_classes', + 'isaid_classes', + 'stare_classes', + 'cityscapes_palette', + 'ade_palette', + 'voc_palette', + 'cocostuff_palette', + 'loveda_palette', + 'potsdam_palette', + 'vaihingen_palette', + 'isaid_palette', + 'stare_palette', + 'dataset_aliases', + 'get_classes', + 'get_palette', + 'datafrombytes', + 'synapse_palette', + 'synapse_classes', + 'get_predefined_templates', + 'tokenize', + 'seg_data_to_instance_data', + 'MatchMasks', + 'bdd100k_classes', + 'bdd100k_palette', ] diff --git a/mmseg/utils/bpe_simple_vocab_16e6.txt.gz b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000..7b5088a527 Binary files /dev/null and b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz differ diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py index 122e63fcc4..5ab35f99dc 100644 --- a/mmseg/utils/class_names.py +++ b/mmseg/utils/class_names.py @@ -52,6 +52,21 @@ def voc_classes(): ] +def pcontext_classes(): + """Pascal Context class names for external use.""" + return [ + 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', + 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', + 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', + 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', + 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', + 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', + 'wood' + ] + + def cocostuff_classes(): """CocoStuff class names for external use.""" return [ @@ -306,6 +321,25 @@ def voc_palette(): [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] +def pcontext_palette(): + """Pascal Context palette for external use.""" + return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def cocostuff_palette(): """CocoStuff palette for external use.""" return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], @@ -443,6 +477,7 @@ def bdd100k_palette(): 'cityscapes': ['cityscapes'], 'ade': ['ade', 'ade20k'], 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'pcontext': ['pcontext', 'pascal_context', 'voc2010'], 'loveda': ['loveda'], 'potsdam': ['potsdam'], 'vaihingen': ['vaihingen'], diff --git a/mmseg/utils/get_templates.py b/mmseg/utils/get_templates.py new file mode 100644 index 0000000000..7e9032ba96 --- /dev/null +++ b/mmseg/utils/get_templates.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +PREDEFINED_TEMPLATES = { + 'imagenet': [ + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + ], + 'vild': [ + 'a photo of a {}.', + 'This is a photo of a {}', + 'There is a {} in the scene', + 'There is the {} in the scene', + 'a photo of a {} in the scene', + 'a photo of a small {}.', + 'a photo of a medium {}.', + 'a photo of a large {}.', + 'This is a photo of a small {}.', + 'This is a photo of a medium {}.', + 'This is a photo of a large {}.', + 'There is a small {} in the scene.', + 'There is a medium {} in the scene.', + 'There is a large {} in the scene.', + ], +} + + +def get_predefined_templates(template_set_name: str) -> List[str]: + if template_set_name not in PREDEFINED_TEMPLATES: + raise ValueError(f'Template set {template_set_name} not found') + return PREDEFINED_TEMPLATES[template_set_name] diff --git a/mmseg/utils/mask_classification.py b/mmseg/utils/mask_classification.py new file mode 100644 index 0000000000..205d525975 --- /dev/null +++ b/mmseg/utils/mask_classification.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmcv.ops import point_sample +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS +from mmseg.utils import ConfigType, SampleList + + +def seg_data_to_instance_data(ignore_index: int, + batch_data_samples: SampleList): + """Convert the paradigm of ground truth from semantic segmentation to + instance segmentation. + + Args: + ignore_index (int): The label index to be ignored. + batch_data_samples (List[SegDataSample]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + - batch_gt_instances (List[InstanceData]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (List[Dict]): List of image meta information. + """ + batch_gt_instances = [] + + for data_sample in batch_data_samples: + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances + + +class MatchMasks: + """Match the predictions to category labels. + + Args: + num_points (int): the number of sampled points to compute cost. + num_queries (int): the number of prediction masks. + num_classes (int): the number of classes. + assigner (BaseAssigner): the assigner to compute matching. + """ + + def __init__(self, + num_points: int, + num_queries: int, + num_classes: int, + assigner: ConfigType = None): + assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ + 'cannot be None' + assert num_points > 0, 'num_points should be a positive integer.' + self.num_points = num_points + self.num_queries = num_queries + self.num_classes = num_classes + self.assigner = TASK_UTILS.build(assigner) + + def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], + batch_gt_instances: List[InstanceData]) -> Tuple: + """Compute best mask matches for all images for a decoder layer. + + Args: + cls_scores (List[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds (List[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (List[InstanceData]): each contains + ``labels`` and ``masks``. + + Returns: + tuple: a tuple containing the following targets. + + - labels (List[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - mask_targets (List[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights (List[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to + average the loss. `avg_factor` is usually equal + to the number of positive priors. + """ + batch_size = cls_scores.shape[0] + results = dict({ + 'labels': [], + 'mask_targets': [], + 'mask_weights': [], + }) + for i in range(batch_size): + labels, mask_targets, mask_weights\ + = self._get_targets_single(cls_scores[i], + mask_preds[i], + batch_gt_instances[i]) + results['labels'].append(labels) + results['mask_targets'].append(mask_targets) + results['mask_weights'].append(mask_weights) + + # shape (batch_size, num_queries) + labels = torch.stack(results['labels'], dim=0) + # shape (batch_size, num_gts, h, w) + mask_targets = torch.cat(results['mask_targets'], dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(results['mask_weights'], dim=0) + + avg_factor = sum( + [len(gt_instances.labels) for gt_instances in batch_gt_instances]) + + res = (labels, mask_targets, mask_weights, avg_factor) + + return res + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData) \ + -> Tuple[Tensor, Tensor, Tensor]: + """Compute a set of best mask matches for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + # when "gt_labels" is empty, classify all queries to background + if len(gt_labels) == 0: + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + mask_targets = gt_labels + mask_weights = gt_labels.new_zeros((self.num_queries, )) + return labels, mask_targets, mask_weights + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + matched_quiery_inds, matched_label_inds = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances) + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[matched_quiery_inds] = gt_labels[matched_label_inds] + + mask_weights = gt_labels.new_zeros((self.num_queries, )) + mask_weights[matched_quiery_inds] = 1 + mask_targets = gt_masks[matched_label_inds] + + return labels, mask_targets, mask_weights diff --git a/mmseg/utils/tokenizer.py b/mmseg/utils/tokenizer.py new file mode 100644 index 0000000000..d56f5fae60 --- /dev/null +++ b/mmseg/utils/tokenizer.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""CLIP tokenizer. + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright +(c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import List, Union + +import ftfy +import regex as re +import torch + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """Returns list of utf-8 byte and a corresponding list of unicode strings. + + The reversible bpe codes work on unicode strings. This means you need a + large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. This is a significant percentage of your normal, say, + 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and + unicode strings. And avoids mapping to whitespace/control characters the + bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer: + + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', '' + ] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = '|'.join(special_tokens) + self.pat = re.compile( + special + + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: # noqa: E722, E261 + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + + +def tokenize(texts: Union[str, List[str]], + context_length: int = 77) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, + shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[''] + eot_token = _tokenizer.encoder[''] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper.""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, + texts: Union[str, List[str]], + context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it + # more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt new file mode 100644 index 0000000000..2195d0d9ef --- /dev/null +++ b/requirements/multimodal.txt @@ -0,0 +1,2 @@ +ftfy +regex diff --git a/requirements/tests.txt b/requirements/tests.txt index 74fc76146d..3fff2520d7 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,6 +1,8 @@ codecov flake8 +ftfy interrogate pytest +regex xdoctest>=0.10.0 yapf diff --git a/setup.py b/setup.py index 7316ed1c84..45d923db60 100755 --- a/setup.py +++ b/setup.py @@ -194,6 +194,7 @@ def add_mim_extension(): 'tests': parse_requirements('requirements/tests.txt'), 'optional': parse_requirements('requirements/optional.txt'), 'mim': parse_requirements('requirements/mminstall.txt'), + 'multimodal': parse_requirements('requirements/multimodal.txt'), }, ext_modules=[], zip_safe=False) diff --git a/tests/test_models/test_assigners/test_hungarian_assigner.py b/tests/test_models/test_assigners/test_hungarian_assigner.py new file mode 100644 index 0000000000..2cdb1de839 --- /dev/null +++ b/tests/test_models/test_assigners/test_hungarian_assigner.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine.structures import InstanceData + +from mmseg.models.assigners import HungarianAssigner + + +class TestHungarianAssigner(TestCase): + + def test_init(self): + with self.assertRaises(AssertionError): + HungarianAssigner([]) + + def test_hungarian_match_assigner(self): + assigner = HungarianAssigner([ + dict(type='ClassificationCost', weight=2.0), + dict(type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0) + ]) + num_classes = 3 + num_masks = 10 + num_points = 20 + gt_instances = InstanceData() + gt_instances.labels = torch.randint(0, num_classes, (num_classes, )) + gt_instances.masks = torch.randint(0, 2, (num_classes, num_points)) + pred_instances = InstanceData() + pred_instances.scores = torch.rand((num_masks, num_classes)) + pred_instances.masks = torch.rand((num_masks, num_points)) + + matched_quiery_inds, matched_label_inds = \ + assigner.assign(pred_instances, gt_instances) + unique_quiery_inds = torch.unique(matched_quiery_inds) + unique_label_inds = torch.unique(matched_label_inds) + self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds)) + self.assertTrue( + torch.equal(unique_label_inds, torch.arange(0, num_classes))) + + def test_cls_match_cost(self): + num_classes = 3 + num_masks = 10 + gt_instances = InstanceData() + gt_instances.labels = torch.randint(0, num_classes, (num_classes, )) + pred_instances = InstanceData() + pred_instances.scores = torch.rand((num_masks, num_classes)) + + # test ClassificationCost + assigner = HungarianAssigner(dict(type='ClassificationCost')) + matched_quiery_inds, matched_label_inds = \ + assigner.assign(pred_instances, gt_instances) + unique_quiery_inds = torch.unique(matched_quiery_inds) + unique_label_inds = torch.unique(matched_label_inds) + self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds)) + self.assertTrue( + torch.equal(unique_label_inds, torch.arange(0, num_classes))) + + def test_mask_match_cost(self): + num_classes = 3 + num_masks = 10 + num_points = 20 + gt_instances = InstanceData() + gt_instances.masks = torch.randint(0, 2, (num_classes, num_points)) + pred_instances = InstanceData() + pred_instances.masks = torch.rand((num_masks, num_points)) + + # test DiceCost + assigner = HungarianAssigner( + dict(type='DiceCost', pred_act=True, eps=1.0)) + assign_result = assigner.assign(pred_instances, gt_instances) + self.assertTrue(len(assign_result[0]) == len(assign_result[1])) + + # test CrossEntropyLossCost + assigner = HungarianAssigner( + dict(type='CrossEntropyLossCost', use_sigmoid=True)) + assign_result = assigner.assign(pred_instances, gt_instances) + self.assertTrue(len(assign_result[0]) == len(assign_result[1])) diff --git a/tests/test_models/test_backbones/test_clip_text_encoder.py b/tests/test_models/test_backbones/test_clip_text_encoder.py new file mode 100644 index 0000000000..ea06c5b5b3 --- /dev/null +++ b/tests/test_models/test_backbones/test_clip_text_encoder.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine import Config +from mmengine.registry import init_default_scope + +from mmseg.models.text_encoder import CLIPTextEncoder +from mmseg.utils import get_classes + + +def test_clip_text_encoder(): + init_default_scope('mmseg') + # test vocabulary + output_dims = 8 + embed_dims = 32 + vocabulary = ['cat', 'dog', 'bird', 'car', 'bike'] + cfg = dict( + vocabulary=vocabulary, + templates=['a photo of a {}.'], + embed_dims=embed_dims, + output_dims=output_dims) + cfg = Config(cfg) + + text_encoder = CLIPTextEncoder(**cfg) + if torch.cuda.is_available(): + text_encoder = text_encoder.cuda() + + with torch.no_grad(): + class_embeds = text_encoder() + assert class_embeds.shape == (len(vocabulary) + 1, output_dims) + + # test dataset name + cfg = dict( + dataset_name='vaihingen', + templates=['a photo of a {}.'], + embed_dims=embed_dims, + output_dims=output_dims) + cfg = Config(cfg) + + text_encoder = CLIPTextEncoder(**cfg) + with torch.no_grad(): + class_embeds = text_encoder() + class_nums = len(get_classes('vaihingen')) + assert class_embeds.shape == (class_nums + 1, output_dims) diff --git a/tests/test_models/test_heads/test_san_head.py b/tests/test_models/test_heads/test_san_head.py new file mode 100644 index 0000000000..af85a6e2ca --- /dev/null +++ b/tests/test_models/test_heads/test_san_head.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine import Config +from mmengine.structures import PixelData + +from mmseg.models.decode_heads import SideAdapterCLIPHead +from mmseg.structures import SegDataSample +from .utils import list_to_cuda + + +def test_san_head(): + H, W = (64, 64) + clip_channels = 64 + img_channels = 4 + num_queries = 40 + out_dims = 64 + num_classes = 19 + cfg = dict( + num_classes=num_classes, + deep_supervision_idxs=[4], + san_cfg=dict( + in_channels=img_channels, + embed_dims=128, + clip_channels=clip_channels, + num_queries=num_queries, + cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2), + cfg_decoder=dict( + num_heads=4, + num_layers=1, + embed_channels=32, + mlp_channels=32, + num_mlp=2, + rescale=True)), + maskgen_cfg=dict( + sos_token_num=num_queries, + embed_dims=clip_channels, + out_dims=out_dims, + num_heads=4, + mlp_ratio=2), + train_cfg=dict( + num_points=100, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='ClassificationCost', weight=2.0), + dict( + type='CrossEntropyLossCost', + weight=5.0, + use_sigmoid=True), + dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0) + ])), + loss_decode=[ + dict( + type='CrossEntropyLoss', + loss_name='loss_cls_ce', + loss_weight=2.0, + class_weight=[1.0] * num_classes + [0.1]), + dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_name='loss_mask_ce', + loss_weight=5.0), + dict( + type='DiceLoss', + ignore_index=None, + naive_dice=True, + eps=1, + loss_name='loss_mask_dice', + loss_weight=5.0) + ]) + + cfg = Config(cfg) + head = SideAdapterCLIPHead(**cfg) + + inputs = torch.rand((2, img_channels, H, W)) + clip_feature = [[ + torch.rand((2, clip_channels, H // 2, W // 2)), + torch.rand((2, clip_channels)) + ], + [ + torch.rand((2, clip_channels, H // 2, W // 2)), + torch.rand((2, clip_channels)) + ], + [ + torch.rand((2, clip_channels, H // 2, W // 2)), + torch.rand((2, clip_channels)) + ], + [ + torch.rand((2, clip_channels, H // 2, W // 2)), + torch.rand((2, clip_channels)) + ]] + class_embed = torch.rand((num_classes + 1, out_dims)) + + data_samples = [] + for i in range(2): + data_sample = SegDataSample() + img_meta = {} + img_meta['img_shape'] = (H, W) + img_meta['ori_shape'] = (H, W) + data_sample.gt_sem_seg = PixelData( + data=torch.randint(0, num_classes, (1, H, W))) + data_sample.set_metainfo(img_meta) + data_samples.append(data_sample) + + batch_img_metas = [] + for data_sample in data_samples: + batch_img_metas.append(data_sample.metainfo) + + if torch.cuda.is_available(): + head = head.cuda() + data = list_to_cuda([inputs, clip_feature, class_embed]) + for data_sample in data_samples: + data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda() + else: + data = [inputs, clip_feature, class_embed] + + # loss test + loss_dict = head.loss(data, data_samples, None) + assert isinstance(loss_dict, dict) + + # prediction test + with torch.no_grad(): + seg_logits = head.predict(data, batch_img_metas, None) + assert seg_logits.shape == torch.Size((2, num_classes, H, W)) diff --git a/tests/test_models/test_heads/utils.py b/tests/test_models/test_heads/utils.py index 335e261a5e..7282340155 100644 --- a/tests/test_models/test_heads/utils.py +++ b/tests/test_models/test_heads/utils.py @@ -20,3 +20,12 @@ def to_cuda(module, data): for i in range(len(data)): data[i] = data[i].cuda() return module, data + + +def list_to_cuda(data): + if isinstance(data, list): + for i in range(len(data)): + data[i] = list_to_cuda(data[i]) + return data + else: + return data.cuda() diff --git a/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py b/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py new file mode 100644 index 0000000000..75258d89a7 --- /dev/null +++ b/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import ConfigDict + +from mmseg.models import build_segmentor +from tests.test_models.test_segmentors.utils import \ + _segmentor_forward_train_test + + +def test_multimodal_encoder_decoder(): + + cfg = ConfigDict( + type='MultimodalEncoderDecoder', + asymetric_input=False, + image_encoder=dict(type='ExampleBackbone', out_indices=[1, 2, 3, 4]), + text_encoder=dict( + type='ExampleTextEncoder', + vocabulary=['A', 'B', 'C'], + output_dims=3), + decode_head=dict( + type='ExampleDecodeHead', out_channels=1, num_classes=2), + train_cfg=None, + test_cfg=dict(mode='whole')) + segmentor = build_segmentor(cfg) + _segmentor_forward_train_test(segmentor) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 6b440df906..ac31e2b277 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -52,15 +52,22 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): @MODELS.register_module() class ExampleBackbone(nn.Module): - def __init__(self): + def __init__(self, out_indices=None): super().__init__() self.conv = nn.Conv2d(3, 3, 3) + self.out_indices = out_indices def init_weights(self, pretrained=None): pass def forward(self, x): - return [self.conv(x)] + if self.out_indices is None: + return [self.conv(x)] + else: + outs = [] + for i in self.out_indices: + outs.append(self.conv(x)) + return outs @MODELS.register_module() @@ -74,6 +81,18 @@ def forward(self, inputs): return self.cls_seg(inputs[0]) +@MODELS.register_module() +class ExampleTextEncoder(nn.Module): + + def __init__(self, vocabulary=None, output_dims=None): + super().__init__() + self.vocabulary = vocabulary + self.output_dims = output_dims + + def forward(self): + return torch.randn((len(self.vocabulary), self.output_dims)) + + @MODELS.register_module() class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): @@ -132,3 +151,32 @@ def _segmentor_forward_train_test(segmentor): data_batch = dict(inputs=imgs, data_samples=data_samples) results = segmentor.forward(imgs, data_samples, mode='tensor') assert isinstance(results, torch.Tensor) + + +def _segmentor_predict(segmentor): + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + # batch_size=2 for BatchNorm + mm_inputs = _demo_mm_inputs(num_classes=num_classes) + + # convert to cuda Tensor if applicable + if torch.cuda.is_available(): + segmentor = segmentor.cuda() + + # check data preprocessor + if not hasattr(segmentor, + 'data_preprocessor') or segmentor.data_preprocessor is None: + segmentor.data_preprocessor = SegDataPreProcessor() + + mm_inputs = segmentor.data_preprocessor(mm_inputs, True) + imgs = mm_inputs.pop('imgs') + data_samples = mm_inputs.pop('data_samples') + + # Test predict + with torch.no_grad(): + segmentor.eval() + data_batch = dict(inputs=imgs, data_samples=data_samples) + outputs = segmentor.predict(**data_batch) + assert isinstance(outputs, list) diff --git a/tools/model_converters/clip2mmseg.py b/tools/model_converters/clip2mmseg.py new file mode 100644 index 0000000000..9a97e4b04a --- /dev/null +++ b/tools/model_converters/clip2mmseg.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_vitlayer(paras): + new_para_name = '' + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:]) + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_translayer(paras): + new_para_name = '' + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:]) + else: + print(f'Wrong for {paras}') + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_key_name(ckpt, visual_split): + new_ckpt = OrderedDict() + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'visual': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'transformer': + new_layer_name = 'layers' + layer_index = key_list[3] + paras = key_list[4:] + if int(layer_index) < visual_split: + new_para_name = convert_vitlayer(paras) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + else: + new_para_name = convert_translayer(paras) + new_transform_name = 'decode_head.rec_with_attnbias' + new_layer_name = 'layers' + layer_index = str(int(layer_index) - visual_split) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'proj': + new_name = 'decode_head.rec_with_attnbias.proj.weight' + elif key_list[1] == 'ln_post': + new_name = k.replace('visual', 'decode_head.rec_with_attnbias') + else: + print(f'pop parameter: {k}') + continue + else: + text_encoder_name = 'text_encoder' + if key_list[0] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[2] + paras = key_list[3:] + new_para_name = convert_translayer(paras) + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[0] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = 'text_encoder.' + k + else: + print(f'pop parameter: {k}') + continue + new_ckpt[new_name] = v + + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]): + visual_split = 9 + elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]): + visual_split = 18 + else: + print('Make sure the clip model is ViT-B/16 or ViT-L/14!') + visual_split = -1 + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if isinstance(checkpoint, torch.jit.RecursiveScriptModule): + state_dict = checkpoint.state_dict() + else: + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict, visual_split) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/tools/model_converters/san2mmseg.py b/tools/model_converters/san2mmseg.py new file mode 100644 index 0000000000..301a46608e --- /dev/null +++ b/tools/model_converters/san2mmseg.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_key_name(ckpt): + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'clip_visual_extractor': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + key_list[4:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + key_list[4:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + key_list[4:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + + key_list[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + + key_list[-1:]) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[0] == 'side_adapter_network': + decode_head_name = 'decode_head' + module_name = 'side_adapter_network' + if key_list[1] == 'vit_model': + if key_list[2] == 'blocks': + layer_name = 'encode_layers' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'norm1': + new_para_name = '.'.join(['ln1'] + key_list[5:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(key_list[4:]) + new_para_name = new_para_name.replace( + 'attn.qkv.', 'attn.attn.in_proj_') + new_para_name = new_para_name.replace( + 'attn.proj', 'attn.attn.out_proj') + elif paras[0] == 'norm2': + new_para_name = '.'.join(['ln2'] + key_list[5:]) + elif paras[0] == 'mlp': + new_para_name = '.'.join(['ffn'] + key_list[5:]) + new_para_name = new_para_name.replace( + 'fc1', 'layers.0.0') + new_para_name = new_para_name.replace( + 'fc2', 'layers.1') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[2] == 'pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, 'pos_embed']) + elif key_list[2] == 'patch_embed': + new_name = '.'.join([ + decode_head_name, module_name, 'patch_embed', + 'projection', key_list[4] + ]) + else: + print(f'Wrong for {k}') + elif key_list[1] == 'query_embed' or key_list[ + 1] == 'query_pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, key_list[1]]) + elif key_list[1] == 'fusion_layers': + layer_name = 'conv_clips' + layer_index = key_list[2][-1] + paras = '.'.join(key_list[3:]) + new_para_name = paras.replace('input_proj.0', '0') + new_para_name = new_para_name.replace('input_proj.1', '1.conv') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'mask_decoder': + new_name = 'decode_head.' + k + else: + print(f'Wrong for {k}') + elif key_list[0] == 'clip_rec_head': + module_name = 'rec_with_attnbias' + if key_list[1] == 'proj': + new_name = '.'.join( + [decode_head_name, module_name, 'proj.weight']) + elif key_list[1] == 'ln_post': + new_name = '.'.join( + [decode_head_name, module_name, 'ln_post', key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, new_layer_name, layer_index, + new_para_name + ]) + else: + print(f'Wrong for {k}') + elif key_list[0] == 'ov_classifier': + text_encoder_name = 'text_encoder' + if key_list[1] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[1] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = k.replace('ov_classifier', 'text_encoder') + else: + print(f'Wrong for {k}') + elif key_list[0] == 'criterion': + new_name = k + else: + print(f'Wrong for {k}') + new_ckpt[new_name] = v + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main()