Skip to content

Commit

Permalink
[CodeCamp2023-608] Add Adabins model (#3257)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang-Changhui committed Sep 13, 2023
1 parent c46cc85 commit b6090a1
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 0 deletions.
Binary file added demo/classroom__rgb_00283.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
46 changes: 46 additions & 0 deletions projects/Adabins/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# AdaBins: Depth Estimation Using Adaptive Bins

## Reference

> [AdaBins: Depth Estimation Using Adaptive Bins](https://arxiv.org/abs/2011.14141)
## Introduction

<a href="https://github.com/shariqfarooq123/AdaBins">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/Adabins">Code Snippet</a>

## <img src="https://user-images.githubusercontent.com/34859558/190043857-bfbdaf8b-d2dc-4fff-81c7-e0aac50851f9.png" width="25"/> Abstract

We address the problem of estimating a high quality dense depth map from a single RGB input image. We start out with a baseline encoder-decoder convolutional neural network architecture and pose the question of how the global processing of information can help improve overall depth estimation. To this end, we propose a transformer-based architecture block that divides the depth range into bins whose center value is estimated adaptively per image. The final depth values are estimated as linear combinations of the bin centers. We call our new building block AdaBins. Our results show a decisive improvement over the state-of-the-art on several popular depth datasets across all metrics.We also validate the effectiveness of the proposed block with an ablation study and provide the code and corresponding pre-trained weights of the new state-of-the-art model.

Our main contributions are the following:

- We propose an architecture building block that performs global processing of the scene’s information.We propose to divide the predicted depth range into bins where the bin widths change per image. The final depth estimation is a linear combination of the bin center values.
- We show a decisive improvement for supervised single image depth estimation across all metrics for the two most popular datasets, NYU and KITTI.
- We analyze our findings and investigate different modifications on the proposed AdaBins block and study their effect on the accuracy of the depth estimation.

<div align="center">
<img src="https://github.com/open-mmlab/mmsegmentation/assets/15952744/915bcd5a-9dc2-4602-a6e7-055ff5d4889f" width = "1000" />
</div>

## <img src="https://user-images.githubusercontent.com/34859558/190044217-8f6befc2-7f20-473d-b356-148e06265205.png" width="25"/> Performance

### NYU and KITTI

| Model | Encoder | Training epoch | Batchsize | Train Resolution | δ1 | δ2 | δ3 | REL | RMS | RMS log | params(M) | Links |
| ------------- | --------------- | -------------- | --------- | ---------------- | ----- | ----- | ----- | ----- | ----- | ------- | --------- | ----------------------------------------------------------------------------------------------------------------------- |
| AdaBins_nyu | EfficientNet-B5 | 25 | 16 | 416x544 | 0.903 | 0.984 | 0.997 | 0.103 | 0.364 | 0.044 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient_b5_nyu_third-party-f68d6bd3.pth) |
| AdaBins_kitti | EfficientNet-B5 | 25 | 16 | 352x764 | 0.964 | 0.995 | 0.999 | 0.058 | 2.360 | 0.088 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient-b5_kitty_third-party-a1aa6f36.pth) |

## Citation

```bibtex
@article{10.1109/cvpr46437.2021.00400,
author = {Bhat, S. A. and Alhashim, I. and Wonka, P.},
title = {Adabins: depth estimation using adaptive bins},
journal = {2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021},
doi = {10.1109/cvpr46437.2021.00400}
}
```
4 changes: 4 additions & 0 deletions projects/Adabins/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .adabins_backbone import AdabinsBackbone

__all__ = ['AdabinsBackbone']
141 changes: 141 additions & 0 deletions projects/Adabins/backbones/adabins_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_conv_layer
from mmengine.model import BaseModule

from mmseg.registry import MODELS


class UpSampleBN(nn.Module):
""" UpSample module
Args:
skip_input (int): the input feature
output_features (int): the output feature
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict, optional): The activation layer of AAM:
Aggregate Attention Module.
"""

def __init__(self,
skip_input,
output_features,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='LeakyReLU')):
super().__init__()

self._net = nn.Sequential(
ConvModule(
in_channels=skip_input,
out_channels=output_features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
ConvModule(
in_channels=output_features,
out_channels=output_features,
kernel_size=3,
stride=1,
padding=1,
bias=True,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
))

def forward(self, x, concat_with):
up_x = F.interpolate(
x,
size=[concat_with.size(2),
concat_with.size(3)],
mode='bilinear',
align_corners=True)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)


class Encoder(nn.Module):
""" the efficientnet_b5 model
Args:
basemodel_name (str): the name of base model
"""

def __init__(self, basemodel_name):
super().__init__()
self.original_model = timm.create_model(
basemodel_name, pretrained=True)
# Remove last layer
self.original_model.global_pool = nn.Identity()
self.original_model.classifier = nn.Identity()

def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if k == 'blocks':
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features


@MODELS.register_module()
class AdabinsBackbone(BaseModule):
""" the backbone of the adabins
Args:
basemodel_name (str):the name of base model
num_features (int): the middle feature
num_classes (int): the classes number
bottleneck_features (int): the bottleneck features
conv_cfg (dict): Config dict for convolution layer.
"""

def __init__(self,
basemodel_name,
num_features=2048,
num_classes=128,
bottleneck_features=2048,
conv_cfg=dict(type='Conv')):
super().__init__()
self.encoder = Encoder(basemodel_name)
features = int(num_features)
self.conv2 = build_conv_layer(
conv_cfg,
bottleneck_features,
features,
kernel_size=1,
stride=1,
padding=1)
self.up1 = UpSampleBN(
skip_input=features // 1 + 112 + 64, output_features=features // 2)
self.up2 = UpSampleBN(
skip_input=features // 2 + 40 + 24, output_features=features // 4)
self.up3 = UpSampleBN(
skip_input=features // 4 + 24 + 16, output_features=features // 8)
self.up4 = UpSampleBN(
skip_input=features // 8 + 16 + 8, output_features=features // 16)

self.conv3 = build_conv_layer(
conv_cfg,
features // 16,
num_classes,
kernel_size=3,
stride=1,
padding=1)

def forward(self, x):
features = self.encoder(x)
x_block0, x_block1, x_block2, x_block3, x_block4 = features[
3], features[4], features[5], features[7], features[10]
x_d0 = self.conv2(x_block4)
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
out = self.conv3(x_d4)
return out
32 changes: 32 additions & 0 deletions projects/Adabins/configs/_base_/datasets/nyu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
dataset_type = 'NYUDataset'
data_root = 'data/nyu'

test_pipeline = [
dict(dict(type='LoadImageFromFile', to_float32=True)),
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
dict(
type='PackSegInputs',
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'category_id'))
]

val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(
img_path='images/test', depth_map_path='annotations/test'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
type='DepthMetric', max_depth_eval=10.0, crop_type='nyu_crop')
test_evaluator = val_evaluator
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
15 changes: 15 additions & 0 deletions projects/Adabins/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False

tta_model = dict(type='SegTTAModel')
35 changes: 35 additions & 0 deletions projects/Adabins/configs/_base_/models/Adabins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='DepthEstimator',
data_preprocessor=data_preprocessor,
# pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='AdabinsBackbone',
basemodel_name='tf_efficientnet_b5_ap',
num_features=2048,
num_classes=128,
bottleneck_features=2048,
),
decode_head=dict(
type='AdabinsHead',
in_channels=128,
n_query_channels=128,
patch_size=16,
embedding_dim=128,
num_heads=4,
n_bins=256,
min_val=0.001,
max_val=10,
norm='linear'),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = [
'../_base_/models/Adabins.py', '../_base_/datasets/nyu.py',
'../_base_/default_runtime.py'
]
custom_imports = dict(
imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
allow_failed_imports=False)
crop_size = (416, 544)
data_preprocessor = dict(size=crop_size)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(),
decode_head=dict(),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['../_base_/models/Adabins.py']
custom_imports = dict(
imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
allow_failed_imports=False)
crop_size = (352, 704)
data_preprocessor = dict(size=crop_size)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(),
decode_head=dict(min_val=0.001, max_val=80),
)
4 changes: 4 additions & 0 deletions projects/Adabins/decode_head/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .adabins_head import AdabinsHead

__all__ = ['AdabinsHead']
Loading

0 comments on commit b6090a1

Please sign in to comment.