Skip to content

Commit

Permalink
[Enhance] Support mask in merge_results and huge_image_demo.py. (#280)
Browse files Browse the repository at this point in the history
* Support masks mergeing

* Update error report
  • Loading branch information
jbwang1997 committed May 11, 2022
1 parent d80310a commit 0edd257
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 20 deletions.
6 changes: 5 additions & 1 deletion mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,9 @@ def inference_detector_by_patches(model,
start += bs

results = merge_results(
results, windows[:, :2], iou_thr=merge_iou_thr, device=device)
results,
windows[:, :2],
img_shape=(width, height),
iou_thr=merge_iou_thr,
device=device)
return results
128 changes: 109 additions & 19 deletions mmrotate/core/patch/merge_results.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import nms_rotated
from mmcv.ops import nms, nms_rotated


def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
def translate_bboxes(bboxes, offset):
"""Translate bboxes according to its shape.
If the bbox shape is (n, 5), the bboxes are regarded as horizontal bboxes
and in (x, y, x, y, score) format. If the bbox shape is (n, 6), the bboxes
are regarded as rotated bboxes and in (x, y, w, h, theta, score) format.
Args:
bboxes (np.ndarray): The bboxes need to be translated. Its shape can
only be (n, 5) and (n, 6).
offset (np.ndarray): The offset to translate with shape being (2, ).
Returns:
np.ndarray: Translated bboxes.
"""
if bboxes.shape[1] == 5:
bboxes[:, :4] = bboxes[:, :4] + np.tile(offset, 2)
elif bboxes.shape[1] == 6:
bboxes[:, :2] = bboxes[:, :2] + offset
else:
raise TypeError('Require the shape of `bboxes` to be (n, 5) or (n, 6),'
f' but get `bboxes` with shape being {bboxes.shape}.')
return bboxes


def map_masks(masks, offset, new_shape):
"""Map masks to the huge image.
Args:
masks (list[np.ndarray]): masks need to be mapped.
offset (np.ndarray): The offset to translate with shape being (2, ).
new_shape (tuple): A tuple of the huge image's width and height.
Returns:
list[np.ndarray]: Mapped masks.
"""
if not masks:
return masks

new_width, new_height = new_shape
x_start, y_start = offset
mapped = []
for mask in masks:
ori_height, ori_width = mask.shape[:2]

x_end = x_start + ori_width
if x_end > new_width:
ori_width -= x_end - new_width
x_end = new_width

y_end = y_start + ori_height
if y_end > new_height:
ori_height -= y_end - new_height
y_end = new_height

extended_mask = np.zeros((new_height, new_width), dtype=np.bool)
extended_mask[y_start:y_end,
x_start:x_end] = mask[:ori_height, :ori_width]
mapped.append(extended_mask)
return mapped


def merge_results(results, offsets, img_shape, iou_thr=0.1, device='cpu'):
"""Merge patch results via nms.
Args:
results (list[np.ndarray]): A list of patches results.
results (list[np.ndarray] | list[tuple]): A list of patches results.
offsets (np.ndarray): Positions of the left top points of patches.
img_shape (tuple): A tuple of the huge image's width and height.
iou_thr (float): The IoU threshold of NMS.
device (str): The device to call nms.
Expand All @@ -18,20 +81,47 @@ def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
"""
assert len(results) == offsets.shape[0], 'The `results` should has the ' \
'same length with `offsets`.'
merged_results = []
for results_pre_cls in zip(*results):
tran_dets = []
for dets, offset in zip(results_pre_cls, offsets):
dets[:, :2] += offset
tran_dets.append(dets)
tran_dets = np.concatenate(tran_dets, axis=0)

if tran_dets.size == 0:
merged_results.append(tran_dets)
with_mask = isinstance(results[0], tuple)
num_patches = len(results)
num_classes = len(results[0][0]) if with_mask else len(results[0])

merged_bboxes = []
merged_masks = []
for cls in range(num_classes):
if with_mask:
dets_per_cls = [results[i][0][cls] for i in range(num_patches)]
masks_per_cls = [results[i][1][cls] for i in range(num_patches)]
else:
tran_dets = torch.from_numpy(tran_dets)
tran_dets = tran_dets.to(device)
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1],
iou_thr)
merged_results.append(nms_dets.cpu().numpy())
return merged_results
dets_per_cls = [results[i][cls] for i in range(num_patches)]
masks_per_cls = None

dets_per_cls = [
translate_bboxes(dets_per_cls[i], offsets[i])
for i in range(num_patches)
]
dets_per_cls = np.concatenate(dets_per_cls, axis=0)
if with_mask:
masks_placeholder = []
for i, masks in enumerate(masks_per_cls):
translated = map_masks(masks, offsets[i], img_shape)
masks_placeholder.extend(translated)
masks_per_cls = masks_placeholder

if dets_per_cls.size == 0:
merged_bboxes.append(dets_per_cls)
if with_mask:
merged_masks.append(masks_per_cls)
else:
dets_per_cls = torch.from_numpy(dets_per_cls).to(device)
nms_func = nms if dets_per_cls.size(1) == 5 else nms_rotated
nms_dets, keeps = nms_func(dets_per_cls[:, :-1],
dets_per_cls[:, -1], iou_thr)
merged_bboxes.append(nms_dets.cpu().numpy())
if with_mask:
keeps = keeps.cpu().numpy()
merged_masks.append([masks_per_cls[i] for i in keeps])

if with_mask:
return merged_bboxes, merged_masks
else:
return merged_bboxes

0 comments on commit 0edd257

Please sign in to comment.