From 7d50a0151f480ec1b00485ded43dde1bd0c281fd Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 23 Oct 2025 13:01:50 +0100 Subject: [PATCH 1/7] :new: Define `NucleusInstanceSegmentor` --- .../engine/nucleus_instance_segmentor.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index ce74355ae..37d3c0820 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -13,13 +13,22 @@ import tqdm from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from typing_extensions import Unpack from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset -from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, +) from tiatoolbox.tools.patchextraction import PatchExtractor if TYPE_CHECKING: # pragma: no cover + import os from collections.abc import Callable + from pathlib import Path + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.wsicore import WSIReader from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig @@ -812,3 +821,29 @@ def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: # manually call the callback rather than # attaching it when receiving/creating the future callback(*result) + + def run( + self: NucleusInstanceSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the nucleus instance segmentor engine on input images.""" + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) From 03b296476f2f7038765f2b4de24ca14503e88d7f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 17 Nov 2025 11:36:20 +0000 Subject: [PATCH 2/7] :test_tube: Add initial test for nucleus instance segmentor --- .../test_nucleus_instance_segmentor.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/engines/test_nucleus_instance_segmentor.py diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py new file mode 100644 index 000000000..5de438b4d --- /dev/null +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -0,0 +1,59 @@ +"""Test tiatoolbox.models.engine.nucleus_instance_segmentor.""" + +import gc +import shutil +from collections.abc import Callable +from pathlib import Path + +import torch + +from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor +from tiatoolbox.utils import imwrite +from tiatoolbox.wsicore import WSIReader + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + + +def test_functionality_ci(remote_sample: Callable, track_tmp_path: Path) -> None: + """Functionality test for nuclei instance segmentor.""" + gc.collect() + mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) + + resolution = 2.0 + + reader = WSIReader.open(mini_wsi_svs) + thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") + mini_wsi_jpg = f"{track_tmp_path}/mini_svs.jpg" + imwrite(mini_wsi_jpg, thumb) + + # * test run on wsi, test run with worker + # resolution for travis testing, not the correct ones + ioconfig = IOSegmentorConfig( + input_resolutions=[{"units": "mpp", "resolution": resolution}], + output_resolutions=[ + {"units": "mpp", "resolution": resolution}, + {"units": "mpp", "resolution": resolution}, + ], + margin=128, + tile_shape=[1024, 1024], + patch_input_shape=[256, 256], + patch_output_shape=[164, 164], + stride_shape=[164, 164], + ) + + save_dir = track_tmp_path / "instance" + shutil.rmtree(save_dir, ignore_errors=True) + + inst_segmentor = NucleusInstanceSegmentor( + batch_size=1, + num_loader_workers=0, + num_postproc_workers=0, + pretrained_model="hovernet_fast-pannuke", + ) + inst_segmentor.run( + [mini_wsi_svs], + patch_mode=False, + ioconfig=ioconfig, + device=device, + save_dir=save_dir, + ) From dae9213b5e5bd86579fd18cde5f14c274730a4db Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 19 Nov 2025 17:24:58 +0000 Subject: [PATCH 3/7] :test_tube: Test issues with raw output in patch mode --- .../test_nucleus_instance_segmentor.py | 55 ++--- tiatoolbox/models/architecture/hovernet.py | 6 +- tiatoolbox/models/engine/engine_abc.py | 2 +- .../engine/nucleus_instance_segmentor.py | 231 ++++++++++++++++-- 4 files changed, 234 insertions(+), 60 deletions(-) diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py index 5de438b4d..f55225a36 100644 --- a/tests/engines/test_nucleus_instance_segmentor.py +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -1,6 +1,5 @@ """Test tiatoolbox.models.engine.nucleus_instance_segmentor.""" -import gc import shutil from collections.abc import Callable from pathlib import Path @@ -8,52 +7,40 @@ import torch from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor -from tiatoolbox.utils import imwrite -from tiatoolbox.wsicore import WSIReader device = "cuda:0" if torch.cuda.is_available() else "cpu" -def test_functionality_ci(remote_sample: Callable, track_tmp_path: Path) -> None: - """Functionality test for nuclei instance segmentor.""" - gc.collect() - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{track_tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - # * test run on wsi, test run with worker - # resolution for travis testing, not the correct ones - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[1024, 1024], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], +def test_functionality_tile(source_image: Path, track_tmp_path: Path) -> None: + inst_segmentor = NucleusInstanceSegmentor( + batch_size=1, + num_workers=0, + model="hovernet_fast-pannuke", + ) + output = inst_segmentor.run( + [source_image], + patch_mode=True, + device=device, + save_dir=track_tmp_path / "hovernet_fast-pannuke", ) - save_dir = track_tmp_path / "instance" - shutil.rmtree(save_dir, ignore_errors=True) +def test_functionality_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: + """Local functionality test for nuclei instance segmentor.""" + root_save_dir = Path(track_tmp_path) + save_dir = Path(f"{track_tmp_path}/output") + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + + # * generate full output w/o parallel post-processing worker first + shutil.rmtree(save_dir, ignore_errors=True) inst_segmentor = NucleusInstanceSegmentor( - batch_size=1, - num_loader_workers=0, + batch_size=8, num_postproc_workers=0, pretrained_model="hovernet_fast-pannuke", ) - inst_segmentor.run( + output = inst_segmentor.run( [mini_wsi_svs], patch_mode=False, - ioconfig=ioconfig, device=device, save_dir=save_dir, ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 19d02e7a5..af2186c00 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -4,7 +4,7 @@ import math from collections import OrderedDict - +import dask import cv2 import numpy as np import torch @@ -776,7 +776,9 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: tp_map = None np_map, hv_map = raw_maps - pred_type = tp_map + np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map + pred_type = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 73b4ca1c1..01cb4e1a0 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -524,7 +524,7 @@ def infer_patches( coordinates = [] # Main output dictionary - raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False)) + raw_predictions = {key: [] for key in keys} # Inference loop tqdm = get_tqdm() diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 37d3c0820..cca443d78 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -6,26 +6,34 @@ from collections import deque from typing import TYPE_CHECKING +import dask # replace with the sql database once the PR in place import joblib import numpy as np import torch import tqdm +import dask.array as da from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from torch.utils.data import DataLoader from typing_extensions import Unpack -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.models.engine.semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, ) from tiatoolbox.tools.patchextraction import PatchExtractor +from tiatoolbox.models.models_abc import ModelABC +from tiatoolbox.utils.misc import get_tqdm +from .engine_abc import EngineABCRunParams +from tiatoolbox import DuplicateFilter, logger +from pathlib import Path + if TYPE_CHECKING: # pragma: no cover import os from collections.abc import Callable - from pathlib import Path + from tiatoolbox.annotation import AnnotationStore from tiatoolbox.wsicore import WSIReader @@ -381,38 +389,215 @@ class NucleusInstanceSegmentor(SemanticSegmentor): def __init__( self: NucleusInstanceSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: """Initialize :class:`NucleusInstanceSegmentor`.""" super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, ) - # default is None in base class and is un-settable - # hence we redefine the namespace here - self.num_postproc_workers = ( - num_postproc_workers if num_postproc_workers > 0 else None + + def infer_patches( + self: NucleusInstanceSegmentor, + dataloader: DataLoader, + *, + return_coordinates: bool = False, + ) -> dict[str, list[da.Array]]: + """Run model inference on image patches and return predictions. + + This method performs batched inference using a PyTorch DataLoader, + and accumulates predictions in Dask arrays. It supports optional inclusion + of coordinates and labels in the output. + + Args: + dataloader (DataLoader): + PyTorch DataLoader containing image patches for inference. + return_coordinates (bool): + Whether to include coordinates in the output. Required when + called by `infer_wsi` and `patch_mode` is False. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing prediction results as Dask arrays. + Keys include: + - "probabilities": Model output probabilities. + - "labels": Ground truth labels (if `return_labels` is True). + - "coordinates": Patch coordinates (if `return_coordinates` is + True). + + """ + keys = ["probabilities"] + labels, coordinates = [], [] + + # Expected number of outputs from the model + batch_output = self.model.infer_batch( + self.model, + torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), + device=self.device, ) - # adding more runtime placeholder - self._wsi_inst_info = None - self._futures = [] + num_expected_output = len(batch_output) + probabilities = [[] for _ in range(num_expected_output)] + + if return_coordinates: + keys.append("coordinates") + coordinates = [] + + # Main output dictionary + raw_predictions = {key: [] for key in keys} + raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else self.dataloader + ) + + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + for i in range(num_expected_output): + probabilities[i].append( + da.from_array( + batch_output[i], # probabilities + ) + ) + + if return_coordinates: + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + if self.return_labels: + labels.append(da.from_array(np.array(batch_data["label"]))) + + for i in range(num_expected_output): + raw_predictions["probabilities"][i] = da.concatenate(probabilities[i], axis=0) + + if return_coordinates: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + return raw_predictions + + def _run_patch_mode( + self: NucleusInstanceSegmentor, + output_type: str, + save_dir: Path, + **kwargs: EngineABCRunParams, + ) -> dict | AnnotationStore | Path: + """Run the engine in patch mode. + + This method performs inference on image patches, post-processes the predictions, + and saves the output in the specified format. + + Args: + output_type (str): + Desired output format. Supported values are "dict", "zarr", + and "annotationstore". + save_dir (Path): + Directory to save the output files. + **kwargs (EngineABCRunParams): + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. + + Returns: + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. + + """ + save_path = None + if save_dir: + output_file = Path(kwargs.get("output_file", "output.zarr")) + save_path = save_dir / (str(output_file.stem) + ".zarr") + + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + + self.dataloader = self.get_dataloader( + images=self.images, + masks=self.masks, + labels=self.labels, + patch_mode=True, + ioconfig=self._ioconfig, + ) + raw_predictions = self.infer_patches( + dataloader=self.dataloader, + return_coordinates=output_type == "annotationstore", + ) + + raw_predictions["predictions"] = self.post_process_patches( + raw_predictions=raw_predictions["probabilities"], + prediction_shape=None, + prediction_dtype=None, + **kwargs, + ) + + logger.removeFilter(duplicate_filter) + + out = self.save_predictions( + processed_predictions=raw_predictions, + output_type=output_type, + save_path=save_path, + **kwargs, + ) + + msg = f"Output file saved at {out}." + logger.info(msg=msg) + return out + + def post_process_patches( # skipcq: PYL-R0201 + self: NucleusInstanceSegmentor, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], # noqa: ARG002 + prediction_dtype: type, # noqa: ARG002 + **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 + ) -> dask.array.Array: + """Post-process raw patch predictions from inference. + + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. + + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (EngineABCRunParams): + Additional runtime parameters used for post-processing. + + Returns: + dask.array.Array: + Post-processed predictions as a Dask array. + + """ + raw_predictions = self.model.postproc_func(raw_predictions) + return raw_predictions @staticmethod def _get_tile_info( From 4bc33b724e5548a6da35199a09f149b4b0ce9940 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 20 Nov 2025 10:59:13 +0000 Subject: [PATCH 4/7] :test_tube: Test issues with raw output in patch mode --- .../test_nucleus_instance_segmentor.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py index f55225a36..debb8a44f 100644 --- a/tests/engines/test_nucleus_instance_segmentor.py +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -3,27 +3,57 @@ import shutil from collections.abc import Callable from pathlib import Path +from typing import Literal, Final import torch +import numpy as np from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor +from tiatoolbox.wsicore import WSIReader device = "cuda:0" if torch.cuda.is_available() else "cpu" -def test_functionality_tile(source_image: Path, track_tmp_path: Path) -> None: +def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path) -> None: + """Patch mode functionality test for nuclei instance segmentor.""" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.25 + units: Final = "mpp" + patch1 = mini_wsi.read_rect( + location=(0, 0), + size=size, + resolution=resolution, + units=units, + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), + size=size, + resolution=resolution, + units=units, + ) + + patches = np.stack( + arrays=[patch1, patch2], + axis=0 + ) + inst_segmentor = NucleusInstanceSegmentor( batch_size=1, num_workers=0, model="hovernet_fast-pannuke", ) output = inst_segmentor.run( - [source_image], + images=patches, patch_mode=True, device=device, save_dir=track_tmp_path / "hovernet_fast-pannuke", + output_type="dict", ) + assert output + def test_functionality_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: """Local functionality test for nuclei instance segmentor.""" From 2797ff9612f10d0136f22e8f9c42a71f992dfd1c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:09:30 +0000 Subject: [PATCH 5/7] :white_check_mark: Test patch mode with dict output --- requirements/requirements.txt | 3 +- .../test_nucleus_instance_segmentor.py | 41 +++--------- .../engine/nucleus_instance_segmentor.py | 66 +++++++++++++++---- 3 files changed, 63 insertions(+), 47 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 045a4ce4e..ad6d8f40a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,7 +4,8 @@ aiohttp>=3.8.1 albumentations>=1.3.0 bokeh>=3.1.1, <3.6.0 Click>=8.1.3, <8.2.0 -dask>=2025.10.0 +dask[array]>=2025.10.0 +dask[dataframe]>=2025.10.0 defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py index debb8a44f..11d71ab2c 100644 --- a/tests/engines/test_nucleus_instance_segmentor.py +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -1,20 +1,19 @@ """Test tiatoolbox.models.engine.nucleus_instance_segmentor.""" -import shutil from collections.abc import Callable from pathlib import Path -from typing import Literal, Final +from typing import Final -import torch import numpy as np +import torch -from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor +from tiatoolbox.models import NucleusInstanceSegmentor from tiatoolbox.wsicore import WSIReader device = "cuda:0" if torch.cuda.is_available() else "cpu" -def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path) -> None: +def test_functionality_patch_mode(remote_sample: Callable) -> None: """Patch mode functionality test for nuclei instance segmentor.""" mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) mini_wsi = WSIReader.open(mini_wsi_svs) @@ -34,10 +33,7 @@ def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path) units=units, ) - patches = np.stack( - arrays=[patch1, patch2], - axis=0 - ) + patches = np.stack(arrays=[patch1, patch2], axis=0) inst_segmentor = NucleusInstanceSegmentor( batch_size=1, @@ -48,29 +44,10 @@ def test_functionality_patch_mode(remote_sample: Callable, track_tmp_path: Path) images=patches, patch_mode=True, device=device, - save_dir=track_tmp_path / "hovernet_fast-pannuke", output_type="dict", ) - assert output - - -def test_functionality_wsi(remote_sample: Callable, track_tmp_path: Path) -> None: - """Local functionality test for nuclei instance segmentor.""" - root_save_dir = Path(track_tmp_path) - save_dir = Path(f"{track_tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * generate full output w/o parallel post-processing worker first - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - batch_size=8, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - output = inst_segmentor.run( - [mini_wsi_svs], - patch_mode=False, - device=device, - save_dir=save_dir, - ) + assert np.max(output["predictions"][0][:]) == 41 + assert np.max(output["predictions"][1][:]) == 17 + assert len(output["inst_dict"][0].columns) == 41 + assert len(output["inst_dict"][1].columns) == 17 diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index cca443d78..6bfa3c6ad 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -4,40 +4,41 @@ import uuid from collections import deque +from pathlib import Path from typing import TYPE_CHECKING -import dask +import dask.array as da +import dask.dataframe as dd + # replace with the sql database once the PR in place import joblib import numpy as np +import pandas as pd import torch import tqdm -import dask.array as da from shapely.geometry import box as shapely_box from shapely.strtree import STRtree -from torch.utils.data import DataLoader from typing_extensions import Unpack +from tiatoolbox import DuplicateFilter, logger from tiatoolbox.models.engine.semantic_segmentor import ( SemanticSegmentor, SemanticSegmentorRunParams, ) from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils.misc import get_tqdm -from .engine_abc import EngineABCRunParams -from tiatoolbox import DuplicateFilter, logger -from pathlib import Path - if TYPE_CHECKING: # pragma: no cover import os from collections.abc import Callable + from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.wsicore import WSIReader + from .engine_abc import EngineABCRunParams from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig @@ -490,7 +491,9 @@ def infer_patches( labels.append(da.from_array(np.array(batch_data["label"]))) for i in range(num_expected_output): - raw_predictions["probabilities"][i] = da.concatenate(probabilities[i], axis=0) + raw_predictions["probabilities"][i] = da.concatenate( + probabilities[i], axis=0 + ) if return_coordinates: raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) @@ -548,8 +551,8 @@ def _run_patch_mode( return_coordinates=output_type == "annotationstore", ) - raw_predictions["predictions"] = self.post_process_patches( - raw_predictions=raw_predictions["probabilities"], + raw_predictions = self.post_process_patches( + raw_predictions=raw_predictions, prediction_shape=None, prediction_dtype=None, **kwargs, @@ -570,11 +573,11 @@ def _run_patch_mode( def post_process_patches( # skipcq: PYL-R0201 self: NucleusInstanceSegmentor, - raw_predictions: da.Array, + raw_predictions: dict, prediction_shape: tuple[int, ...], # noqa: ARG002 prediction_dtype: type, # noqa: ARG002 **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 - ) -> dask.array.Array: + ) -> dict: """Post-process raw patch predictions from inference. This method applies a post-processing function (e.g., smoothing, filtering) @@ -596,9 +599,44 @@ def post_process_patches( # skipcq: PYL-R0201 Post-processed predictions as a Dask array. """ - raw_predictions = self.model.postproc_func(raw_predictions) + probabilities = raw_predictions["probabilities"] + predictions = [[] for _ in range(probabilities[0].shape[0])] + inst_dict = [[] for _ in range(probabilities[0].shape[0])] + for idx in range(probabilities[0].shape[0]): + predictions[idx], inst_dict[idx] = self.model.postproc_func( + [probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]] + ) + inst_dict[idx] = dd.from_pandas(pd.DataFrame(inst_dict[idx])) + + raw_predictions["predictions"] = da.stack(predictions, axis=0) + raw_predictions["inst_dict"] = inst_dict + return raw_predictions + def save_predictions( + self: SemanticSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path: + """Save semantic segmentation predictions to disk or return them in memory.""" + # Conversion to annotationstore uses a different function for SemanticSegmentor + inst_dict: list[dd.DataFrame] | None = processed_predictions.pop( + "inst_dict", None + ) + out = super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) + + if isinstance(out, dict): + out["inst_dict"] = [[] for _ in range(len(inst_dict))] + for idx in range(len(inst_dict)): + out["inst_dict"][idx] = inst_dict[idx].compute() + return out + + return out + @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, From da6a1eac24a41151c498c2d8d57878e57fa8d593 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:11:05 +0000 Subject: [PATCH 6/7] :white_check_mark: Test patch mode with dict and zarr output --- .../test_nucleus_instance_segmentor.py | 94 ++++++++++++++++++- tiatoolbox/models/architecture/hovernet.py | 31 +++++- tiatoolbox/models/engine/engine_abc.py | 33 +++++-- .../engine/nucleus_instance_segmentor.py | 22 +---- 4 files changed, 149 insertions(+), 31 deletions(-) diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py index 11d71ab2c..b17c98f2a 100644 --- a/tests/engines/test_nucleus_instance_segmentor.py +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -6,6 +6,7 @@ import numpy as np import torch +import zarr from tiatoolbox.models import NucleusInstanceSegmentor from tiatoolbox.wsicore import WSIReader @@ -13,7 +14,9 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu" -def test_functionality_patch_mode(remote_sample: Callable) -> None: +def test_functionality_patch_mode( + remote_sample: Callable, track_tmp_path: Path +) -> None: """Patch mode functionality test for nuclei instance segmentor.""" mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) mini_wsi = WSIReader.open(mini_wsi_svs) @@ -33,7 +36,10 @@ def test_functionality_patch_mode(remote_sample: Callable) -> None: units=units, ) - patches = np.stack(arrays=[patch1, patch2], axis=0) + # Test dummy input, should result in no output segmentation + patch3 = np.zeros_like(patch1) + + patches = np.stack(arrays=[patch1, patch2, patch3], axis=0) inst_segmentor = NucleusInstanceSegmentor( batch_size=1, @@ -49,5 +55,85 @@ def test_functionality_patch_mode(remote_sample: Callable) -> None: assert np.max(output["predictions"][0][:]) == 41 assert np.max(output["predictions"][1][:]) == 17 - assert len(output["inst_dict"][0].columns) == 41 - assert len(output["inst_dict"][1].columns) == 17 + assert np.max(output["predictions"][2][:]) == 0 + + assert len(output["box"][0]) == 41 + assert len(output["box"][1]) == 17 + assert len(output["box"][2]) == 0 + + assert len(output["centroid"][0]) == 41 + assert len(output["centroid"][1]) == 17 + assert len(output["centroid"][2]) == 0 + + assert len(output["contour"][0]) == 41 + assert len(output["contour"][1]) == 17 + assert len(output["contour"][2]) == 0 + + assert len(output["prob"][0]) == 41 + assert len(output["prob"][1]) == 17 + assert len(output["prob"][2]) == 0 + + assert len(output["type"][0]) == 41 + assert len(output["type"][1]) == 17 + assert len(output["type"][2]) == 0 + + output_ = output + + output = inst_segmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="zarr", + save_dir=track_tmp_path / "patch_output_zarr", + ) + + output = zarr.open(output, mode="r") + + assert np.max(output["predictions"][0][:]) == 41 + assert np.max(output["predictions"][1][:]) == 17 + + assert all( + np.array_equal(a, b) + for a, b in zip(output["box"][0], output_["box"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["box"][1], output_["box"][1], strict=False) + ) + assert len(output["box"][2]) == 0 + + assert all( + np.array_equal(a, b) + for a, b in zip(output["centroid"][0], output_["centroid"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["centroid"][1], output_["centroid"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["contour"][0], output_["contour"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["contour"][1], output_["contour"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["prob"][0], output_["prob"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["prob"][1], output_["prob"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["type"][0], output_["type"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["type"][1], output_["type"][1], strict=False) + ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index af2186c00..6aa592ad0 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -4,9 +4,13 @@ import math from collections import OrderedDict -import dask + import cv2 +import dask +import dask.array as da +import dask.dataframe as dd import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from scipy import ndimage @@ -22,6 +26,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils.misc import get_bounding_box +dask.config.set({"dataframe.convert-string": False}) + class TFSamepaddingLayer(nn.Module): """To align with tensorflow `same` padding. @@ -782,7 +788,28 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) - return pred_inst, nuc_inst_info_dict + if not nuc_inst_info_dict: + nuc_inst_info_dict = { # inst_id should start at 1 + "box": da.empty(shape=0), + "centroid": da.empty(shape=0), + "contour": da.empty(shape=0), + "prob": da.empty(shape=0), + "type": da.empty(shape=0), + } + return pred_inst, nuc_inst_info_dict + + # dask dataframe does not support transpose + nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() + + # create dask dataframe + nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df) + + # reinitialize nuc_inst_info_dict + nuc_inst_info_dict_ = {} + for key in nuc_inst_info_df.columns: + nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True) + + return pred_inst, nuc_inst_info_dict_ @staticmethod def infer_batch( # skipcq: PYL-W0221 diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 01cb4e1a0..5cf2601b2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -46,6 +46,7 @@ import zarr from dask import compute from dask.diagnostics import ProgressBar +from numcodecs import Pickle from torch import nn from typing_extensions import Unpack @@ -71,6 +72,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units +dask.config.set({"dataframe.convert-string": False}) + class EngineABCRunParams(TypedDict, total=False): """Parameters for configuring the :func:`EngineABC.run()` method. @@ -645,13 +648,29 @@ def save_predictions( keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] write_tasks = [] for key in keys_to_compute: - dask_array = processed_predictions[key].rechunk("auto") - task = dask_array.to_zarr( - url=save_path, - component=key, - compute=False, - ) - write_tasks.append(task) + dask_output = processed_predictions[key] + if isinstance(dask_output, da.Array): + dask_output = dask_output.rechunk("auto") + task = dask_output.to_zarr( + url=save_path, component=key, compute=False, object_codec=None + ) + write_tasks.append(task) + + if isinstance(dask_output, list) and all( + isinstance(dask_array, da.Array) for dask_array in dask_output + ): + for i, dask_array in enumerate(dask_output): + object_codec = ( + Pickle() if dask_array.dtype == "object" else None + ) + task = dask_array.to_zarr( + url=save_path, + component=f"{key}/{i}", + compute=False, + object_codec=object_codec, + ) + write_tasks.append(task) + msg = f"Saving output to {save_path}." logger.info(msg=msg) with ProgressBar(): diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 6bfa3c6ad..5ee5b138f 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -8,12 +8,10 @@ from typing import TYPE_CHECKING import dask.array as da -import dask.dataframe as dd # replace with the sql database once the PR in place import joblib import numpy as np -import pandas as pd import torch import tqdm from shapely.geometry import box as shapely_box @@ -601,15 +599,15 @@ def post_process_patches( # skipcq: PYL-R0201 """ probabilities = raw_predictions["probabilities"] predictions = [[] for _ in range(probabilities[0].shape[0])] - inst_dict = [[] for _ in range(probabilities[0].shape[0])] + inst_dict = [[{}] for _ in range(probabilities[0].shape[0])] for idx in range(probabilities[0].shape[0]): predictions[idx], inst_dict[idx] = self.model.postproc_func( [probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]] ) - inst_dict[idx] = dd.from_pandas(pd.DataFrame(inst_dict[idx])) raw_predictions["predictions"] = da.stack(predictions, axis=0) - raw_predictions["inst_dict"] = inst_dict + for key in inst_dict[0]: + raw_predictions[key] = [d[key] for d in inst_dict] return raw_predictions @@ -621,22 +619,10 @@ def save_predictions( **kwargs: Unpack[SemanticSegmentorRunParams], ) -> dict | AnnotationStore | Path: """Save semantic segmentation predictions to disk or return them in memory.""" - # Conversion to annotationstore uses a different function for SemanticSegmentor - inst_dict: list[dd.DataFrame] | None = processed_predictions.pop( - "inst_dict", None - ) - out = super().save_predictions( + return super().save_predictions( processed_predictions, output_type, save_path=save_path, **kwargs ) - if isinstance(out, dict): - out["inst_dict"] = [[] for _ in range(len(inst_dict))] - for idx in range(len(inst_dict)): - out["inst_dict"][idx] = inst_dict[idx].compute() - return out - - return out - @staticmethod def _get_tile_info( image_shape: list[int] | np.ndarray, From 5e14877a25c1444e3415ddea5a2638638fd1e512 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:16:20 +0000 Subject: [PATCH 7/7] :lipstick: log output if save path is requested --- tiatoolbox/models/engine/nucleus_instance_segmentor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 5ee5b138f..c15322718 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -565,8 +565,9 @@ def _run_patch_mode( **kwargs, ) - msg = f"Output file saved at {out}." - logger.info(msg=msg) + if save_path: + msg = f"Output file saved at {out}." + logger.info(msg=msg) return out def post_process_patches( # skipcq: PYL-R0201