Skip to content
Draft
3 changes: 2 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ aiohttp>=3.8.1
albumentations>=1.3.0
bokeh>=3.1.1, <3.6.0
Click>=8.1.3, <8.2.0
dask>=2025.10.0
dask[array]>=2025.10.0
dask[dataframe]>=2025.10.0
defusedxml>=0.7.1
filelock>=3.9.0
flask>=2.2.2
Expand Down
139 changes: 139 additions & 0 deletions tests/engines/test_nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Test tiatoolbox.models.engine.nucleus_instance_segmentor."""

from collections.abc import Callable
from pathlib import Path
from typing import Final

import numpy as np
import torch
import zarr

from tiatoolbox.models import NucleusInstanceSegmentor
from tiatoolbox.wsicore import WSIReader

device = "cuda:0" if torch.cuda.is_available() else "cpu"


def test_functionality_patch_mode(
remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Patch mode functionality test for nuclei instance segmentor."""
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
mini_wsi = WSIReader.open(mini_wsi_svs)
size = (256, 256)
resolution = 0.25
units: Final = "mpp"
patch1 = mini_wsi.read_rect(
location=(0, 0),
size=size,
resolution=resolution,
units=units,
)
patch2 = mini_wsi.read_rect(
location=(512, 512),
size=size,
resolution=resolution,
units=units,
)

# Test dummy input, should result in no output segmentation
patch3 = np.zeros_like(patch1)

patches = np.stack(arrays=[patch1, patch2, patch3], axis=0)

inst_segmentor = NucleusInstanceSegmentor(
batch_size=1,
num_workers=0,
model="hovernet_fast-pannuke",
)
output = inst_segmentor.run(
images=patches,
patch_mode=True,
device=device,
output_type="dict",
)

assert np.max(output["predictions"][0][:]) == 41
assert np.max(output["predictions"][1][:]) == 17
assert np.max(output["predictions"][2][:]) == 0

assert len(output["box"][0]) == 41
assert len(output["box"][1]) == 17
assert len(output["box"][2]) == 0

assert len(output["centroid"][0]) == 41
assert len(output["centroid"][1]) == 17
assert len(output["centroid"][2]) == 0

assert len(output["contour"][0]) == 41
assert len(output["contour"][1]) == 17
assert len(output["contour"][2]) == 0

assert len(output["prob"][0]) == 41
assert len(output["prob"][1]) == 17
assert len(output["prob"][2]) == 0

assert len(output["type"][0]) == 41
assert len(output["type"][1]) == 17
assert len(output["type"][2]) == 0

output_ = output

output = inst_segmentor.run(
images=patches,
patch_mode=True,
device=device,
output_type="zarr",
save_dir=track_tmp_path / "patch_output_zarr",
)

output = zarr.open(output, mode="r")

assert np.max(output["predictions"][0][:]) == 41
assert np.max(output["predictions"][1][:]) == 17

assert all(
np.array_equal(a, b)
for a, b in zip(output["box"][0], output_["box"][0], strict=False)
)
assert all(
np.array_equal(a, b)
for a, b in zip(output["box"][1], output_["box"][1], strict=False)
)
assert len(output["box"][2]) == 0

assert all(
np.array_equal(a, b)
for a, b in zip(output["centroid"][0], output_["centroid"][0], strict=False)
)
assert all(
np.array_equal(a, b)
for a, b in zip(output["centroid"][1], output_["centroid"][1], strict=False)
)

assert all(
np.array_equal(a, b)
for a, b in zip(output["contour"][0], output_["contour"][0], strict=False)
)
assert all(
np.array_equal(a, b)
for a, b in zip(output["contour"][1], output_["contour"][1], strict=False)
)

assert all(
np.array_equal(a, b)
for a, b in zip(output["prob"][0], output_["prob"][0], strict=False)
)
assert all(
np.array_equal(a, b)
for a, b in zip(output["prob"][1], output_["prob"][1], strict=False)
)

assert all(
np.array_equal(a, b)
for a, b in zip(output["type"][0], output_["type"][0], strict=False)
)
assert all(
np.array_equal(a, b)
for a, b in zip(output["type"][1], output_["type"][1], strict=False)
)
33 changes: 31 additions & 2 deletions tiatoolbox/models/architecture/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from collections import OrderedDict

import cv2
import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F # noqa: N812
from scipy import ndimage
Expand All @@ -22,6 +26,8 @@
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils.misc import get_bounding_box

dask.config.set({"dataframe.convert-string": False})


class TFSamepaddingLayer(nn.Module):
"""To align with tensorflow `same` padding.
Expand Down Expand Up @@ -776,11 +782,34 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]:
tp_map = None
np_map, hv_map = raw_maps

pred_type = tp_map
np_map = np_map.compute() if isinstance(np_map, dask.array.Array) else np_map
hv_map = hv_map.compute() if isinstance(hv_map, dask.array.Array) else hv_map
pred_type = tp_map.compute() if isinstance(tp_map, dask.array.Array) else tp_map
pred_inst = HoVerNet._proc_np_hv(np_map, hv_map)
nuc_inst_info_dict = HoVerNet.get_instance_info(pred_inst, pred_type)

return pred_inst, nuc_inst_info_dict
if not nuc_inst_info_dict:
nuc_inst_info_dict = { # inst_id should start at 1
"box": da.empty(shape=0),
"centroid": da.empty(shape=0),
"contour": da.empty(shape=0),
"prob": da.empty(shape=0),
"type": da.empty(shape=0),
}
return pred_inst, nuc_inst_info_dict

# dask dataframe does not support transpose
nuc_inst_info_df = pd.DataFrame(nuc_inst_info_dict).transpose()

# create dask dataframe
nuc_inst_info_dd = dd.from_pandas(nuc_inst_info_df)

# reinitialize nuc_inst_info_dict
nuc_inst_info_dict_ = {}
for key in nuc_inst_info_df.columns:
nuc_inst_info_dict_[key] = nuc_inst_info_dd[key].to_dask_array(lengths=True)

return pred_inst, nuc_inst_info_dict_

@staticmethod
def infer_batch( # skipcq: PYL-W0221
Expand Down
35 changes: 27 additions & 8 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import zarr
from dask import compute
from dask.diagnostics import ProgressBar
from numcodecs import Pickle
from torch import nn
from typing_extensions import Unpack

Expand All @@ -71,6 +72,8 @@
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.type_hints import IntPair, Resolution, Units

dask.config.set({"dataframe.convert-string": False})


class EngineABCRunParams(TypedDict, total=False):
"""Parameters for configuring the :func:`EngineABC.run()` method.
Expand Down Expand Up @@ -524,7 +527,7 @@ def infer_patches(
coordinates = []

# Main output dictionary
raw_predictions = dict(zip(keys, [[]] * len(keys), strict=False))
raw_predictions = {key: [] for key in keys}

# Inference loop
tqdm = get_tqdm()
Expand Down Expand Up @@ -645,13 +648,29 @@ def save_predictions(
keys_to_compute = [k for k in keys_to_compute if k not in zarr_group]
write_tasks = []
for key in keys_to_compute:
dask_array = processed_predictions[key].rechunk("auto")
task = dask_array.to_zarr(
url=save_path,
component=key,
compute=False,
)
write_tasks.append(task)
dask_output = processed_predictions[key]
if isinstance(dask_output, da.Array):
dask_output = dask_output.rechunk("auto")
task = dask_output.to_zarr(
url=save_path, component=key, compute=False, object_codec=None
)
write_tasks.append(task)

if isinstance(dask_output, list) and all(
isinstance(dask_array, da.Array) for dask_array in dask_output
):
for i, dask_array in enumerate(dask_output):
object_codec = (
Pickle() if dask_array.dtype == "object" else None
)
task = dask_array.to_zarr(
url=save_path,
component=f"{key}/{i}",
compute=False,
object_codec=object_codec,
)
write_tasks.append(task)

msg = f"Saving output to {save_path}."
logger.info(msg=msg)
with ProgressBar():
Expand Down
Loading
Loading