diff --git a/mmseg/models/decode_heads/mask2former_head.py b/mmseg/models/decode_heads/mask2former_head.py index 0135af0645..a8f738e667 100644 --- a/mmseg/models/decode_heads/mask2former_head.py +++ b/mmseg/models/decode_heads/mask2former_head.py @@ -150,8 +150,11 @@ def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], all_cls_scores, all_mask_preds = self(x, batch_data_samples) mask_cls_results = all_cls_scores[-1] mask_pred_results = all_mask_preds[-1] - if 'pad_shape' in batch_img_metas[0]: - size = batch_img_metas[0]['pad_shape'] + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] else: size = batch_img_metas[0]['img_shape'] # upsample mask