Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Move some predict code to BaseAngleDenseHead #796

Open
wants to merge 5 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmrotate/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .angle_branch_retina_head import AngleBranchRetinaHead
from .base_angle_dense_head import BaseAngleDenseHead
from .cfa_head import CFAHead
from .h2rbox_head import H2RBoxHead
from .oriented_reppoints_head import OrientedRepPointsHead
Expand All @@ -18,5 +19,5 @@
'SAMRepPointsHead', 'AngleBranchRetinaHead', 'RotatedATSSHead',
'RotatedFCOSHead', 'OrientedRepPointsHead', 'R3Head', 'R3RefineHead',
'S2AHead', 'S2ARefineHead', 'CFAHead', 'H2RBoxHead', 'RotatedRTMDetHead',
'RotatedRTMDetSepBNHead'
'RotatedRTMDetSepBNHead', 'BaseAngleDenseHead'
]
299 changes: 299 additions & 0 deletions mmrotate/models/dense_heads/base_angle_dense_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Optional

import torch
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl
from mmdet.structures.bbox import cat_boxes
from mmdet.utils import ConfigType, InstanceList, OptConfigType
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor

from mmrotate.registry import MODELS, TASK_UTILS
from mmrotate.structures import RotatedBoxes


@MODELS.register_module()
class BaseAngleDenseHead(BaseDenseHead):
"""Base class for dense heads with angle. Commonly, BaseAngleDenseHead will
be used with other head.

Args:
angle_version (str, optional): The version of angle. Defaults to
'le90'.
use_hbbox_loss (bool, optional): Whether to use the loss of
horizontal bboxes. Defaults to False.
angle_coder (dict, optional): Config dict for angle coder.
Defaults to dict(type='PseudoAngleCoder').
loss_angle (dict, optional): Config dict for angle loss.
Defaults to None.
"""

def __init__(self,
angle_version: str = 'le90',
use_hbbox_loss: bool = False,
angle_coder: ConfigType = dict(type='PseudoAngleCoder'),
loss_angle: OptConfigType = None,
*args,
**kwargs):
self.angle_version = angle_version
self.use_hbbox_loss = use_hbbox_loss
self.angle_coder = TASK_UTILS.build(angle_coder)
# Commonly, BaseAngleDenseHead will be used with other head.
# So we call super here to init the other head.
# For example, RotatedFCOSHead will used with FCOSHead,
# so super here will call the init function of FCOSHead.
super().__init__(*args, **kwargs)

if loss_angle is not None:
self.loss_angle = MODELS.build(loss_angle)
else:
self.loss_angle = None
if self.use_hbbox_loss:
assert self.loss_angle is not None

def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
angle_preds (list[Tensor]): Box angle for each scale level
with shape (N, num_points * encode_size, H, W)
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 5),
the last dimension 5 arrange as (x, y, w, h, t).
"""
assert len(cls_scores) == len(bbox_preds)

if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)

num_levels = len(cls_scores)

featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)

result_list = []

for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(
cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
angle_pred_list = select_single_mlvl(
angle_preds, img_id, detach=True)
if with_score_factors:
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]

results = self._predict_by_feat_single(
cls_score_list=cls_score_list,
bbox_pred_list=bbox_pred_list,
angle_pred_list=angle_pred_list,
score_factor_list=score_factor_list,
mlvl_priors=mlvl_priors,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list

def _predict_by_feat_single(self,
cls_score_list: List[Tensor],
bbox_pred_list: List[Tensor],
angle_pred_list: List[Tensor],
score_factor_list: List[Tensor],
mlvl_priors: List[Tensor],
img_meta: dict,
cfg: ConfigDict,
rescale: bool = False,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
angle_pred_list (list[Tensor]): Box angle for a single scale
level with shape (N, num_points * encode_size, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 5),
the last dimension 5 arrange as (x, y, w, h, t).
"""
if score_factor_list[0] is None:
# e.g. Retina, FreeAnchor, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True

cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)

mlvl_bbox_preds = []
mlvl_decoded_angles = []
mlvl_valid_priors = []
mlvl_scores = []
mlvl_labels = []
if with_score_factors:
mlvl_score_factors = []
else:
mlvl_score_factors = None
for level_idx, (
cls_score, bbox_pred, angle_pred, score_factor, priors) in \
enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list,
score_factor_list, mlvl_priors)):

assert cls_score.size()[-2:] == bbox_pred.size()[-2:]

# dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
angle_pred = angle_pred.permute(1, 2, 0).reshape(
-1, self.angle_coder.encode_size)
if with_score_factors:
score_factor = score_factor.permute(1, 2,
0).reshape(-1).sigmoid()
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
scores = cls_score.softmax(-1)[:, :-1]

# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
score_thr = cfg.get('score_thr', 0)

results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(
bbox_pred=bbox_pred, angle_pred=angle_pred, priors=priors))
scores, labels, keep_idxs, filtered_results = results

bbox_pred = filtered_results['bbox_pred']
angle_pred = filtered_results['angle_pred']
priors = filtered_results['priors']

decoded_angle = self.angle_coder.decode(angle_pred, keepdim=True)

if with_score_factors:
score_factor = score_factor[keep_idxs]

mlvl_bbox_preds.append(bbox_pred)
mlvl_decoded_angles.append(decoded_angle)
mlvl_valid_priors.append(priors)
mlvl_scores.append(scores)
mlvl_labels.append(labels)

if with_score_factors:
mlvl_score_factors.append(score_factor)

bbox_pred = torch.cat(mlvl_bbox_preds)
decoded_angle = torch.cat(mlvl_decoded_angles)
priors = cat_boxes(mlvl_valid_priors)

decode_with_angle = cfg.get('decode_with_angle', True)
if decode_with_angle:
bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1)
bboxes = self.bbox_coder.decode(
priors, bbox_pred, max_shape=img_shape)
else:
bboxes = self.bbox_coder.decode(
priors, bbox_pred, max_shape=img_shape)
bboxes = torch.cat([bboxes[..., :4], decoded_angle], dim=-1)

results = InstanceData()
results.bboxes = RotatedBoxes(bboxes)
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
if with_score_factors:
results.score_factors = torch.cat(mlvl_score_factors)

return self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
Loading