Skip to content

Commit

Permalink
[Feature] Support Side Adapter Network (#3232)
Browse files Browse the repository at this point in the history
## Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: [Side Adapter Network for Open-Vocabulary Semantic
Segmentation](https://arxiv.org/abs/2302.12242)
official Code: [SAN](https://github.com/MendelXu/SAN)

## Modification
- Added the parameters of backbone vit for implementing the image
encoder of CLIP.
- Added text encoder code.
- Added segmentor multimodel encoder-decoder code for open-vocabulary
semantic segmentation.
- Added SideAdapterNetwork decode head code.
- Added config files for train and inference.
- Added tools for converting pretrained models.
- Added loss implementation for mask classification model, such as SAN,
Maskformer and remove dependency on mmdetection.
- Added test units for text encoder, multimodel encoder-decoder, san
decode head and hungarian_assigner.

## Use cases
### Convert Models
**pretrained SAN model**
The official pretrained model can be downloaded from
[san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth)
and
[san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth).
Use tools/model_converters/san2mmseg.py to convert offcial model into
mmseg style.
`python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

**pretrained CLIP model**
Use the CLIP model provided by openai to train SAN. The CLIP model can
be download from
[ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt)
and
[ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt).
Use tools/model_converters/clip2mmseg.py to convert model into mmseg
style.
`python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

### Inference
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/test.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py
<TRAINED_MODEL_PATH>`

### Train
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/train.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options
model.pretrained=<PRETRAINED_MODEL_PATH>`

## Comparision Results
### Train on COCO-Stuff164k
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 41.93 | 56.73 | 67.69 |
|                 | mmseg | 41.93 | 56.84 | 67.84 |
| san-vit-large14 | official  | 45.57 | 59.52 | 69.76 |
|                 | mmseg | 45.78 | 59.61 | 69.21 |

### Evaluate on Pascal Context
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 54.05 | 72.96 | 77.77 |
|                 | mmseg | 54.04 | 73.74 | 77.71 |
| san-vit-large14 | official  | 57.53 | 77.56 | 78.89 |
|                 | mmseg | 56.89 | 76.96 | 78.74 |

### Evaluate on Voc12Aug
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 93.86 | 96.61 | 97.11 |
|                 | mmseg | 94.58 | 97.01 | 97.38 |
| san-vit-large14 | official  | 95.17 | 97.61 | 97.63 |
|                 | mmseg | 95.58 | 97.75 | 97.79 |

---------

Co-authored-by: CastleDream <[email protected]>
Co-authored-by: yeedrag <[email protected]>
Co-authored-by: Yang-ChangHui <[email protected]>
Co-authored-by: Xu CAO <[email protected]>
Co-authored-by: xiexinch <[email protected]>
Co-authored-by: 小飞猪 <[email protected]>
  • Loading branch information
7 people committed Sep 20, 2023
1 parent 1471d1e commit 608e319
Show file tree
Hide file tree
Showing 42 changed files with 4,114 additions and 29 deletions.
137 changes: 137 additions & 0 deletions configs/_base_/models/san_vit-b16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)

data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size_divisor=640,
test_cfg=dict(size_divisor=32))

num_classes = 171
model = dict(
type='MultimodalEncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained='pretrain/clip_vit_base_patch16_224.pth',
asymetric_input=True,
encoder_resolution=0.5,
image_encoder=dict(
type='VisionTransformer',
img_size=(224, 224),
patch_size=16,
patch_pad=0,
in_channels=3,
embed_dims=768,
num_layers=9,
num_heads=12,
mlp_ratio=4,
out_origin=True,
out_indices=(2, 5, 8),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
output_cls_token=True,
patch_bias=False,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
act_cfg=dict(type='QuickGELU'),
norm_eval=False,
interpolate_mode='bicubic',
frozen_exclude=['pos_embed']),
text_encoder=dict(
type='CLIPTextEncoder',
dataset_name=None,
templates='vild',
embed_dims=512,
num_layers=12,
num_heads=8,
mlp_ratio=4,
output_dims=512,
cache_feature=True,
cat_bg=True,
norm_cfg=dict(type='LN', eps=1e-5)
),
decode_head=dict(
type='SideAdapterCLIPHead',
num_classes=num_classes,
deep_supervision_idxs=[7],
san_cfg=dict(
in_channels=3,
clip_channels=768,
embed_dims=240,
patch_size=16,
patch_bias=True,
num_queries=100,
cfg_encoder=dict(
num_encode_layer=8,
num_heads=6,
mlp_ratio=4
),
fusion_index=[0, 1, 2, 3],
cfg_decoder=dict(
num_heads=12,
num_layers=1,
embed_channels=256,
mlp_channels=256,
num_mlp=3,
rescale=True),
norm_cfg=dict(type='LN', eps=1e-6),
),
maskgen_cfg=dict(
sos_token_format='cls_token',
sos_token_num=100,
cross_attn=False,
num_layers=3,
embed_dims=768,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
out_dims=512,
final_norm=True,
act_cfg=dict(type='QuickGELU'),
norm_cfg=dict(type='LN', eps=1e-5),
frozen_exclude=[]
),
align_corners=False,
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='ClassificationCost', weight=2.0),
dict(
type='CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(
type='DiceCost',
weight=5.0,
pred_act=True,
eps=1.0)
])),
loss_decode=[dict(type='CrossEntropyLoss',
loss_name='loss_cls_ce',
loss_weight=2.0,
class_weight=[1.0] * num_classes + [0.1]),
dict(type='CrossEntropyLoss',
use_sigmoid=True,
loss_name='loss_mask_ce',
loss_weight=5.0),
dict(type='DiceLoss',
ignore_index=None,
naive_dice=True,
eps=1,
loss_name='loss_mask_dice',
loss_weight=5.0)
]),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable
47 changes: 47 additions & 0 deletions configs/san/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SAN

> [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)
## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/MendelXu/SAN">Official Repo</a>

## Abstract

<!-- [ABSTRACT] -->

This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. We hope our approach will serve as a solid baseline and help ease future research in open-vocabulary semantic segmentation.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/MendelXu/SAN/blob/main/resources/arch.png" width="800"/>
</div>

## Results and models

### COCO-Stuff164k

| Method | Backbone | Pretrained | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | ------------ | --------- | ------- | -------- | -------------- | ------ | ----- | ------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| SAN | ViT-B_16 | CLIP_ViT-B16 | 640x640 | 60000 | 12.61 | - | V100 | 41.93 | 41.77 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906-fd0a7684.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-b16_20230906.log) |
| SAN | ViT-L_14 | CLIP_ViT-L14 | 640x640 | 60000 | 22.84 | - | V100 | 45.78 | 43.99 | - | [model](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907-a11e098f.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/san/san-vit-l14_20230907.log) |

## Notes

git push
The pretrained weights in config files are converted from open_clip models using tools/model_converters/clip2mmseg.py.

## Citation

```bibtex
@inproceedings{xu2023side,
title={Side adapter network for open-vocabulary semantic segmentation},
author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2945--2954},
year={2023}
}
```
82 changes: 82 additions & 0 deletions configs/san/san-vit-b16_coco-stuff164k-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
_base_ = [
'../_base_/models/san_vit-b16.py', '../_base_/datasets/coco-stuff164k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomChoiceResize',
scales=[int(640 * x * 0.1) for x in range(5, 16)],
resize_type='ResizeShortestEdge',
max_size=2560),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=1.0),
dict(type='PhotoMetricDistortion'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackSegInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]

# By default, models are trained on 4 GPUs with 8 images per GPU
train_dataloader = dict(batch_size=8, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/san/clip_vit-base-patch16-224_3rdparty-d08f8887.pth' # noqa
data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
pretrained=pretrained,
text_encoder=dict(dataset_name='coco-stuff164k'),
decode_head=dict(num_classes=171))

# training schedule for 60k
train_cfg = dict(
type='IterBasedTrainLoop',
max_iters=60000,
val_interval=500,
val_begin=55000)
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
by_epoch=False,
interval=10000,
save_best='mIoU'))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'img_encoder': dict(lr_mult=0.1, decay_mult=1.0),
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}),
loss_scale='dynamic',
clip_grad=dict(max_norm=0.01, norm_type=2))

param_scheduler = [
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=0,
end=60000,
by_epoch=False,
)
]
56 changes: 56 additions & 0 deletions configs/san/san-vit-b16_pascal_context-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/san_vit-b16.py',
'../_base_/datasets/pascal_context_59.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
data_preprocessor=data_preprocessor,
pretrained='pretrain/vit_base_patch16_224.pth',
text_encoder=dict(dataset_name='pascal_context'),
decode_head=dict(num_classes=59))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]
65 changes: 65 additions & 0 deletions configs/san/san-vit-b16_voc12aug-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
_base_ = [
'../_base_/models/san_vit-b16.py',
'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)

metainfo = dict(
classes=('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
palette=[[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeShortestEdge', scale=crop_size, max_size=2560),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(
batch_size=1, dataset=dict(metainfo=metainfo, pipeline=test_pipeline))
test_dataloader = val_dataloader

data_preprocessor = dict(
mean=[122.7709, 116.7460, 104.0937],
std=[68.5005, 66.6322, 70.3232],
size_divisor=640,
test_cfg=dict(size_divisor=32))
model = dict(
data_preprocessor=data_preprocessor,
pretrained='pretrain/vit_base_patch16_224.pth',
text_encoder=dict(dataset_name='voc'),
decode_head=dict(num_classes=20))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]
Loading

0 comments on commit 608e319

Please sign in to comment.