Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support TTA #771

Open
wants to merge 1 commit into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/rotated_rtmdet/_base_/tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
tta_model = dict(
type='RotatedTTAModel',
tta_cfg=dict(nms=dict(type='nms_rotated', iou_threshold=0.1), max_per_img=2000))

img_scales = [(1024, 1024), (800, 800), (1200, 1200)]
tta_pipeline = [
dict(type='mmdet.LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='mmdet.TestTimeAug',
transforms=[
[
dict(type='mmdet.Resize', scale=s, keep_ratio=True)
for s in img_scales
],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='mmdet.RandomFlip', prob=1.),
dict(type='mmdet.RandomFlip', prob=0.)
],
[
dict(
type='mmdet.Pad',
size=(1200, 1200),
pad_val=dict(img=(114, 114, 114))),
],
[
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
2 changes: 1 addition & 1 deletion configs/rotated_rtmdet/rotated_rtmdet_l-3x-dota.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'./_base_/default_runtime.py', './_base_/schedule_3x.py',
'./_base_/dota_rr.py'
'./_base_/dota_rr.py', './_base_/tta.py'
]
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-l_8xb256-rsb-a1-600e_in1k-6a760974.pth' # noqa

Expand Down
1 change: 1 addition & 0 deletions mmrotate/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .necks import * # noqa: F401, F403
from .roi_heads import * # noqa: F401, F403
from .task_modules import * # noqa: F401,F403
from .test_time_augs import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
7 changes: 7 additions & 0 deletions mmrotate/models/test_time_augs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .rotated_tta import RotatedTTAModel


__all__ = [
'RotatedTTAModel'
]
71 changes: 71 additions & 0 deletions mmrotate/models/test_time_augs/rotated_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch

from torch import Tensor

from mmdet.models.test_time_augs import DetTTAModel
from mmrotate.registry import MODELS


def bbox_flip(bboxes: Tensor,
img_shape: Tuple[int],
direction: str = 'horizontal') -> Tensor:
"""Flip bboxes horizontally or vertically.

Args:
bboxes (Tensor): Shape (..., 5*k)
img_shape (Tuple[int]): Image shape.
direction (str): Flip direction, options are "horizontal", "vertical",
"diagonal". Default: "horizontal"

Returns:
Tensor: Flipped bboxes.
"""
assert bboxes.shape[-1] % 5 == 0
assert direction in ['horizontal', 'vertical', 'diagonal']
flipped = bboxes.clone()
if direction == 'horizontal':
flipped[..., 0] = img_shape[1] - flipped[..., 0]
flipped[..., 4] = -flipped[..., 4]
elif direction == 'vertical':
flipped[..., 1] = img_shape[0] - flipped[..., 1]
flipped[..., 4] = -flipped[..., 4]
else:
flipped[..., 0] = img_shape[1] - flipped[..., 0]
flipped[..., 1] = img_shape[0] - flipped[..., 1]
return flipped

@MODELS.register_module()
class RotatedTTAModel(DetTTAModel):

def merge_aug_bboxes(self, aug_bboxes: List[Tensor],
aug_scores: List[Tensor],
img_metas: List[str]) -> Tuple[Tensor, Tensor]:
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 5*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
Returns:
tuple[Tensor]: ``bboxes`` with shape (n,5), where
4 represent (x, y, w, h, t)
and ``scores`` with shape (n,).
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
ori_shape = img_info['ori_shape']
flip = img_info['flip']
flip_direction = img_info['flip_direction']
if flip:
bboxes = bbox_flip(
bboxes=bboxes,
img_shape=ori_shape,
direction=flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores
19 changes: 18 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path as osp

from mmdet.utils import register_all_modules as register_all_modules_mmdet
from mmengine.config import Config, DictAction
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
Expand All @@ -24,6 +24,10 @@ def parse_args():
'--out',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--tta',
action='store_true',
help='Whether to use test time augmentation')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
Expand Down Expand Up @@ -103,6 +107,19 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)

if args.tta:
assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \
" Can't use tta !"
assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` ' \
"in config. Can't use tta !"

cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
test_data_cfg = cfg.test_dataloader.dataset
while 'dataset' in test_data_cfg:
test_data_cfg = test_data_cfg['dataset']

test_data_cfg.pipeline = cfg.tta_pipeline

# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
Expand Down