Skip to content

Commit

Permalink
[Feature] Support inference and visualization of VPD (#3331)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Support inference and visualization of VPD

## Modification

1. add a new VPD model that does not generate black border in
predictions
2. update `SegLocalVisualizer` to support depth visualization
3. update `MMSegInferencer` to support save predictions of depth
estimation in method `postprocess`

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

Run inference with VPD using the this command

```sh
python demo/image_demo_with_inferencer.py demo/classroom__rgb_00283.jpg vpd_depth --out-dir vis_results
```

The following image will be saved under `vis_results/vis`


![classroom__rgb_00283](https://github.com/open-mmlab/mmsegmentation/assets/26127467/051e8c4b-8f92-495f-8c3e-f249aac888e3)




## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
4. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
5. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
6. The documentation has been modified accordingly, like docstring or
example tutorials.
  • Loading branch information
Ben-Louis committed Sep 18, 2023
1 parent f1fa61a commit 743171d
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 36 deletions.
6 changes: 5 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.7"

formats:
- epub

python:
version: 3.7
install:
- requirements: requirements/docs.txt
- requirements: requirements/readthedocs.txt
1 change: 1 addition & 0 deletions configs/_base_/datasets/nyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2000, 480), keep_ratio=True),
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
dict(
type='PackSegInputs',
Expand Down
72 changes: 72 additions & 0 deletions configs/_base_/datasets/nyu_512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# dataset settings
dataset_type = 'NYUDataset'
data_root = 'data/nyu'

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3),
dict(type='RandomDepthMix', prob=0.25),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomResize',
scale=(768, 512),
ratio_range=(0.8, 1.5),
keep_ratio=True),
dict(type='RandomCrop', crop_size=(512, 512)),
dict(
type='Albu',
transforms=[
dict(type='RandomBrightnessContrast'),
dict(type='RandomGamma'),
dict(type='HueSaturationValue'),
]),
dict(
type='PackSegInputs',
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'category_id')),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
dict(
type='PackSegInputs',
meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'category_id'))
]

train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/train', depth_map_path='annotations/train'),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(
img_path='images/test', depth_map_path='annotations/test'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
type='DepthMetric',
min_depth_eval=0.001,
max_depth_eval=10.0,
crop_type='nyu_crop')
test_evaluator = val_evaluator
2 changes: 1 addition & 1 deletion configs/_base_/models/vpd_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
std=[127.5, 127.5, 127.5],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
seg_pad_val=0)

# adapted from stable-diffusion/configs/stable-diffusion/v1-inference.yaml
stable_diffusion_cfg = dict(
Expand Down
3 changes: 2 additions & 1 deletion configs/vpd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
## Introduction

<!-- [BACKBONE] -->
<!-- [ALGORITHM] -->

<a href = "https://github.com/wl-zhao/VPD">Official Repo</a>

Expand Down Expand Up @@ -36,6 +36,7 @@ pip install -r requirements/optional.txt
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | RMSE | d1 | d2 | d3 | REL | log_10 | config | download |
| ------ | --------------------- | --------- | ------- | -------- | -------------- | ------ | ----- | ----- | ----- | ----- | ----- | ------ | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| VPD | Stable-Diffusion-v1-5 | 480x480 | 25000 | - | - | A100 | 0.253 | 0.964 | 0.995 | 0.999 | 0.069 | 0.030 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json) |
| VPD | Stable-Diffusion-v1-5 | 512x512 | 25000 | - | - | A100 | 0.258 | 0.963 | 0.995 | 0.999 | 0.072 | 0.031 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918-60cefcff.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918.json) |

## Citation

Expand Down
22 changes: 22 additions & 0 deletions configs/vpd/metafile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,25 @@ Models:
URL: https://arxiv.org/abs/2112.10752
Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333
Framework: PyTorch
- Name: vpd_sd_4xb8-25k_nyu-512x512
In Collection: VPD
Alias: vpd_depth
Results:
Task: Depth Estimation
Dataset: NYU
Metrics:
RMSE: 0.258
Config: configs/vpd/vpd_sd_4xb8-25k_nyu-512x512.py
Metadata:
Training Data: NYU
Batch Size: 32
Architecture:
- Stable-Diffusion
Training Resources: 8x A100 GPUS
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918-60cefcff.pth
Training log: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-512x512_20230918.json
Paper:
Title: 'High-Resolution Image Synthesis with Latent Diffusion Models'
URL: https://arxiv.org/abs/2112.10752
Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333
Framework: PyTorch
3 changes: 2 additions & 1 deletion configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
),
test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(160, 160)))

default_hooks = dict(checkpoint=dict(save_best='rmse', rule='less'))
default_hooks = dict(
checkpoint=dict(save_best='rmse', rule='less', max_keep_ckpts=1))

# custom optimizer
optim_wrapper = dict(
Expand Down
37 changes: 37 additions & 0 deletions configs/vpd/vpd_sd_4xb8-25k_nyu-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = [
'../_base_/models/vpd_sd.py', '../_base_/datasets/nyu_512x512.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_25k.py'
]

crop_size = (512, 512)

model = dict(
type='DepthEstimator',
data_preprocessor=dict(size=crop_size),
backbone=dict(
class_embed_path='https://download.openmmlab.com/mmsegmentation/'
'v0.5/vpd/nyu_class_embeddings.pth',
class_embed_select=True,
pad_shape=512,
unet_cfg=dict(use_attn=False),
),
decode_head=dict(
type='VPDDepthHead',
in_channels=[320, 640, 1280, 1280],
max_depth=10,
),
test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(128, 128)))

default_hooks = dict(
checkpoint=dict(save_best='rmse', rule='less', max_keep_ckpts=1))

# custom optimizer
optim_wrapper = dict(
constructor='ForceDefaultOptimWrapperConstructor',
paramwise_cfg=dict(
bias_decay_mult=0,
force_default_settings=True,
custom_keys={
'backbone.encoder_vq': dict(lr_mult=0),
'backbone.unet': dict(lr_mult=0.01),
}))
32 changes: 22 additions & 10 deletions mmseg/apis/mmseg_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,17 +306,28 @@ def postprocess(self,
results_dict['visualization'] = []

for i, pred in enumerate(preds):
pred_data = pred.pred_sem_seg.numpy().data[0]
results_dict['predictions'].append(pred_data)
pred_data = dict()
if 'pred_sem_seg' in pred.keys():
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
elif 'pred_depth_map' in pred.keys():
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]

if visualization is not None:
vis = visualization[i]
results_dict['visualization'].append(vis)
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
img_path = osp.join(pred_out_dir, img_name)
output = Image.fromarray(pred_data.astype(np.uint8))
output.save(img_path)
for key, data in pred_data.items():
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
img_path = osp.join(pred_out_dir, img_name)
if key == 'sem_seg':
output = Image.fromarray(data.astype(np.uint8))
output.save(img_path)
else:
np.save(img_path, data)
pred_data = next(iter(pred_data.values()))
results_dict['predictions'].append(pred_data)
self.num_pred_imgs += 1

if len(results_dict['predictions']) == 1:
Expand Down Expand Up @@ -344,12 +355,13 @@ def preprocess(self, inputs, batch_size, **kwargs):
"""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Loading annotations is also not applicable
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations')
if idx != -1:
del pipeline_cfg[idx]
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
idx = self._get_transform_idx(pipeline_cfg, transform)
if idx != -1:
del pipeline_cfg[idx]

load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')

if load_img_idx == -1:
raise ValueError(
'LoadImageFromFile is not found in the test pipeline')
Expand Down
4 changes: 2 additions & 2 deletions mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomDepthMix, RandomFlip, RandomMosaic,
RandomRotate, RandomRotFlip, Rerange,
RandomRotate, RandomRotFlip, Rerange, Resize,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)

Expand All @@ -26,5 +26,5 @@
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
'RandomFlip'
'RandomFlip', 'Resize'
]
67 changes: 67 additions & 0 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import mmengine
import numpy as np
from mmcv.transforms import RandomFlip as MMCV_RandomFlip
from mmcv.transforms import Resize as MMCV_Resize
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of
Expand Down Expand Up @@ -1031,6 +1032,72 @@ def _flip(self, results: dict) -> None:
results['swap_seg_labels'] = self.swap_seg_labels


@TRANSFORMS.register_module()
class Resize(MMCV_Resize):
"""Resize images & seg & depth map.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Seg map, depth map and other relative annotations are
then resized with the same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- gt_seg_map (optional)
- gt_depth_map (optional)
Modified Keys:
- img
- gt_seg_map
- gt_depth_map
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (int or tuple): Images scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""

def _resize_seg(self, results: dict) -> None:
"""Resize semantic segmentation map with ``results['scale']``."""
for seg_key in results.get('seg_fields', []):
if results.get(seg_key, None) is not None:
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[seg_key],
results['scale'],
interpolation='nearest',
backend=self.backend)
else:
gt_seg = mmcv.imresize(
results[seg_key],
results['scale'],
interpolation='nearest',
backend=self.backend)
results[seg_key] = gt_seg


@TRANSFORMS.register_module()
class RandomMosaic(BaseTransform):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
Expand Down
5 changes: 5 additions & 0 deletions mmseg/models/losses/silog_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def silog_loss(pred: Tensor,

diff_log = torch.log(target.clamp(min=eps)) - torch.log(
pred.clamp(min=eps))

valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
diff_log[~valid_mask] = 0.0
valid_mask = valid_mask.float()

diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
Expand Down
Loading

0 comments on commit 743171d

Please sign in to comment.