Skip to content

Commit

Permalink
[Feature] Support depth metrics (#3297)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Please describe the motivation of this PR and the goal you want to
achieve through this PR.

Support metrics for the depth estimation task, including RMSE, ABSRel,
and etc.

## Modification

Please briefly describe what modification is made in this PR.

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

Using the following configuration to compute depth metrics on NYU

```python
dataset_type = 'NYUDataset'
data_root = 'data/nyu'

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
    dict(
        type='PackSegInputs',
        meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
                   'pad_shape', 'scale_factor', 'flip', 'flip_direction',
                   'category_id'))
]

val_dataloader = dict(
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        test_mode=True,
        data_prefix=dict(
            img_path='images/test', depth_map_path='annotations/test'),
        pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='DepthMetric', max_depth_eval=10.0, crop_type='nyu')
test_evaluator = val_evaluator
```

Example log:

![image](https://github.com/open-mmlab/mmsegmentation/assets/26127467/8101d65c-dee6-48de-916c-818659947b59)


## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.
  • Loading branch information
Ben-Louis committed Aug 31, 2023
1 parent 8233e64 commit 35ff78a
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mmseg/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .metrics import CityscapesMetric, IoUMetric
from .metrics import CityscapesMetric, DepthMetric, IoUMetric

__all__ = ['IoUMetric', 'CityscapesMetric']
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
3 changes: 2 additions & 1 deletion mmseg/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .citys_metric import CityscapesMetric
from .depth_metric import DepthMetric
from .iou_metric import IoUMetric

__all__ = ['IoUMetric', 'CityscapesMetric']
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
212 changes: 212 additions & 0 deletions mmseg/evaluation/metrics/depth_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Sequence

import cv2
import numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from prettytable import PrettyTable
from torch import Tensor

from mmseg.registry import METRICS


@METRICS.register_module()
class DepthMetric(BaseMetric):
"""Depth estimation evaluation metric.
Args:
depth_metrics (List[str], optional): List of metrics to compute. If
not specified, defaults to all metrics in self.METRICS.
min_depth_eval (float): Minimum depth value for evaluation.
Defaults to 0.0.
max_depth_eval (float): Maximum depth value for evaluation.
Defaults to infinity.
crop_type (str, optional): Specifies the type of cropping to be used
during evaluation. This option can affect how the evaluation mask
is generated. Currently, 'nyu_crop' is supported, but other
types can be added in future. Defaults to None if no cropping
should be applied.
depth_scale_factor (float): Factor to scale the depth values.
Defaults to 1.0.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
output_dir (str): The directory for output prediction. Defaults to
None.
format_only (bool): Only format result for results commit without
perform evaluation. It is useful when you want to save the result
to a specific format and submit it to the test server.
Defaults to False.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
'log10', 'silog')

def __init__(self,
depth_metrics: Optional[List[str]] = None,
min_depth_eval: float = 0.0,
max_depth_eval: float = float('inf'),
crop_type: Optional[str] = None,
depth_scale_factor: float = 1.0,
collect_device: str = 'cpu',
output_dir: Optional[str] = None,
format_only: bool = False,
prefix: Optional[str] = None,
**kwargs) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

if depth_metrics is None:
self.metrics = self.METRICS
elif isinstance(depth_metrics, [tuple, list]):
for metric in depth_metrics:
assert metric in self.METRICS, f'the metric {metric} is not ' \
f'supported. Please use metrics in {self.METRICS}'
self.metrics = depth_metrics

# Validate crop_type, if provided
assert crop_type in [
None, 'nyu_crop'
], (f'Invalid value for crop_type: {crop_type}. Supported values are '
'None or \'nyu_crop\'.')
self.crop_type = crop_type
self.min_depth_eval = min_depth_eval
self.max_depth_eval = max_depth_eval
self.output_dir = output_dir
if self.output_dir and is_main_process():
mkdir_or_exist(self.output_dir)
self.format_only = format_only
self.depth_scale_factor = depth_scale_factor

def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
be used to compute the metrics when all batches have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_label = data_sample['pred_depth_map']['data'].squeeze()
# format_only always for test dataset without ground truth
if not self.format_only:
gt_depth = data_sample['gt_depth_map']['data'].squeeze().to(
pred_label)

eval_mask = self._get_eval_mask(gt_depth)
self.results.append(
(gt_depth[eval_mask], pred_label[eval_mask]))
# format_result
if self.output_dir is not None:
basename = osp.splitext(osp.basename(
data_sample['img_path']))[0]
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output_mask = pred_label.cpu().numpy(
) * self.depth_scale_factor

cv2.imwrite(png_filename, output_mask.astype(np.uint16),
[cv2.IMWRITE_PNG_COMPRESSION, 0])

def _get_eval_mask(self, gt_depth: Tensor):
"""Generates an evaluation mask based on ground truth depth and
cropping.
Args:
gt_depth (Tensor): Ground truth depth map.
Returns:
Tensor: Boolean mask where evaluation should be performed.
"""
valid_mask = torch.logical_and(gt_depth > self.min_depth_eval,
gt_depth < self.max_depth_eval)

if self.crop_type == 'nyu_crop':
# this implementation is adapted from
# https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa
crop_mask = torch.zeros_like(valid_mask)
crop_mask[45:471, 41:601] = 1
else:
crop_mask = torch.ones_like(valid_mask)

eval_mask = torch.logical_and(valid_mask, crop_mask)
return eval_mask

@staticmethod
def _calc_all_metrics(gt_depth, pred_depth):
"""Computes final evaluation metrics based on accumulated results."""
assert gt_depth.shape == pred_depth.shape

thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth))
diff = pred_depth - gt_depth
diff_log = torch.log(pred_depth) - torch.log(gt_depth)

d1 = torch.sum(thresh < 1.25).float() / len(thresh)
d2 = torch.sum(thresh < 1.25**2).float() / len(thresh)
d3 = torch.sum(thresh < 1.25**3).float() / len(thresh)

abs_rel = torch.mean(torch.abs(diff) / gt_depth)
sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth)

rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2)))

log10 = torch.mean(
torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth)))
silog = torch.sqrt(
torch.pow(diff_log, 2).mean() -
0.5 * torch.pow(diff_log.mean(), 2))

return {
'd1': d1.item(),
'd2': d2.item(),
'd3': d3.item(),
'abs_rel': abs_rel.item(),
'sq_rel': sq_rel.item(),
'rmse': rmse.item(),
'rmse_log': rmse_log.item(),
'log10': log10.item(),
'silog': silog.item()
}

def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results. The keys
are identical with self.metrics.
"""
logger: MMLogger = MMLogger.get_current_instance()
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()

metrics = defaultdict(list)
for gt_depth, pred_depth in results:
for key, value in self._calc_all_metrics(gt_depth,
pred_depth).items():
metrics[key].append(value)
metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics}

table_data = PrettyTable()
for key, val in metrics.items():
table_data.add_column(key, [round(val, 5)])

print_log('results:', logger)
print_log('\n' + table_data.get_string(), logger=logger)

return metrics
4 changes: 2 additions & 2 deletions mmseg/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
elif backend == 'numpy':
data = np.load(f)
elif backend == 'cv2':
data = np.frombuffer(f.read(), dtype=np.uint16)
data = cv2.imdecode(data, 2)
data = np.frombuffer(f.read(), dtype=np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED)
else:
raise ValueError
return data
85 changes: 85 additions & 0 deletions tests/test_evaluation/test_metrics/test_depth_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
from unittest import TestCase

import torch
from mmengine.structures import PixelData

from mmseg.evaluation import DepthMetric
from mmseg.structures import SegDataSample


class TestDepthMetric(TestCase):

def _demo_mm_inputs(self,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
if isinstance(image_shapes, list):
assert len(image_shapes) == batch_size
else:
image_shapes = [image_shapes] * batch_size

data_samples = []
for idx in range(batch_size):
image_shape = image_shapes[idx]
_, h, w = image_shape

data_sample = SegDataSample()
gt_depth_map = torch.rand((1, h, w)) * 10
data_sample.gt_depth_map = PixelData(data=gt_depth_map)

data_samples.append(data_sample.to_dict())

return data_samples

def _demo_mm_model_output(self,
data_samples,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):

_, h, w = image_shapes

for data_sample in data_samples:
data_sample['pred_depth_map'] = dict(data=torch.randn(1, h, w))

data_sample[
'img_path'] = 'tests/data/pseudo_dataset/imgs/00000_img.jpg'
return data_samples

def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""

data_samples = self._demo_mm_inputs()
data_samples = self._demo_mm_model_output(data_samples)

depth_metric = DepthMetric()
depth_metric.process([0] * len(data_samples), data_samples)
res = depth_metric.compute_metrics(depth_metric.results)
self.assertIsInstance(res, dict)

# test save depth map file in output_dir
depth_metric = DepthMetric(output_dir='tmp')
depth_metric.process([0] * len(data_samples), data_samples)
assert osp.exists('tmp')
assert osp.isfile('tmp/00000_img.png')
shutil.rmtree('tmp')

# test format_only
depth_metric = DepthMetric(output_dir='tmp', format_only=True)
depth_metric.process([0] * len(data_samples), data_samples)
assert depth_metric.results == []
assert osp.exists('tmp')
assert osp.isfile('tmp/00000_img.png')
shutil.rmtree('tmp')

0 comments on commit 35ff78a

Please sign in to comment.