diff --git a/mmrotate/models/dense_heads/oriented_reppoints_head.py b/mmrotate/models/dense_heads/oriented_reppoints_head.py index bfe18ff12..0e5b3ec21 100644 --- a/mmrotate/models/dense_heads/oriented_reppoints_head.py +++ b/mmrotate/models/dense_heads/oriented_reppoints_head.py @@ -715,7 +715,11 @@ def _point_target_single(self, # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) - labels = unmap(labels, num_total_proposals, inside_flags) + labels = unmap( + labels, + num_total_proposals, + inside_flags, + fill=self.num_classes) # fill bg label label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) diff --git a/mmrotate/models/dense_heads/rotated_reppoints_head.py b/mmrotate/models/dense_heads/rotated_reppoints_head.py index 8af17ae53..6660b192b 100644 --- a/mmrotate/models/dense_heads/rotated_reppoints_head.py +++ b/mmrotate/models/dense_heads/rotated_reppoints_head.py @@ -354,7 +354,11 @@ def _point_target_single(self, # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) - labels = unmap(labels, num_total_proposals, inside_flags) + labels = unmap( + labels, + num_total_proposals, + inside_flags, + fill=self.num_classes) # fill bg label label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) diff --git a/mmrotate/models/dense_heads/sam_reppoints_head.py b/mmrotate/models/dense_heads/sam_reppoints_head.py index 91ec6592e..1f434e012 100644 --- a/mmrotate/models/dense_heads/sam_reppoints_head.py +++ b/mmrotate/models/dense_heads/sam_reppoints_head.py @@ -414,7 +414,11 @@ def _point_target_single(self, # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) - labels = unmap(labels, num_total_proposals, inside_flags) + labels = unmap( + labels, + num_total_proposals, + inside_flags, + fill=self.num_classes) # fill bg label label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)