From 8015ad502feef1118fc82b722dc30b39d2db726c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E6=97=BB=E4=BD=91?= Date: Sun, 3 Aug 2025 22:16:15 +0800 Subject: [PATCH 1/2] BUGFIX: support NDHWC input in sliding_window_inference and DiceMetric --- monai/inferers/utils.py | 9 +++++++++ monai/metrics/meandice.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8adba8fa25..6ccad36fef 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -60,6 +60,7 @@ def sliding_window_inference( *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: + """ Sliding window inference on `inputs` with `predictor`. @@ -134,6 +135,14 @@ def sliding_window_inference( - input must be channel-first and have a batch dim, supports N-D sliding window. """ + + # auto transform (N,D,H,W,C) → (N,C,D,H,W) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4): + inputs = inputs.permute(0, 4, 1, 2, 3).contiguous() + + + + buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 if buffered: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 0802cc3364..980552e628 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -134,6 +134,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Raises: ValueError: when `y_pred` has fewer than three dimensions. """ + + if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4): + y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() + if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): + y = y.permute(0, 4, 1, 2, 3).contiguous() + dims = y_pred.ndimension() if dims < 3: raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") From 9c2881b51e59b918ae24f635368f673f974853ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 07:34:46 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/inferers/utils.py | 2 +- monai/metrics/meandice.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 6ccad36fef..00ead9a2d5 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -60,7 +60,7 @@ def sliding_window_inference( *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: - + """ Sliding window inference on `inputs` with `predictor`. diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 980552e628..841585897f 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -139,7 +139,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): y = y.permute(0, 4, 1, 2, 3).contiguous() - + dims = y_pred.ndimension() if dims < 3: raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")