diff --git a/docs/pretrained.rst b/docs/pretrained.rst index 70c8319b5..f6aeb0a6d 100644 --- a/docs/pretrained.rst +++ b/docs/pretrained.rst @@ -326,7 +326,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -342,7 +342,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -366,7 +366,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(31, 31), stride_shape=(8, 8), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) @@ -382,7 +382,7 @@ The input output configuration is as follows: ioconfig = IOPatchPredictorConfig( patch_input_shape=(252, 252), stride_shape=(150, 150), - input_resolutions=[{"resolution": 0.25, "units": "mpp"}] + input_resolutions=[{"resolution": 0.5, "units": "mpp"}] ) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8195c2f32..719d9551a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ # torch installation ---extra-index-url https://download.pytorch.org/whl/cu117; sys_platform != "darwin" +--extra-index-url https://download.pytorch.org/whl/cu118; sys_platform != "darwin" albumentations>=1.3.0 Click>=8.1.3 defusedxml>=0.7.1 @@ -25,7 +25,7 @@ shapely>=2.0.0 SimpleITK>=2.2.1 sphinx>=5.3.0 tifffile>=2022.10.10 -torch>=2.0.0 +torch>=1.13.0 torchvision>=0.14.1 tqdm>=4.64.1 umap-learn>=0.5.3 diff --git a/tests/models/test_nucleus_detection_engine.py b/tests/models/test_nucleus_detection_engine.py new file mode 100644 index 000000000..e9f4e1aec --- /dev/null +++ b/tests/models/test_nucleus_detection_engine.py @@ -0,0 +1,67 @@ +"""Tests for NucleusDetector.""" + +import pathlib +import shutil + +import pandas as pd +import pytest + +from tiatoolbox.models.engine.nucleus_detector import ( + IONucleusDetectorConfig, + NucleusDetector, +) +from tiatoolbox.utils import env_detection as toolbox_env + +ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() + + +def _rm_dir(path): + """Helper func to remove directory.""" + if pathlib.Path(path).exists(): + shutil.rmtree(path, ignore_errors=True) + + +def check_output(path): + """Check NucleusDetector output.""" + coordinates = pd.read_csv(path) + assert coordinates.x[0] == pytest.approx(53, abs=2) + assert coordinates.x[1] == pytest.approx(55, abs=2) + assert coordinates.y[0] == pytest.approx(107, abs=2) + assert coordinates.y[1] == pytest.approx(127, abs=2) + + +def test_nucleus_detector_engine(remote_sample, tmp_path): + """Test for nucleus detection engine.""" + mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs")) + + nucleus_detector = NucleusDetector(pretrained_model="mapde-conic") + _ = nucleus_detector.predict( + [mini_wsi_svs], + mode="wsi", + save_dir=tmp_path / "output", + on_gpu=ON_GPU, + ) + + check_output(tmp_path / "output" / "0.locations.0.csv") + + _rm_dir(tmp_path / "output") + + ioconfig = IONucleusDetectorConfig( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + output_resolutions=[{"units": "mpp", "resolution": 0.5}], + save_resolution=None, + patch_input_shape=[252, 252], + patch_output_shape=[252, 252], + stride_shape=[150, 150], + ) + + nucleus_detector = NucleusDetector(pretrained_model="mapde-conic") + _ = nucleus_detector.predict( + [mini_wsi_svs], + mode="wsi", + save_dir=tmp_path / "output", + on_gpu=ON_GPU, + ioconfig=ioconfig, + ) + + check_output(tmp_path / "output" / "0.locations.0.csv") diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index f69274623..352782c10 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -548,22 +548,22 @@ def test_io_config_delegation(remote_sample, tmp_path): predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) predictor.predict( [mini_wsi_svs], - patch_input_shape=[300, 300], + patch_input_shape=(300, 300), mode="wsi", on_gpu=ON_GPU, save_dir=f"{tmp_path}/dump", ) - assert predictor._ioconfig.patch_input_shape == [300, 300] + assert predictor._ioconfig.patch_input_shape == (300, 300) _rm_dir(f"{tmp_path}/dump") predictor.predict( [mini_wsi_svs], - stride_shape=[300, 300], + stride_shape=(300, 300), mode="wsi", on_gpu=ON_GPU, save_dir=f"{tmp_path}/dump", ) - assert predictor._ioconfig.stride_shape == [300, 300] + assert predictor._ioconfig.stride_shape == (300, 300) _rm_dir(f"{tmp_path}/dump") predictor.predict( diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 2243cebb0..9eb539efc 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -765,7 +765,6 @@ micronet-consep: - {"units": "mpp", "resolution": 0.25} output_resolutions: - {"units": "mpp", "resolution": 0.25} - margin: 128 tile_shape: [2048, 2048] patch_input_shape: [252, 252] patch_output_shape: [252, 252] @@ -777,66 +776,90 @@ mapde-crchisto: architecture: class: mapde.MapDe kwargs: - input_resolutions: - - { "units": "mpp", "resolution": 0.25 } num_input_channels: 3 min_distance: 4 threshold_abs: 250 num_classes: 1 + ioconfig: + class: semantic_segmentor.IOSegmentorConfig + kwargs: + input_resolutions: + - { "units": "mpp", "resolution": 0.5 } + output_resolutions: + - { "units": "mpp", "resolution": 0.5 } tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] + patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] + save_resolution: { 'units': 'mpp', 'resolution': 0.5 } mapde-conic: url: https://tiatoolbox.dcs.warwick.ac.uk/models/detection/mapde-conic.pth architecture: class: mapde.MapDe kwargs: - input_resolutions: - - { "units": "mpp", "resolution": 0.25 } num_input_channels: 3 min_distance: 3 threshold_abs: 205 num_classes: 1 + ioconfig: + class: semantic_segmentor.IOSegmentorConfig + kwargs: + input_resolutions: + - { "units": "mpp", "resolution": 0.5 } + output_resolutions: + - { "units": "mpp", "resolution": 0.5 } tile_shape: [ 2048, 2048 ] patch_input_shape: [ 252, 252 ] + patch_output_shape: [ 252, 252 ] stride_shape: [ 150, 150 ] + save_resolution: { 'units': 'mpp', 'resolution': 0.5 } sccnn-crchisto: url: https://tiatoolbox.dcs.warwick.ac.uk/models/detection/sccnn-crchisto.pth architecture: class: sccnn.SCCNN kwargs: - input_resolutions: - - { "units": "mpp", "resolution": 0.25 } num_input_channels: 3 - out_height: 13 - out_width: 13 radius: 12 min_distance: 6 threshold_abs: 0.20 + patch_output_shape: [ 13, 13 ] + ioconfig: + class: semantic_segmentor.IOSegmentorConfig + kwargs: + input_resolutions: + - { "units": "mpp", "resolution": 0.5 } + output_resolutions: + - { "units": "mpp", "resolution": 0.5 } tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] + save_resolution: { 'units': 'mpp', 'resolution': 0.5 } sccnn-conic: url: https://tiatoolbox.dcs.warwick.ac.uk/models/detection/sccnn-conic.pth architecture: class: sccnn.SCCNN kwargs: - input_resolutions: - - { "units": "mpp", "resolution": 0.25 } num_input_channels: 3 - out_height: 13 - out_width: 13 radius: 12 min_distance: 5 threshold_abs: 0.05 + patch_output_shape: [ 13, 13 ] + ioconfig: + class: semantic_segmentor.IOSegmentorConfig + kwargs: + input_resolutions: + - { "units": "mpp", "resolution": 0.5 } + output_resolutions: + - { "units": "mpp", "resolution": 0.5 } tile_shape: [ 2048, 2048 ] patch_input_shape: [ 31, 31 ] patch_output_shape: [ 13, 13 ] stride_shape: [ 8, 8 ] + save_resolution: { 'units': 'mpp', 'resolution': 0.5 } nuclick_original-pannuke: url: https://tiatoolbox.dcs.warwick.ac.uk/models/seg/nuclick_original-pannuke.pth diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 2cba98a32..e7936a1c1 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,5 +1,6 @@ """Engines to run models implemented in tiatoolbox.""" from tiatoolbox.models.engine import ( + nucleus_detector, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py new file mode 100644 index 000000000..6e8e8615b --- /dev/null +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -0,0 +1,253 @@ +"""This module implements nucleus detection engine.""" + + +from typing import List, Union + +import numpy as np +import pandas as pd + +from tiatoolbox.models.engine.semantic_segmentor import ( + IOSegmentorConfig, + SemanticSegmentor, +) + + +class IONucleusDetectorConfig(IOSegmentorConfig): + """Contains NucleusDetector input and output information. + + Args: + input_resolutions (list): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + output_resolutions (list): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + patch_input_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest input in (height, width). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IONucleusDetectorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + + """ + + def __init__( + self, + input_resolutions: List[dict], + output_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray], + patch_output_shape: Union[List[int], np.ndarray], + save_resolution: dict = None, + **kwargs, + ): + super().__init__( + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + patch_input_shape=patch_input_shape, + patch_output_shape=patch_output_shape, + save_resolution=save_resolution, + **kwargs, + ) + + +class NucleusDetector(SemanticSegmentor): + r"""Nucleus detection engine. + + The models provided by tiatoolbox should give the following results: + + .. list-table:: Nucleus detection performance on the (add models list here) + :widths: 15 15 + :header-rows: 1 + + Args: + model (nn.Module): + Use externally defined PyTorch model for prediction with. + weights already loaded. Default is `None`. If provided, + `pretrained_model` argument is ignored. + pretrained_model (str): + Name of the existing models support by tiatoolbox for + processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `pretrained_weights` argument. Argument + is case-insensitive. + pretrained_weights (str): + Path to the weight of the corresponding `pretrained_model`. + + >>> predictor = NucleusDetector( + ... pretrained_model="mapde-conic", + ... pretrained_weights="mapde_local_weight") + + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers to load the data. Take note that they will + also perform preprocessing. + verbose (bool): + Whether to output logging information. default=False. + auto_generate_mask (bool): + To automatically generate tile/WSI tissue mask if is not + provided. default=False. + + Attributes: + imgs (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + A HWC image or a path to WSI. + model (nn.Module): + Defined PyTorch model. + pretrained_model (str): + Name of the existing models support by tiatoolbox for + processing the data e.g., mapde-conic, sccnn-conic. + For a full list of pretrained models, please refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `pretrained_weights` argument. Argument + is case insensitive. + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers used in torch.utils.data.DataLoader. + verbose (bool): + Whether to output logging information. + + Examples: + >>> # list of 2 image patches as input + >>> data = [img1, img2] + >>> nucleus_detector = NucleusDetector(pretrained_model="mapde-conic") + >>> output = nucleus_detector.predict(data, mode='patch') + + >>> # array of list of 2 image patches as input + >>> data = np.array([img1, img2]) + >>> nucleus_detector = NucleusDetector(pretrained_model="mapde-conic") + >>> output = nucleus_detector.predict(data, mode='patch') + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> nucleus_detector = NucleusDetector(pretrained_model="mapde-conic") + >>> output = nucleus_detector.predict(data, mode='patch') + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> nucleus_detector = NucleusDetector(pretraind_model="mapde-conic") + >>> output = nucleus_detector.predict(tile_file, mode='tile') + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> nucleus_detector = NucleusDetector(pretraind_model="mapde-conic") + >>> output = nucleus_detector.predict(wsi_file, mode='wsi') + + References: + [1] Raza, Shan E. Ahmed, et al. "Deconvolving convolutional neural network + for cell detection." 2019 IEEE 16th International Symposium on Biomedical + Imaging (ISBI 2019). IEEE, 2019. + + [2] Sirinukunwattana, Korsuk, et al. + "Locality sensitive deep learning for detection and classification + of nuclei in routine colon cancer histology images." + IEEE transactions on medical imaging 35.5 (2016): 1196-1206. + + """ # noqa: W605 + + from tiatoolbox.wsicore.wsireader import WSIReader + + def __init__( + self, + batch_size=8, + num_loader_workers=0, + model=None, + pretrained_model=None, + pretrained_weights=None, + verbose: bool = False, + auto_generate_mask: bool = False, + ): + super().__init__( + batch_size=batch_size, + num_loader_workers=num_loader_workers, + model=model, + pretrained_model=pretrained_model, + pretrained_weights=pretrained_weights, + verbose=verbose, + auto_generate_mask=auto_generate_mask, + ) + + def _process_predictions( + self, + cum_batch_predictions: List, + wsi_reader: WSIReader, + ioconfig: IOSegmentorConfig, + save_path: str, + cache_dir: str, + ): + """Define how the aggregated predictions are processed. + + This includes merging the prediction if necessary and also saving the + locations afterwards. Note that items within `cum_batch_predictions` will + be consumed during the operation. + + Args: + cum_batch_predictions (list): + List of batch predictions. Each item within the list + should be of (location, patch_predictions). + wsi_reader (:class:`WSIReader`): + A reader for the image where the predictions come from. + ioconfig (:class:`IOSegmentorConfig`): + A configuration object contains input and output + information. + save_path (str): + Root path to save current WSI predictions. + cache_dir (str): + Root path to cache current WSI data. + + """ + if len(cum_batch_predictions) == 0: + return + + # assume predictions is N, each item has L output element + locations, predictions = list(zip(*cum_batch_predictions)) + # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of + # output patch this can exceed the image bound at the requested + # resolution remove singleton due to split. + locations = np.array([v[0] for v in locations]) + for index, output_resolution in enumerate(ioconfig.output_resolutions): + # assume resolution index to be in the same order as L + merged_resolution = ioconfig.highest_input_resolution + merged_locations = locations + # ! location is w.r.t the highest resolution, hence still need conversion + if ioconfig.save_resolution is not None: + merged_resolution = ioconfig.save_resolution + output_shape = wsi_reader.slide_dimensions(**output_resolution) + merged_shape = wsi_reader.slide_dimensions(**merged_resolution) + fx = merged_shape[0] / output_shape[0] + merged_locations = np.ceil(locations * fx).astype(np.int64) + merged_shape = wsi_reader.slide_dimensions(**merged_resolution) + # 0 idx is to remove singleton without removing other axes singleton + to_merge_predictions = [v[index][0] for v in predictions] + sub_save_path = f"{save_path}.raw.{index}.npy" + sub_count_path = f"{cache_dir}/count.{index}.npy" + cum_canvas = self.merge_prediction( + merged_shape[::-1], # XY to YX + to_merge_predictions, + merged_locations, + save_path=sub_save_path, + cache_count_path=sub_count_path, + ) + + # Coordinates in output resolution for the current canvas. + cum_canvas = np.expand_dims(cum_canvas, axis=0) + coordinates_canvas = pd.DataFrame( + self.model.postproc_func(cum_canvas), columns=["x", "y"] + ) + coordinates_canvas.to_csv(f"{save_path}.locations.{index}.csv", index=False) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 2584f3bc0..7ee21604d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -158,7 +158,7 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + imgs (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` @@ -246,7 +246,7 @@ def __init__( self.model = model # for runtime, such as after wrapping with nn.DataParallel self.pretrained_model = pretrained_model self.batch_size = batch_size - self.num_loader_worker = num_loader_workers + self.num_loader_workers = num_loader_workers self.verbose = verbose @staticmethod @@ -397,7 +397,7 @@ def _predict_engine( # preprocessing must be defined with the dataset dataloader = torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_worker, + num_workers=self.num_loader_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, @@ -462,20 +462,24 @@ def _update_ioconfig( Args: ioconfig (IOPatchPredictorConfig): - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. + Object defines information about input and output placement + of patches for patch prediction. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride using during tile and WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. Returns: Updated Patch Predictor IO configuration. diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 620bdcaba..3a7970f87 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1537,7 +1537,7 @@ def predict( Resolution used for reading the image. units (Units): Units of resolution used for reading the image. - save_dir (str): + save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked.