diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 4faf54559d..179d871fd1 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -350,9 +350,17 @@ def predict_by_feat(self, seg_logits: Tensor, Tensor: Outputs segmentation logits map. """ + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] + seg_logits = resize( input=seg_logits, - size=batch_img_metas[0]['img_shape'], + size=size, mode='bilinear', align_corners=self.align_corners) return seg_logits diff --git a/mmseg/models/decode_heads/san_head.py b/mmseg/models/decode_heads/san_head.py index 03dedf2e49..d20da80192 100644 --- a/mmseg/models/decode_heads/san_head.py +++ b/mmseg/models/decode_heads/san_head.py @@ -586,8 +586,11 @@ def predict_by_feat(self, seg_logits: List[Tensor], """ mask_pred = seg_logits[0] cls_score = seg_logits[1] - if 'pad_shape' in batch_img_metas[0]: - size = batch_img_metas[0]['pad_shape'] + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] else: size = batch_img_metas[0]['img_shape'] # upsample mask diff --git a/projects/hssn/decode_head/sep_aspp_contrast_head.py b/projects/hssn/decode_head/sep_aspp_contrast_head.py index d1d087362c..331af30de4 100644 --- a/projects/hssn/decode_head/sep_aspp_contrast_head.py +++ b/projects/hssn/decode_head/sep_aspp_contrast_head.py @@ -127,10 +127,17 @@ def predict_by_feat(self, seg_logits: Tuple[Tensor], # elif seg_logit.size(1) == 144 # For Mapillary dataset, 124+16+4 # unofficial repository not release mapillary until 2023/2/6 + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] seg_logit = seg_logit[:, :-hiera_num_classes] seg_logit = resize( input=seg_logit, - size=batch_img_metas[0]['img_shape'], + size=size, mode='bilinear', align_corners=self.align_corners)