diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 045a4ce4e..ad6d8f40a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,7 +4,8 @@ aiohttp>=3.8.1 albumentations>=1.3.0 bokeh>=3.1.1, <3.6.0 Click>=8.1.3, <8.2.0 -dask>=2025.10.0 +dask[array]>=2025.10.0 +dask[dataframe]>=2025.10.0 defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 diff --git a/tests/engines/test_nucleus_instance_segmentor.py b/tests/engines/test_nucleus_instance_segmentor.py new file mode 100644 index 000000000..b17c98f2a --- /dev/null +++ b/tests/engines/test_nucleus_instance_segmentor.py @@ -0,0 +1,139 @@ +"""Test tiatoolbox.models.engine.nucleus_instance_segmentor.""" + +from collections.abc import Callable +from pathlib import Path +from typing import Final + +import numpy as np +import torch +import zarr + +from tiatoolbox.models import NucleusInstanceSegmentor +from tiatoolbox.wsicore import WSIReader + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + + +def test_functionality_patch_mode( + remote_sample: Callable, track_tmp_path: Path +) -> None: + """Patch mode functionality test for nuclei instance segmentor.""" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + mini_wsi = WSIReader.open(mini_wsi_svs) + size = (256, 256) + resolution = 0.25 + units: Final = "mpp" + patch1 = mini_wsi.read_rect( + location=(0, 0), + size=size, + resolution=resolution, + units=units, + ) + patch2 = mini_wsi.read_rect( + location=(512, 512), + size=size, + resolution=resolution, + units=units, + ) + + # Test dummy input, should result in no output segmentation + patch3 = np.zeros_like(patch1) + + patches = np.stack(arrays=[patch1, patch2, patch3], axis=0) + + inst_segmentor = NucleusInstanceSegmentor( + batch_size=1, + num_workers=0, + model="hovernet_fast-pannuke", + ) + output = inst_segmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="dict", + ) + + assert np.max(output["predictions"][0][:]) == 41 + assert np.max(output["predictions"][1][:]) == 17 + assert np.max(output["predictions"][2][:]) == 0 + + assert len(output["box"][0]) == 41 + assert len(output["box"][1]) == 17 + assert len(output["box"][2]) == 0 + + assert len(output["centroid"][0]) == 41 + assert len(output["centroid"][1]) == 17 + assert len(output["centroid"][2]) == 0 + + assert len(output["contour"][0]) == 41 + assert len(output["contour"][1]) == 17 + assert len(output["contour"][2]) == 0 + + assert len(output["prob"][0]) == 41 + assert len(output["prob"][1]) == 17 + assert len(output["prob"][2]) == 0 + + assert len(output["type"][0]) == 41 + assert len(output["type"][1]) == 17 + assert len(output["type"][2]) == 0 + + output_ = output + + output = inst_segmentor.run( + images=patches, + patch_mode=True, + device=device, + output_type="zarr", + save_dir=track_tmp_path / "patch_output_zarr", + ) + + output = zarr.open(output, mode="r") + + assert np.max(output["predictions"][0][:]) == 41 + assert np.max(output["predictions"][1][:]) == 17 + + assert all( + np.array_equal(a, b) + for a, b in zip(output["box"][0], output_["box"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["box"][1], output_["box"][1], strict=False) + ) + assert len(output["box"][2]) == 0 + + assert all( + np.array_equal(a, b) + for a, b in zip(output["centroid"][0], output_["centroid"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["centroid"][1], output_["centroid"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["contour"][0], output_["contour"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["contour"][1], output_["contour"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["prob"][0], output_["prob"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["prob"][1], output_["prob"][1], strict=False) + ) + + assert all( + np.array_equal(a, b) + for a, b in zip(output["type"][0], output_["type"][0], strict=False) + ) + assert all( + np.array_equal(a, b) + for a, b in zip(output["type"][1], output_["type"][1], strict=False) + ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 19d02e7a5..6aa592ad0 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -6,7 +6,11 @@ from collections import OrderedDict import cv2 +import dask +import dask.array as da +import dask.dataframe as dd import numpy as np +import pandas as pd import torch import torch.nn.functional as F # noqa: N812 from scipy import ndimage @@ -22,6 +26,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils.misc import get_bounding_box +dask.config.set({"dataframe.convert-string": False}) + class TFSamepaddingLayer(nn.Module): """To align with tensorflow `same` padding. @@ -776,11 +782,34 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: tp_map = None np_map, hv_map = raw_maps - pred_type = tp_map + np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map + hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map + pred_type = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map pred_inst = HoVerNet._proc_np_hv(np_map, hv_map) nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type) - return pred_inst, nuc_inst_info_dict + if not nuc_inst_info_dict: + nuc_inst_info_dict = { # inst_id should start at 1 + "box": da.empty(shape=0), + "centroid": da.empty(shape=0), + "contour": da.empty(shape=0), + "prob": da.empty(shape=0), + "type": da.empty(shape=0), + } + return pred_inst, nuc_inst_info_dict + + # dask dataframe does not support transpose + nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose() + + # create dask dataframe + nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df) + + # reinitialize nuc_inst_info_dict + nuc_inst_info_dict_ = {} + for key in nuc_inst_info_df.columns: + nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True) + + return pred_inst, nuc_inst_info_dict_ @staticmethod def infer_batch( # skipcq: PYL-W0221 diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 73b4ca1c1..5cf2601b2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -46,6 +46,7 @@ import zarr from dask import compute from dask.diagnostics import ProgressBar +from numcodecs import Pickle from torch import nn from typing_extensions import Unpack @@ -71,6 +72,8 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units +dask.config.set({"dataframe.convert-string": False}) + class EngineABCRunParams(TypedDict, total=False): """Parameters for configuring the :func:`EngineABC.run()` method. @@ -524,7 +527,7 @@ def infer_patches( coordinates = [] # Main output dictionary - raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False)) + raw_predictions = {key: [] for key in keys} # Inference loop tqdm = get_tqdm() @@ -645,13 +648,29 @@ def save_predictions( keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] write_tasks = [] for key in keys_to_compute: - dask_array = processed_predictions[key].rechunk("auto") - task = dask_array.to_zarr( - url=save_path, - component=key, - compute=False, - ) - write_tasks.append(task) + dask_output = processed_predictions[key] + if isinstance(dask_output, da.Array): + dask_output = dask_output.rechunk("auto") + task = dask_output.to_zarr( + url=save_path, component=key, compute=False, object_codec=None + ) + write_tasks.append(task) + + if isinstance(dask_output, list) and all( + isinstance(dask_array, da.Array) for dask_array in dask_output + ): + for i, dask_array in enumerate(dask_output): + object_codec = ( + Pickle() if dask_array.dtype == "object" else None + ) + task = dask_array.to_zarr( + url=save_path, + component=f"{key}/{i}", + compute=False, + object_codec=object_codec, + ) + write_tasks.append(task) + msg = f"Saving output to {save_path}." logger.info(msg=msg) with ProgressBar(): diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index ce74355ae..c15322718 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -4,8 +4,11 @@ import uuid from collections import deque +from pathlib import Path from typing import TYPE_CHECKING +import dask.array as da + # replace with the sql database once the PR in place import joblib import numpy as np @@ -13,14 +16,27 @@ import tqdm from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from typing_extensions import Unpack -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset -from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor +from tiatoolbox import DuplicateFilter, logger +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + SemanticSegmentorRunParams, +) from tiatoolbox.tools.patchextraction import PatchExtractor +from tiatoolbox.utils.misc import get_tqdm if TYPE_CHECKING: # pragma: no cover + import os from collections.abc import Callable + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader + + from .engine_abc import EngineABCRunParams from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig @@ -372,38 +388,241 @@ class NucleusInstanceSegmentor(SemanticSegmentor): def __init__( self: NucleusInstanceSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: """Initialize :class:`NucleusInstanceSegmentor`.""" super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, ) - # default is None in base class and is un-settable - # hence we redefine the namespace here - self.num_postproc_workers = ( - num_postproc_workers if num_postproc_workers > 0 else None + + def infer_patches( + self: NucleusInstanceSegmentor, + dataloader: DataLoader, + *, + return_coordinates: bool = False, + ) -> dict[str, list[da.Array]]: + """Run model inference on image patches and return predictions. + + This method performs batched inference using a PyTorch DataLoader, + and accumulates predictions in Dask arrays. It supports optional inclusion + of coordinates and labels in the output. + + Args: + dataloader (DataLoader): + PyTorch DataLoader containing image patches for inference. + return_coordinates (bool): + Whether to include coordinates in the output. Required when + called by `infer_wsi` and `patch_mode` is False. + + Returns: + dict[str, dask.array.Array]: + Dictionary containing prediction results as Dask arrays. + Keys include: + - "probabilities": Model output probabilities. + - "labels": Ground truth labels (if `return_labels` is True). + - "coordinates": Patch coordinates (if `return_coordinates` is + True). + + """ + keys = ["probabilities"] + labels, coordinates = [], [] + + # Expected number of outputs from the model + batch_output = self.model.infer_batch( + self.model, + torch.Tensor(dataloader.dataset[0]["image"][np.newaxis, ...]), + device=self.device, ) - # adding more runtime placeholder - self._wsi_inst_info = None - self._futures = [] + num_expected_output = len(batch_output) + probabilities = [[] for _ in range(num_expected_output)] + + if return_coordinates: + keys.append("coordinates") + coordinates = [] + + # Main output dictionary + raw_predictions = {key: [] for key in keys} + raw_predictions["probabilities"] = [[] for _ in range(num_expected_output)] + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else self.dataloader + ) + + for batch_data in tqdm_loop: + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + for i in range(num_expected_output): + probabilities[i].append( + da.from_array( + batch_output[i], # probabilities + ) + ) + + if return_coordinates: + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + if self.return_labels: + labels.append(da.from_array(np.array(batch_data["label"]))) + + for i in range(num_expected_output): + raw_predictions["probabilities"][i] = da.concatenate( + probabilities[i], axis=0 + ) + + if return_coordinates: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + return raw_predictions + + def _run_patch_mode( + self: NucleusInstanceSegmentor, + output_type: str, + save_dir: Path, + **kwargs: EngineABCRunParams, + ) -> dict | AnnotationStore | Path: + """Run the engine in patch mode. + + This method performs inference on image patches, post-processes the predictions, + and saves the output in the specified format. + + Args: + output_type (str): + Desired output format. Supported values are "dict", "zarr", + and "annotationstore". + save_dir (Path): + Directory to save the output files. + **kwargs (EngineABCRunParams): + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. + + Returns: + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. + + """ + save_path = None + if save_dir: + output_file = Path(kwargs.get("output_file", "output.zarr")) + save_path = save_dir / (str(output_file.stem) + ".zarr") + + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + + self.dataloader = self.get_dataloader( + images=self.images, + masks=self.masks, + labels=self.labels, + patch_mode=True, + ioconfig=self._ioconfig, + ) + raw_predictions = self.infer_patches( + dataloader=self.dataloader, + return_coordinates=output_type == "annotationstore", + ) + + raw_predictions = self.post_process_patches( + raw_predictions=raw_predictions, + prediction_shape=None, + prediction_dtype=None, + **kwargs, + ) + + logger.removeFilter(duplicate_filter) + + out = self.save_predictions( + processed_predictions=raw_predictions, + output_type=output_type, + save_path=save_path, + **kwargs, + ) + + if save_path: + msg = f"Output file saved at {out}." + logger.info(msg=msg) + return out + + def post_process_patches( # skipcq: PYL-R0201 + self: NucleusInstanceSegmentor, + raw_predictions: dict, + prediction_shape: tuple[int, ...], # noqa: ARG002 + prediction_dtype: type, # noqa: ARG002 + **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 + ) -> dict: + """Post-process raw patch predictions from inference. + + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. + + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (EngineABCRunParams): + Additional runtime parameters used for post-processing. + + Returns: + dask.array.Array: + Post-processed predictions as a Dask array. + + """ + probabilities = raw_predictions["probabilities"] + predictions = [[] for _ in range(probabilities[0].shape[0])] + inst_dict = [[{}] for _ in range(probabilities[0].shape[0])] + for idx in range(probabilities[0].shape[0]): + predictions[idx], inst_dict[idx] = self.model.postproc_func( + [probabilities[0][idx], probabilities[1][idx], probabilities[2][idx]] + ) + + raw_predictions["predictions"] = da.stack(predictions, axis=0) + for key in inst_dict[0]: + raw_predictions[key] = [d[key] for d in inst_dict] + + return raw_predictions + + def save_predictions( + self: SemanticSegmentor, + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path: + """Save semantic segmentation predictions to disk or return them in memory.""" + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) @staticmethod def _get_tile_info( @@ -812,3 +1031,29 @@ def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: # manually call the callback rather than # attaching it when receiving/creating the future callback(*result) + + def run( + self: NucleusInstanceSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, + *, + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the nucleus instance segmentor engine on input images.""" + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + )