Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support CopyPaste for RotatedBoxes #657

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmrotate/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .anchor import * # noqa: F401, F403
from .bbox import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .mask import * # noqa: F401, F403
from .patch import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .visualization import * # noqa: F401, F403
4 changes: 4 additions & 0 deletions mmrotate/core/mask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .structures import RBitmapMasks

__all__ = ['RBitmapMasks']
33 changes: 33 additions & 0 deletions mmrotate/core/mask/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
from mmdet.core.mask import BitmapMasks


class RBitmapMasks(BitmapMasks):
"""This class represents masks in the form of bitmaps. Compared to the
original class, this class supports getting the minimum area rectangles
from masks.

Args:
masks (ndarray): ndarray of masks in shape (N, H, W), where N is
the number of objects.
height (int): height of masks
width (int): width of masks
"""

def get_rbboxes(self):
num_masks = len(self)
rboxes = np.zeros((num_masks, 5), dtype=np.float32)
x_any = self.masks.any(axis=1)
y_any = self.masks.any(axis=2)
for idx in range(num_masks):
x = np.where(x_any[idx, :])[0]
y = np.where(y_any[idx, :])[0]
if len(x) > 0 and len(y) > 0:
contours = cv2.findContours(self.masks[idx], cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)[0][0]
(cx, cy), (w, h), a = cv2.minAreaRect(contours)
rboxes[idx, :] = np.array(
[cx, cy, w, h, np.radians(a)], dtype=np.float32)
Comment on lines +28 to +32
Copy link
Contributor Author

@nijkah nijkah Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we need to support mmrotate.core.bbox.poly2obb_np for various angle versions.
However, poly2obb supposes that polygons have 4 vertices.
How can we support this?

Copy link
Collaborator

@zytx121 zytx121 Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mmrotate1.x, we suppose all rotated boxes use oc angle version during data transforms.
We only need to distinguish different angle versions within the box_head.
So I think it's OK to write like this.

return rboxes
5 changes: 3 additions & 2 deletions mmrotate/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .loading import LoadPatchFromImage
from .transforms import PolyRandomRotate, RMosaic, RRandomFlip, RResize
from .transforms import (PolyRandomRotate, RCopyPaste, RMosaic, RRandomFlip,
RResize)

__all__ = [
'LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate',
'RMosaic'
'RMosaic', 'RCopyPaste'
]
141 changes: 140 additions & 1 deletion mmrotate/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import numpy as np
import torch
from mmcv.ops import box_iou_rotated
from mmdet.datasets.pipelines.transforms import (Mosaic, RandomCrop,
from mmdet.datasets.pipelines.transforms import (CopyPaste, Mosaic, RandomCrop,
RandomFlip, Resize)
from numpy import random

from mmrotate.core import norm_angle, obb2poly_np, poly2obb_np
from mmrotate.core.mask import RBitmapMasks
from ..builder import ROTATED_PIPELINES


Expand Down Expand Up @@ -554,3 +555,141 @@ def _filter_box_candidates(self, bboxes, labels, w, h):
(bbox_h > self.min_bbox_size)
valid_inds = np.nonzero(valid_inds)[0]
return bboxes[valid_inds], labels[valid_inds]


@ROTATED_PIPELINES.register_module()
class RCopyPaste(CopyPaste):
"""Simple Copy-Paste is a Strong Data Augmentation Method for Instance
Segmentation The simple copy-paste transform steps are as follows:

1. The destination image is already resized with aspect ratio kept,
cropped and padded.
2. Randomly select a source image, which is also already resized
with aspect ratio kept, cropped and padded in a similar way
as the destination image.
3. Randomly select some objects from the source image.
4. Paste these source objects to the destination image directly,
due to the source and destination image have the same size.
5. Update object masks of the destination image, for some origin objects
may be occluded.
6. Generate bboxes from the updated destination masks and
filter some objects which are totally occluded, and adjust bboxes
which are partly occluded.
7. Append selected source bboxes, masks, and labels.
Args:
max_num_pasted (int): The maximum number of pasted objects.
Default: 100.
rbbox_occluded_iou_thr (int): The threshold of occluded rbboxes.
Default: 0.3.
mask_occluded_thr (int): The threshold of occluded mask.
Default: 300.
selected (bool): Whether select objects or not. If select is False,
all objects of the source image will be pasted to the
destination image.
Default: True.
version (str, optional): Angle representations. Defaults to `oc`.
"""

def __init__(
self,
max_num_pasted=100,
rbbox_occluded_iou_thr=0.3,
mask_occluded_thr=300,
selected=True,
version='le90',
):
self.max_num_pasted = max_num_pasted
self.rbbox_occluded_iou_thr = rbbox_occluded_iou_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = False
self.version = version

def gen_masks_from_bboxes(self, bboxes, img_shape):
"""Generate gt_masks based on gt_bboxes.

Args:
bboxes (list): The bboxes's list.
img_shape (tuple): The shape of image.
Returns:
RBitmapMasks
"""
self.paste_by_box = True
img_h, img_w = img_shape[:2]
gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8)
bboxes = np.concatenate(
[bboxes, np.zeros((bboxes.shape[0], 1))], axis=-1)
polys = obb2poly_np(bboxes,
self.version)[:, :-1].reshape(-1, 4,
2).astype(np.int0)

for i, poly in enumerate(polys):
cv2.drawContours(gt_masks[i], [poly], 0, 1, -1)
return RBitmapMasks(gt_masks, img_h, img_w)

def _copy_paste(self, dst_results, src_results):
"""CopyPaste transform function.

Args:
dst_results (dict): Result dict of the destination image.
src_results (dict): Result dict of the source image.
Returns:
dict: Updated result dict.
"""

dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_labels']
dst_masks = dst_results['gt_masks']

src_img = src_results['img']
src_bboxes = src_results['gt_bboxes']
src_labels = src_results['gt_labels']
src_masks = src_results['gt_masks']

if len(src_bboxes) == 0:
if self.paste_by_box:
dst_results.pop('gt_masks')
return dst_results

# update masks and generate bboxes from updated masks
composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
updated_dst_masks = self.get_updated_masks(dst_masks, composed_mask)
updated_dst_bboxes = updated_dst_masks.get_rbboxes()
assert len(updated_dst_bboxes) == len(updated_dst_masks)

# filter totally occluded objects
bboxes_inds = box_iou_rotated(
torch.tensor(dst_bboxes), torch.tensor(updated_dst_bboxes)).numpy(
).max(-1) <= self.rbbox_occluded_iou_thr
masks_inds = updated_dst_masks.masks.sum(
axis=(1, 2)) > self.mask_occluded_thr
valid_inds = bboxes_inds | masks_inds

# Paste source objects to destination image directly
img = dst_img * (1 - composed_mask[..., np.newaxis]
) + src_img * composed_mask[..., np.newaxis]
bboxes = np.concatenate([updated_dst_bboxes[valid_inds], src_bboxes])
labels = np.concatenate([dst_labels[valid_inds], src_labels])
masks = np.concatenate(
[updated_dst_masks.masks[valid_inds], src_masks.masks])

dst_results['img'] = img
dst_results['gt_bboxes'] = bboxes
dst_results['gt_labels'] = labels
if self.paste_by_box:
dst_results.pop('gt_masks')
else:
dst_results['gt_masks'] = RBitmapMasks(masks, masks.shape[1],
masks.shape[2])

return dst_results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'max_num_pasted={self.max_num_pasted}, '
repr_str += f'rbox_occluded_iou_thr={self.rbox_occluded_iou_thr}, '
repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
repr_str += f'selected={self.selected}, '
repr_str += f'version={self.version}, '
return repr_str