From 3eb4ab33d8b42d30adb1b74039ceae7e17762377 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jun 2023 18:10:53 +0100 Subject: [PATCH 001/112] :recycle: Refactor base code from IOSegmentorConfig to IOConfigABC - Refactor base code from IOSegmentorConfig to IOConfigABC --- tiatoolbox/models/engine/engine_abc.py | 90 +++++++++++++++---- .../models/engine/semantic_segmentor.py | 75 +++------------- whitelist.txt | 1 + 3 files changed, 86 insertions(+), 80 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 3f51e5681..b0622e8eb 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -3,6 +3,8 @@ import numpy as np +from tiatoolbox.wsicore.wsimeta import Units + class IOConfigABC(ABC): """Define an abstract class for holding predictor I/O information. @@ -10,8 +12,20 @@ class IOConfigABC(ABC): Enforcing such that following attributes must always be defined by the subclass. + Args: + input_resolutions (list): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + stride_shape (:class:`numpy.ndarray`, list(int)): + Stride in (x, y) direction for patch extraction. + patch_input_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest input in (height, width). + """ + # We pre-define to follow enforcement, actual initialisation in init + input_resolutions = None + def __init__( self, input_resolutions: List[dict], @@ -20,20 +34,18 @@ def __init__( **kwargs, ): self._kwargs = kwargs - self.resolution_unit = input_resolutions[0]["units"] self.patch_input_shape = patch_input_shape self.stride_shape = stride_shape + self.input_resolutions = input_resolutions + self.output_resolutions = [] + # output_resolutions are equal to input resolutions by default + # but these are customizable. + self.resolution_unit = input_resolutions[0]["units"] - self._validate() + for variable, value in kwargs.items(): + self.__setattr__(variable, value) - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"] - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"] - ) + self._validate() def _validate(self): """Validate the data format.""" @@ -47,14 +59,58 @@ def _validate(self): ]: raise ValueError(f"Invalid resolution units `{units[0]}`.") - @property - @abstractmethod - def input_resolutions(self): - raise NotImplementedError + def _set_highest_input_resolution(self): + """Identifies and sets highest input resolution available.""" + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, key=lambda x: x["resolution"] + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, key=lambda x: x["resolution"] + ) - @property - @abstractmethod - def output_resolutions(self): + @staticmethod + def scale_to_highest(resolutions: List[dict], units: Units): + """Get the scaling factor from input resolutions. + + This will convert resolutions to a scaling factor with respect to + the highest resolution found in the input resolutions list. + + Args: + resolutions (list): + A list of resolutions where one is defined as + `{'resolution': value, 'unit': value}` + units (Units): + Units that the resolutions are at. + + Returns: + :class:`numpy.ndarray`: + A 1D array of scaling factors having the same length as + `resolutions` + + """ + old_val = [v["resolution"] for v in resolutions] + if units not in ["baseline", "mpp", "power"]: + raise ValueError( + f"Unknown units `{units}`. " + "Units should be one of 'baseline', 'mpp' or 'power'." + ) + if units == "baseline": + return old_val + if units == "mpp": + return np.min(old_val) / np.array(old_val) + return np.array(old_val) / np.max(old_val) + + def to_baseline(self): + """Return a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ raise NotImplementedError diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index ea8738983..8155b9c81 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -20,12 +20,13 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.engine.engine_abc import IOConfigABC from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread, misc from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader +from .engine_abc import IOConfigABC + def _estimate_canvas_parameters(sample_prediction, canvas_shape): """Estimates canvas parameters. @@ -106,6 +107,8 @@ class IOSegmentorConfig(IOConfigABC): output_resolutions (list): Resolution of each output head from model inference, must be in the same order as target model.infer_batch(). + stride_shape (:class:`numpy.ndarray`, list(int)): + Stride in (x, y) direction for patch extraction. patch_input_shape (:class:`numpy.ndarray`, list(int)): Shape of the largest input in (height, width). patch_output_shape (:class:`numpy.ndarray`, list(int)): @@ -155,77 +158,23 @@ def __init__( input_resolutions: List[dict], output_resolutions: List[dict], patch_input_shape: Union[List[int], np.ndarray], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], patch_output_shape: Union[List[int], np.ndarray], save_resolution: dict = None, **kwargs, ): - self._kwargs = kwargs - self.patch_input_shape = patch_input_shape + super().__init__( + input_resolutions=input_resolutions, + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + **kwargs, + ) self.patch_output_shape = patch_output_shape - self.stride_shape = None - self.input_resolutions = input_resolutions self.output_resolutions = output_resolutions - - self.resolution_unit = input_resolutions[0]["units"] self.save_resolution = save_resolution - for variable, value in kwargs.items(): - self.__setattr__(variable, value) - self._validate() - - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"] - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"] - ) - - def _validate(self): - """Validate the data format.""" - resolutions = self.input_resolutions + self.output_resolutions - units = [v["units"] for v in resolutions] - units = np.unique(units) - if len(units) != 1 or units[0] not in [ - "power", - "baseline", - "mpp", - ]: - raise ValueError(f"Invalid resolution units `{units[0]}`.") - - @staticmethod - def scale_to_highest(resolutions: List[dict], units: Units): - """Get the scaling factor from input resolutions. - - This will convert resolutions to a scaling factor with respect to - the highest resolution found in the input resolutions list. - - Args: - resolutions (list): - A list of resolutions where one is defined as - `{'resolution': value, 'unit': value}` - units (Units): - Units that the resolutions are at. - - Returns: - :class:`numpy.ndarray`: - A 1D array of scaling factors having the same length as - `resolutions` - - """ - old_val = [v["resolution"] for v in resolutions] - if units not in ["baseline", "mpp", "power"]: - raise ValueError( - f"Unknown units `{units}`. " - "Units should be one of 'baseline', 'mpp' or 'power'." - ) - if units == "baseline": - return old_val - if units == "mpp": - return np.min(old_val) / np.array(old_val) - return np.array(old_val) / np.max(old_val) + self._set_highest_input_resolution() def to_baseline(self): """Return a new config object converted to baseline form. diff --git a/whitelist.txt b/whitelist.txt index 2f15e7d4b..1d3b79a12 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -94,6 +94,7 @@ coord coords csv cuda +customizable cv2 dataframe dataset From c8d3458ad7bc1b2cb34b484c30a2403be3b59cbe Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jun 2023 18:15:38 +0100 Subject: [PATCH 002/112] :recycle: No need for a separate function - No need for a separate function --- tiatoolbox/models/engine/engine_abc.py | 20 ++++++++----------- .../models/engine/semantic_segmentor.py | 1 - 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index b0622e8eb..64ab9c433 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -45,7 +45,14 @@ def __init__( for variable, value in kwargs.items(): self.__setattr__(variable, value) - self._validate() + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, key=lambda x: x["resolution"] + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, key=lambda x: x["resolution"] + ) def _validate(self): """Validate the data format.""" @@ -59,17 +66,6 @@ def _validate(self): ]: raise ValueError(f"Invalid resolution units `{units[0]}`.") - def _set_highest_input_resolution(self): - """Identifies and sets highest input resolution available.""" - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"] - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"] - ) - @staticmethod def scale_to_highest(resolutions: List[dict], units: Units): """Get the scaling factor from input resolutions. diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 8155b9c81..1d198e0db 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -174,7 +174,6 @@ def __init__( self.save_resolution = save_resolution self._validate() - self._set_highest_input_resolution() def to_baseline(self): """Return a new config object converted to baseline form. From 9623f860b66c5284e73dd6af7fff88845aa9e993 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 8 Jun 2023 18:17:22 +0100 Subject: [PATCH 003/112] :technologist: Enable git workflow on this PR - Enable git workflow on this PR --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index aee579e5a..aa809fd22 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,7 +8,7 @@ on: branches: [ develop, pre-release, master, main ] tags: v* pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, dev-define-engines-abc] jobs: build: From 4c8f91f73df7b14ceb8bd6798529509cf979ad4e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jun 2023 09:33:35 +0100 Subject: [PATCH 004/112] :bug: Add missing variable - Add missing variable --- tiatoolbox/models/engine/semantic_segmentor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 1d198e0db..a1eea1319 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -210,6 +210,7 @@ def to_baseline(self): output_resolutions=output_resolutions, patch_input_shape=self.patch_input_shape, patch_output_shape=self.patch_output_shape, + stride_shape=self.stride_shape, save_resolution=save_resolution, **self._kwargs, ) From 972d81d3c260ea35a446c3d879be732a541a78c3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jun 2023 11:25:27 +0100 Subject: [PATCH 005/112] :art: Move `to_baseline` code to ABC - Move `to_baseline` code to ABC --- tiatoolbox/models/engine/engine_abc.py | 42 ++++++++++++------- .../models/engine/semantic_segmentor.py | 17 ++++---- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 64ab9c433..d589c6488 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -6,26 +6,23 @@ from tiatoolbox.wsicore.wsimeta import Units -class IOConfigABC(ABC): - """Define an abstract class for holding predictor I/O information. +class ModelIOConfigABC(ABC): + """Defines an abstract class for holding a CNN model's I/O information. Enforcing such that following attributes must always be defined by the subclass. Args: - input_resolutions (list): + input_resolutions (list(dict)): Resolution of each input head of model inference, must be in the same order as `target model.forward()`. - stride_shape (:class:`numpy.ndarray`, list(int)): + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int)): + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): Shape of the largest input in (height, width). """ - # We pre-define to follow enforcement, actual initialisation in init - input_resolutions = None - def __init__( self, input_resolutions: List[dict], @@ -38,8 +35,6 @@ def __init__( self.stride_shape = stride_shape self.input_resolutions = input_resolutions self.output_resolutions = [] - # output_resolutions are equal to input resolutions by default - # but these are customizable. self.resolution_unit = input_resolutions[0]["units"] for variable, value in kwargs.items(): @@ -78,7 +73,7 @@ def scale_to_highest(resolutions: List[dict], units: Units): A list of resolutions where one is defined as `{'resolution': value, 'unit': value}` units (Units): - Units that the resolutions are at. + Resolution units. Returns: :class:`numpy.ndarray`: @@ -98,16 +93,35 @@ def scale_to_highest(resolutions: List[dict], units: Units): return np.min(old_val) / np.array(old_val) return np.array(old_val) / np.max(old_val) + @abstractmethod def to_baseline(self): - """Return a new config object converted to baseline form. + """Returns a new config object converted to baseline form. - This will return a new :class:`IOSegmentorConfig` where + This will return a new :class:`ModelIOConfigABC` where resolutions have been converted to baseline format with the highest possible resolution found in both input and output as reference. """ - raise NotImplementedError + resolutions = self.input_resolutions + self.output_resolutions + save_resolution = getattr(self, "save_resolution", None) + if save_resolution is not None: + resolutions.append(save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + num_input_resolutions = len(self.input_resolutions) + + end_idx = num_input_resolutions + input_resolutions = [ + {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] + ] + + return ModelIOConfigABC( + input_resolutions=input_resolutions, + patch_input_shape=self.patch_input_shape, + stride_shape=self.stride_shape, + **self._kwargs, + ) class EngineABC(ABC): diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index a1eea1319..4d6af6338 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -25,7 +25,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader -from .engine_abc import IOConfigABC +from .engine_abc import ModelIOConfigABC def _estimate_canvas_parameters(sample_prediction, canvas_shape): @@ -97,7 +97,7 @@ def _prepare_save_output( return is_on_drive, count_canvas, cum_canvas -class IOSegmentorConfig(IOConfigABC): +class IOSegmentorConfig(ModelIOConfigABC): """Contain semantic segmentor input and output information. Args: @@ -176,7 +176,7 @@ def __init__( self._validate() def to_baseline(self): - """Return a new config object converted to baseline form. + """Returns a new config object converted to baseline form. This will return a new :class:`IOSegmentorConfig` where resolutions have been converted to baseline format with the @@ -184,18 +184,15 @@ def to_baseline(self): reference. """ - resolutions = self.input_resolutions + self.output_resolutions + new_config = super().to_baseline() + resolutions = new_config.input_resolutions + self.output_resolutions if self.save_resolution is not None: resolutions.append(self.save_resolution) scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(self.input_resolutions) + num_input_resolutions = len(new_config.input_resolutions) num_output_resolutions = len(self.output_resolutions) - end_idx = num_input_resolutions - input_resolutions = [ - {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] - ] end_idx = num_input_resolutions + num_output_resolutions output_resolutions = [ {"units": "baseline", "resolution": v} @@ -206,7 +203,7 @@ def to_baseline(self): if self.save_resolution is not None: save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} return IOSegmentorConfig( - input_resolutions=input_resolutions, + input_resolutions=new_config.input_resolutions, output_resolutions=output_resolutions, patch_input_shape=self.patch_input_shape, patch_output_shape=self.patch_output_shape, From 5f81f9baecaf1939358deb62cac30d2be9f06a2d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:39:56 +0100 Subject: [PATCH 006/112] :art: Move `to_baseline` code to ABC - Move `to_baseline` code to ABC --- tiatoolbox/models/engine/engine_abc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index d589c6488..9f04e0a82 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -6,7 +6,7 @@ from tiatoolbox.wsicore.wsimeta import Units -class ModelIOConfigABC(ABC): +class ModelIOConfigABC(ABC): # noqa: B024 """Defines an abstract class for holding a CNN model's I/O information. Enforcing such that following attributes must always be defined by @@ -93,7 +93,6 @@ def scale_to_highest(resolutions: List[dict], units: Units): return np.min(old_val) / np.array(old_val) return np.array(old_val) / np.max(old_val) - @abstractmethod def to_baseline(self): """Returns a new config object converted to baseline form. From 8270c4ee4d855bf82fafe13734c2b9838cc85de8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 9 Jun 2023 16:18:20 +0100 Subject: [PATCH 007/112] :bug: Fix incorrect calculations for output_resolutions - Fix incorrect calculations for output_resolutions --- tiatoolbox/models/engine/semantic_segmentor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 4d6af6338..d286e5a77 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -185,12 +185,12 @@ def to_baseline(self): """ new_config = super().to_baseline() - resolutions = new_config.input_resolutions + self.output_resolutions + resolutions = self.input_resolutions + self.output_resolutions if self.save_resolution is not None: resolutions.append(self.save_resolution) scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(new_config.input_resolutions) + num_input_resolutions = len(self.input_resolutions) num_output_resolutions = len(self.output_resolutions) end_idx = num_input_resolutions + num_output_resolutions @@ -202,6 +202,7 @@ def to_baseline(self): save_resolution = None if self.save_resolution is not None: save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} + return IOSegmentorConfig( input_resolutions=new_config.input_resolutions, output_resolutions=output_resolutions, From f3b6b963248b7d63f25dd2657b162b8e6faf8295 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 15 Jun 2023 13:51:59 +0100 Subject: [PATCH 008/112] :wrench: Update `PatchPredictor` to use `ModelIOConfigABC` - Update `PatchPredictor` to use `ModelIOConfigABC` Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/patch_predictor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 2584f3bc0..07c69c839 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -13,30 +13,28 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset -from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc, save_as_json from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .engine_abc import ModelIOConfigABC -class IOPatchPredictorConfig(IOSegmentorConfig): + +class IOPatchPredictorConfig(ModelIOConfigABC): """Contains patch predictor input and output information.""" def __init__( self, - patch_input_shape=None, input_resolutions=None, + patch_input_shape=None, stride_shape=None, **kwargs, ): stride_shape = patch_input_shape if stride_shape is None else stride_shape super().__init__( input_resolutions=input_resolutions, - output_resolutions=[], stride_shape=stride_shape, patch_input_shape=patch_input_shape, - patch_output_shape=patch_input_shape, - save_resolution=None, **kwargs, ) From 285d99aa15c252868ae6a7193c7ea25a749730bc Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 09:46:02 +0100 Subject: [PATCH 009/112] :bricks: Define `ModelIOConfigABC` as a dataclass - Define `ModelIOConfigABC` as a dataclass Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/engine_abc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9f04e0a82..0986e53c2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import List, Tuple, Union import numpy as np @@ -6,8 +7,9 @@ from tiatoolbox.wsicore.wsimeta import Units -class ModelIOConfigABC(ABC): # noqa: B024 - """Defines an abstract class for holding a CNN model's I/O information. +@dataclass +class ModelIOConfigABC: + """Defines a data class for holding a deep learning model's I/O information. Enforcing such that following attributes must always be defined by the subclass. From 0d270bfff4e822266046efba453abaa0ff9172f8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 11:24:22 +0100 Subject: [PATCH 010/112] :boom: Remove kwargs - Remove kwargs. - Define dataclass attributes. Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/engine_abc.py | 15 ++++++++------- tiatoolbox/models/engine/patch_predictor.py | 2 -- tiatoolbox/models/engine/semantic_segmentor.py | 3 --- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 0986e53c2..2ee2d5707 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Tuple, Union import numpy as np @@ -25,23 +25,25 @@ class ModelIOConfigABC: """ + input_resolutions: List[dict] + patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] + stride_shape: Union[List[int], np.ndarray, Tuple[int]] + highest_input_resolution: dict + output_resolutions: List[dict] = field(default_factory=list) + resolution_unit: Units = "mpp" + def __init__( self, input_resolutions: List[dict], patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], stride_shape: Union[List[int], np.ndarray, Tuple[int]], - **kwargs, ): - self._kwargs = kwargs self.patch_input_shape = patch_input_shape self.stride_shape = stride_shape self.input_resolutions = input_resolutions self.output_resolutions = [] self.resolution_unit = input_resolutions[0]["units"] - for variable, value in kwargs.items(): - self.__setattr__(variable, value) - if self.resolution_unit == "mpp": self.highest_input_resolution = min( self.input_resolutions, key=lambda x: x["resolution"] @@ -121,7 +123,6 @@ def to_baseline(self): input_resolutions=input_resolutions, patch_input_shape=self.patch_input_shape, stride_shape=self.stride_shape, - **self._kwargs, ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 07c69c839..5ece0bb87 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -28,14 +28,12 @@ def __init__( input_resolutions=None, patch_input_shape=None, stride_shape=None, - **kwargs, ): stride_shape = patch_input_shape if stride_shape is None else stride_shape super().__init__( input_resolutions=input_resolutions, stride_shape=stride_shape, patch_input_shape=patch_input_shape, - **kwargs, ) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index d286e5a77..835c91095 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -161,13 +161,11 @@ def __init__( stride_shape: Union[List[int], np.ndarray, Tuple[int]], patch_output_shape: Union[List[int], np.ndarray], save_resolution: dict = None, - **kwargs, ): super().__init__( input_resolutions=input_resolutions, patch_input_shape=patch_input_shape, stride_shape=stride_shape, - **kwargs, ) self.patch_output_shape = patch_output_shape self.output_resolutions = output_resolutions @@ -210,7 +208,6 @@ def to_baseline(self): patch_output_shape=self.patch_output_shape, stride_shape=self.stride_shape, save_resolution=save_resolution, - **self._kwargs, ) From 9e4266aa5badf90cded48e5aa94d3db6d1fc9b7b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 13:27:25 +0100 Subject: [PATCH 011/112] :sparkles: Add `IOInstanceSegmentorConfig` - Add `IOInstanceSegmentorConfig` Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../models/engine/multi_task_segmentor.py | 9 +- .../engine/nucleus_instance_segmentor.py | 82 ++++++++++++++++++- .../models/engine/semantic_segmentor.py | 5 +- 3 files changed, 87 insertions(+), 9 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 4ba11774b..04e84803b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -34,10 +34,9 @@ NucleusInstanceSegmentor, _process_instance_predictions, ) -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) + +from .nucleus_instance_segmentor import IOInstanceSegmentorConfig +from .semantic_segmentor import WSIStreamDataset # Python is yet to be able to natively pickle Object method/static method. @@ -279,7 +278,7 @@ def __init__( def _predict_one_wsi( self, wsi_idx: int, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, ): diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 2696acb44..e4b920f65 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -2,7 +2,7 @@ import uuid from collections import deque -from typing import Callable, List, Union +from typing import Callable, List, Tuple, Union # replace with the sql database once the PR in place import joblib @@ -298,6 +298,86 @@ def _process_tile_predictions( return new_inst_dict, remove_insts_in_orig +class IOInstanceSegmentorConfig(IOSegmentorConfig): + """Contain instance segmentor input and output information. + + Args: + input_resolutions (list): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + output_resolutions (list): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + stride_shape (:class:`numpy.ndarray`, list(int)): + Stride in (x, y) direction for patch extraction. + patch_input_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest input in (height, width). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + margin (int): + Tile margin to accumulate the output. + + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + + Examples: + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... ) + + """ + + margin: int + + def __init__( + self, + input_resolutions: List[dict], + output_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], + patch_output_shape: Union[List[int], np.ndarray], + save_resolution: dict = None, + margin: int = None, + ): + super().__init__( + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + patch_output_shape=patch_output_shape, + save_resolution=save_resolution, + ) + self.margin = margin + + self._validate() + + class NucleusInstanceSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 835c91095..9a87bc4d6 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -149,9 +149,8 @@ class IOSegmentorConfig(ModelIOConfigABC): """ - # We pre-define to follow enforcement, actual initialisation in init - input_resolutions = None - output_resolutions = None + patch_output_shape: Union[List[int], np.ndarray] + save_resolution: dict = None def __init__( self, From d0ea031be79b9564a89209a7fa6cb8e067beea30 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 14:10:59 +0100 Subject: [PATCH 012/112] :wrench: Update yaml for `IOInstanceSegmentorConfig` - Update yaml for `IOInstanceSegmentorConfig` Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/data/pretrained_model.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 2243cebb0..208e0c4c0 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -644,7 +644,7 @@ hovernet_fast-pannuke: num_types: 6 mode: "fast" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: nucleus_instance_segmentor.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -667,7 +667,7 @@ hovernet_fast-monusac: num_types: 5 mode: "fast" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: nucleus_instance_segmentor.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -690,7 +690,7 @@ hovernet_original-consep: num_types: 5 mode: "original" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: nucleus_instance_segmentor.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -713,7 +713,7 @@ hovernet_original-kumar: num_types: null # None in python ?, only do instance segmentation mode: "original" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: nucleus_instance_segmentor.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -735,7 +735,7 @@ hovernetplus-oed: num_types: 3 num_layers: 5 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: nucleus_instance_segmentor.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.50} From fa70c40fb24876109bea02798e037a50a23566b2 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 14:17:56 +0100 Subject: [PATCH 013/112] :wrench: Update yaml for `IOInstanceSegmentorConfig` - Update yaml for `IOInstanceSegmentorConfig` Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/nucleus_instance_segmentor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index e4b920f65..0b891990a 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -318,7 +318,8 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): Resolution to save all output. margin (int): Tile margin to accumulate the output. - + tile_shape (tuple(int, int)): + Tile shape to process the WSI. Examples: >>> # Defining io for a network having 1 input and 1 output at the @@ -354,6 +355,7 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): """ margin: int + tile_shape: Tuple[int, int] def __init__( self, @@ -364,6 +366,7 @@ def __init__( patch_output_shape: Union[List[int], np.ndarray], save_resolution: dict = None, margin: int = None, + tile_shape: Tuple[int, int] = None, ): super().__init__( input_resolutions=input_resolutions, @@ -374,6 +377,7 @@ def __init__( save_resolution=save_resolution, ) self.margin = margin + self.tile_shape = tile_shape self._validate() From 70da57ec3ed7e0346d600285fefc8df2652f225e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 16:06:57 +0100 Subject: [PATCH 014/112] :wrench: Update `to_baseline` for `IOInstanceSegmentorConfig` - Update `to_baseline` for `IOInstanceSegmentorConfig` Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_multi_task_segmentor.py | 14 +++++---- tiatoolbox/models/__init__.py | 6 +++- .../models/engine/multi_task_segmentor.py | 6 ++-- .../engine/nucleus_instance_segmentor.py | 30 ++++++++++++++++--- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py index fd2e06aa4..c0c10d15b 100644 --- a/tests/models/test_multi_task_segmentor.py +++ b/tests/models/test_multi_task_segmentor.py @@ -13,7 +13,11 @@ import numpy as np import pytest -from tiatoolbox.models import IOSegmentorConfig, MultiTaskSegmentor, SemanticSegmentor +from tiatoolbox.models import ( + IOInstanceSegmentorConfig, + MultiTaskSegmentor, + SemanticSegmentor, +) from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite from tiatoolbox.utils.metrics import f1_detection @@ -178,7 +182,7 @@ def test_masked_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -302,7 +306,7 @@ def test_empty_image(tmp_path): output_types=["semantic"], ) - bcc_wsi_ioconfig = IOSegmentorConfig( + bcc_wsi_ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[{"units": "mpp", "resolution": 0.25}], tile_shape=2048, @@ -350,7 +354,7 @@ def test_functionality_semantic(remote_sample, tmp_path): output_types=["semantic"], ) - bcc_wsi_ioconfig = IOSegmentorConfig( + bcc_wsi_ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[{"units": "mpp", "resolution": 0.25}], tile_shape=2048, @@ -391,7 +395,7 @@ def test_crash_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index a1057495d..b76831085 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -9,7 +9,10 @@ from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN from .engine.multi_task_segmentor import MultiTaskSegmentor -from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from .engine.nucleus_instance_segmentor import ( + IOInstanceSegmentorConfig, + NucleusInstanceSegmentor, +) from .engine.patch_predictor import ( IOPatchPredictorConfig, PatchDataset, @@ -35,4 +38,5 @@ "NucleusInstanceSegmentor", "PatchPredictor", "SemanticSegmentor", + "IOInstanceSegmentorConfig", ] diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 04e84803b..a5f763869 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -59,7 +59,7 @@ def _process_tile_predictions( using the output from each task. Args: - ioconfig (:class:`IOSegmentorConfig`): Object defines information + ioconfig (:class:`IOInstanceSegmentorConfig`): Object defines information about input and output placement of patches. tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as (top_left_x, top_left_y, bottom_x, bottom_y). @@ -286,8 +286,8 @@ def _predict_one_wsi( Args: wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): Object which defines I/O placement during - inference and when assembling back to full tile/wsi. + ioconfig (IOInstanceSegmentorConfig): Object which defines I/O placement + during inference and when assembling back to full tile/wsi. loader (torch.Dataloader): The loader object which return batch of data to be input to model. save_path (str): Location to save output prediction as well as possible diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 0b891990a..be258ae4f 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -381,6 +381,28 @@ def __init__( self._validate() + def to_baseline(self): + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + new_config = super().to_baseline() + + return IOInstanceSegmentorConfig( + input_resolutions=new_config.input_resolutions, + output_resolutions=new_config.output_resolutions, + patch_input_shape=self.patch_input_shape, + patch_output_shape=self.patch_output_shape, + stride_shape=self.stride_shape, + save_resolution=new_config.save_resolution, + margin=self.margin, + tile_shape=self.tile_shape, + ) + class NucleusInstanceSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. @@ -485,7 +507,7 @@ def __init__( @staticmethod def _get_tile_info( image_shape: Union[List[int], np.ndarray], - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, ): """Generating tile information. @@ -503,7 +525,7 @@ def _get_tile_info( image_shape (:class:`numpy.ndarray`, list(int)): The shape of WSI to extract the tile from, assumed to be in `[width, height]`. - ioconfig (:obj:IOSegmentorConfig): + ioconfig (:obj:IOInstanceSegmentorConfig): The input and output configuration objects. Returns: @@ -740,7 +762,7 @@ def _infer_once(self): def _predict_one_wsi( self, wsi_idx: int, - ioconfig: IOSegmentorConfig, + ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, ): @@ -749,7 +771,7 @@ def _predict_one_wsi( Args: wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOSegmentorConfig): + ioconfig (IOInstanceSegmentorConfig): Object which defines I/O placement during inference and when assembling back to full tile/wsi. save_path (str): From 17e9a1f36a7f2fb6abd70f8857ee9eaf2adc2931 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 16:34:11 +0100 Subject: [PATCH 015/112] :wrench: Update IOInstanceSegmentorConfig in test_nucleus_instance_segmentor.py - Update IOInstanceSegmentorConfig in test_nucleus_instance_segmentor.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../models/test_nucleus_instance_segmentor.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index d7f711dea..8da120df1 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -15,7 +15,7 @@ from tiatoolbox import cli from tiatoolbox.models import ( - IOSegmentorConfig, + IOInstanceSegmentorConfig, NucleusInstanceSegmentor, SemanticSegmentor, ) @@ -62,7 +62,7 @@ def helper_tile_info(): # | 12 | 13 | 14 | 15 | # --------------------- # ! assume flag index ordering: left right top bottom - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[ {"units": "mpp", "resolution": 0.25}, @@ -70,7 +70,7 @@ def helper_tile_info(): {"units": "mpp", "resolution": 0.25}, ], margin=1, - tile_shape=[4, 4], + tile_shape=(4, 4), stride_shape=[4, 4], patch_input_shape=[4, 4], patch_output_shape=[4, 4], @@ -245,7 +245,7 @@ def test_crash_segmentor(remote_sample, tmp_path): # resolution for travis testing, not the correct ones resolution = 4.0 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -253,7 +253,7 @@ def test_crash_segmentor(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -297,14 +297,14 @@ def test_functionality_travis(remote_sample, tmp_path): # * test run on wsi, test run with worker # resolution for travis testing, not the correct ones - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[1024, 1024], + tile_shape=(1024, 1024), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -338,7 +338,7 @@ def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path): mini_wsi_svs = pathlib.Path(remote_sample("wsi4_512_512_svs")) resolution = 0.5 - ioconfig = IOSegmentorConfig( + ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": resolution}], output_resolutions=[ {"units": "mpp", "resolution": resolution}, @@ -346,7 +346,7 @@ def test_functionality_merge_tile_predictions_travis(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], From a976abee679deb00a3a769cf25d4761e8f685df8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 22:30:51 +0100 Subject: [PATCH 016/112] :bug: Update cli for IOInstanceSegmentorConfig - Update cli for IOInstanceSegmentorConfig Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/cli/nucleus_instance_segment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index 4909d4dd6..5d6fff055 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -63,7 +63,7 @@ def nucleus_instance_segment( verbose, ): """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import IOSegmentorConfig, NucleusInstanceSegmentor + from tiatoolbox.models import IOInstanceSegmentorConfig, NucleusInstanceSegmentor from tiatoolbox.utils import save_as_json files_all, masks_all, output_path = prepare_model_cli( @@ -74,7 +74,7 @@ def nucleus_instance_segment( ) ioconfig = prepare_ioconfig_seg( - IOSegmentorConfig, pretrained_weights, yaml_config_path + IOInstanceSegmentorConfig, pretrained_weights, yaml_config_path ) predictor = NucleusInstanceSegmentor( From 4330e87cc081006fc1d025574f7c5ddfd68472d4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 22:33:22 +0100 Subject: [PATCH 017/112] :fire: Remove unnecessary test after kwargs removal - Remove unnecessary test after kwargs removal Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_patch_predictor.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index f69274623..5e3ebea2d 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -443,19 +443,6 @@ def __getitem__(self, idx): # ------------------------------------------------------------------------------------- -def test_io_patch_predictor_config(): - """Test for IOConfig.""" - # test for creating - cfg = IOPatchPredictorConfig( - patch_input_shape=[224, 224], - stride_shape=[224, 224], - input_resolutions=[{"resolution": 0.5, "units": "mpp"}], - # test adding random kwarg and they should be accessible as kwargs - crop_from_source=True, - ) - assert cfg.crop_from_source - - # ------------------------------------------------------------------------------------- # Engine # ------------------------------------------------------------------------------------- From 1963923bc359a35db01a1435435ad897e6d42154 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 23:17:59 +0100 Subject: [PATCH 018/112] :truck: Move ioconfigs to io_config.py - Move ioconfigs to io_config.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_feature_extractor.py | 6 +- tests/models/test_semantic_segmentation.py | 7 +- tiatoolbox/data/pretrained_model.yaml | 100 ++--- tiatoolbox/models/__init__.py | 16 +- tiatoolbox/models/engine/engine_abc.py | 125 ------ tiatoolbox/models/engine/io_config.py | 361 ++++++++++++++++++ .../models/engine/multi_task_segmentor.py | 2 +- .../engine/nucleus_instance_segmentor.py | 110 +----- tiatoolbox/models/engine/patch_predictor.py | 21 +- .../models/engine/semantic_segmentor.py | 117 +----- 10 files changed, 429 insertions(+), 436 deletions(-) create mode 100644 tiatoolbox/models/engine/io_config.py diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py index a38eddaab..d7b89a02a 100644 --- a/tests/models/test_feature_extractor.py +++ b/tests/models/test_feature_extractor.py @@ -7,11 +7,9 @@ import numpy as np import torch +from tiatoolbox.models import IOSegmentorConfig from tiatoolbox.models.architecture.vanilla import CNNBackbone -from tiatoolbox.models.engine.semantic_segmentor import ( - DeepFeatureExtractor, - IOSegmentorConfig, -) +from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.wsicore.wsireader import WSIReader diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index cde7221e3..586ce806e 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -19,13 +19,10 @@ from click.testing import CliRunner from tiatoolbox import cli -from tiatoolbox.models import SemanticSegmentor +from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, - WSIStreamDataset, -) +from tiatoolbox.models.engine.semantic_segmentor import WSIStreamDataset from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imread, imwrite diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 208e0c4c0..6caf4a35b 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -6,7 +6,7 @@ alexnet-kather100k: backbone: alexnet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -20,7 +20,7 @@ resnet18-kather100k: backbone: resnet18 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -34,7 +34,7 @@ resnet34-kather100k: backbone: resnet34 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -48,7 +48,7 @@ resnet50-kather100k: backbone: resnet50 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -62,7 +62,7 @@ resnet101-kather100k: backbone: resnet101 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -76,7 +76,7 @@ resnext50_32x4d-kather100k: backbone: resnext50_32x4d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -90,7 +90,7 @@ resnext101_32x8d-kather100k: backbone: resnext101_32x8d num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -104,7 +104,7 @@ wide_resnet50_2-kather100k: backbone: wide_resnet50_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -118,7 +118,7 @@ wide_resnet101_2-kather100k: backbone: wide_resnet101_2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -132,7 +132,7 @@ densenet121-kather100k: backbone: densenet121 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -146,7 +146,7 @@ densenet161-kather100k: backbone: densenet161 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -160,7 +160,7 @@ densenet169-kather100k: backbone: densenet169 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -174,7 +174,7 @@ densenet201-kather100k: backbone: densenet201 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -188,7 +188,7 @@ mobilenet_v2-kather100k: backbone: mobilenet_v2 num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -202,7 +202,7 @@ mobilenet_v3_large-kather100k: backbone: mobilenet_v3_large num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -216,7 +216,7 @@ mobilenet_v3_small-kather100k: backbone: mobilenet_v3_small num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -230,7 +230,7 @@ googlenet-kather100k: backbone: googlenet num_classes: 9 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -245,7 +245,7 @@ alexnet-pcam: backbone: alexnet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -259,7 +259,7 @@ resnet18-pcam: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -273,7 +273,7 @@ resnet34-pcam: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -287,7 +287,7 @@ resnet50-pcam: backbone: resnet50 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -301,7 +301,7 @@ resnet101-pcam: backbone: resnet101 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -315,7 +315,7 @@ resnext50_32x4d-pcam: backbone: resnext50_32x4d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -329,7 +329,7 @@ resnext101_32x8d-pcam: backbone: resnext101_32x8d num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -343,7 +343,7 @@ wide_resnet50_2-pcam: backbone: wide_resnet50_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -357,7 +357,7 @@ wide_resnet101_2-pcam: backbone: wide_resnet101_2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -371,7 +371,7 @@ densenet121-pcam: backbone: densenet121 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -385,7 +385,7 @@ densenet161-pcam: backbone: densenet161 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -399,7 +399,7 @@ densenet169-pcam: backbone: densenet169 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -413,7 +413,7 @@ densenet201-pcam: backbone: densenet201 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -427,7 +427,7 @@ mobilenet_v2-pcam: backbone: mobilenet_v2 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -441,7 +441,7 @@ mobilenet_v3_large-pcam: backbone: mobilenet_v3_large num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -455,7 +455,7 @@ mobilenet_v3_small-pcam: backbone: mobilenet_v3_small num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -469,7 +469,7 @@ googlenet-pcam: backbone: googlenet num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [96, 96] stride_shape: [96, 96] @@ -484,7 +484,7 @@ resnet18-idars-tumour: backbone: resnet18 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [512, 512] stride_shape: [512, 512] @@ -497,7 +497,7 @@ resnet34-idars-msi: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -510,7 +510,7 @@ resnet34-idars-braf: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -523,7 +523,7 @@ resnet34-idars-cimp: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -536,7 +536,7 @@ resnet34-idars-cin: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -549,7 +549,7 @@ resnet34-idars-tp53: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -562,7 +562,7 @@ resnet34-idars-hm: backbone: resnet34 num_classes: 2 ioconfig: - class: patch_predictor.IOPatchPredictorConfig + class: io_config.IOPatchPredictorConfig kwargs: patch_input_shape: [224, 224] stride_shape: [224, 224] @@ -579,7 +579,7 @@ fcn-tissue_mask: encoder: "resnet50" decoder_block: [3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 2.0} @@ -600,7 +600,7 @@ fcn_resnet50_unet-bcss: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'mpp', 'resolution': 0.25} @@ -625,7 +625,7 @@ unet_tissue_mask_tsef: encoder: "resnet50" decoder_block: [3, 3] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 1.0} @@ -644,7 +644,7 @@ hovernet_fast-pannuke: num_types: 6 mode: "fast" ioconfig: - class: nucleus_instance_segmentor.IOInstanceSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -667,7 +667,7 @@ hovernet_fast-monusac: num_types: 5 mode: "fast" ioconfig: - class: nucleus_instance_segmentor.IOInstanceSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -690,7 +690,7 @@ hovernet_original-consep: num_types: 5 mode: "original" ioconfig: - class: nucleus_instance_segmentor.IOInstanceSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -713,7 +713,7 @@ hovernet_original-kumar: num_types: null # None in python ?, only do instance segmentation mode: "original" ioconfig: - class: nucleus_instance_segmentor.IOInstanceSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} @@ -735,7 +735,7 @@ hovernetplus-oed: num_types: 3 num_layers: 5 ioconfig: - class: nucleus_instance_segmentor.IOInstanceSegmentorConfig + class: io_config.IOInstanceSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.50} @@ -759,7 +759,7 @@ micronet-consep: num_input_channels: 3 num_output_channels: 2 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {"units": "mpp", "resolution": 0.25} diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index b76831085..76c4389ab 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -8,20 +8,16 @@ from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick from .architecture.sccnn import SCCNN -from .engine.multi_task_segmentor import MultiTaskSegmentor -from .engine.nucleus_instance_segmentor import ( +from .engine.io_config import ( IOInstanceSegmentorConfig, - NucleusInstanceSegmentor, -) -from .engine.patch_predictor import ( IOPatchPredictorConfig, - PatchDataset, - PatchPredictor, - WSIPatchDataset, + IOSegmentorConfig, ) +from .engine.multi_task_segmentor import MultiTaskSegmentor +from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from .engine.patch_predictor import PatchDataset, PatchPredictor, WSIPatchDataset from .engine.semantic_segmentor import ( DeepFeatureExtractor, - IOSegmentorConfig, SemanticSegmentor, WSIStreamDataset, ) @@ -38,5 +34,7 @@ "NucleusInstanceSegmentor", "PatchPredictor", "SemanticSegmentor", + "IOPatchPredictorConfig", + "IOSegmentorConfig", "IOInstanceSegmentorConfig", ] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 2ee2d5707..661c882f2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,129 +1,4 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import List, Tuple, Union - -import numpy as np - -from tiatoolbox.wsicore.wsimeta import Units - - -@dataclass -class ModelIOConfigABC: - """Defines a data class for holding a deep learning model's I/O information. - - Enforcing such that following attributes must always be defined by - the subclass. - - Args: - input_resolutions (list(dict)): - Resolution of each input head of model inference, must be in - the same order as `target model.forward()`. - stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): - Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): - Shape of the largest input in (height, width). - - """ - - input_resolutions: List[dict] - patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] - stride_shape: Union[List[int], np.ndarray, Tuple[int]] - highest_input_resolution: dict - output_resolutions: List[dict] = field(default_factory=list) - resolution_unit: Units = "mpp" - - def __init__( - self, - input_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - ): - self.patch_input_shape = patch_input_shape - self.stride_shape = stride_shape - self.input_resolutions = input_resolutions - self.output_resolutions = [] - self.resolution_unit = input_resolutions[0]["units"] - - if self.resolution_unit == "mpp": - self.highest_input_resolution = min( - self.input_resolutions, key=lambda x: x["resolution"] - ) - else: - self.highest_input_resolution = max( - self.input_resolutions, key=lambda x: x["resolution"] - ) - - def _validate(self): - """Validate the data format.""" - resolutions = self.input_resolutions + self.output_resolutions - units = [v["units"] for v in resolutions] - units = np.unique(units) - if len(units) != 1 or units[0] not in [ - "power", - "baseline", - "mpp", - ]: - raise ValueError(f"Invalid resolution units `{units[0]}`.") - - @staticmethod - def scale_to_highest(resolutions: List[dict], units: Units): - """Get the scaling factor from input resolutions. - - This will convert resolutions to a scaling factor with respect to - the highest resolution found in the input resolutions list. - - Args: - resolutions (list): - A list of resolutions where one is defined as - `{'resolution': value, 'unit': value}` - units (Units): - Resolution units. - - Returns: - :class:`numpy.ndarray`: - A 1D array of scaling factors having the same length as - `resolutions` - - """ - old_val = [v["resolution"] for v in resolutions] - if units not in ["baseline", "mpp", "power"]: - raise ValueError( - f"Unknown units `{units}`. " - "Units should be one of 'baseline', 'mpp' or 'power'." - ) - if units == "baseline": - return old_val - if units == "mpp": - return np.min(old_val) / np.array(old_val) - return np.array(old_val) / np.max(old_val) - - def to_baseline(self): - """Returns a new config object converted to baseline form. - - This will return a new :class:`ModelIOConfigABC` where - resolutions have been converted to baseline format with the - highest possible resolution found in both input and output as - reference. - - """ - resolutions = self.input_resolutions + self.output_resolutions - save_resolution = getattr(self, "save_resolution", None) - if save_resolution is not None: - resolutions.append(save_resolution) - - scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(self.input_resolutions) - - end_idx = num_input_resolutions - input_resolutions = [ - {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] - ] - - return ModelIOConfigABC( - input_resolutions=input_resolutions, - patch_input_shape=self.patch_input_shape, - stride_shape=self.stride_shape, - ) class EngineABC(ABC): diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py new file mode 100644 index 000000000..ee70786e9 --- /dev/null +++ b/tiatoolbox/models/engine/io_config.py @@ -0,0 +1,361 @@ +from dataclasses import dataclass, field +from typing import List, Tuple, Union + +import numpy as np + +from tiatoolbox.wsicore.wsimeta import Units + + +@dataclass +class ModelIOConfigABC: + """Defines a data class for holding a deep learning model's I/O information. + + Enforcing such that following attributes must always be defined by + the subclass. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + + """ + + input_resolutions: List[dict] + patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] + stride_shape: Union[List[int], np.ndarray, Tuple[int]] + highest_input_resolution: dict + output_resolutions: List[dict] = field(default_factory=list) + resolution_unit: Units = "mpp" + + def __init__( + self, + input_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], + ): + self.patch_input_shape = patch_input_shape + self.stride_shape = stride_shape + self.input_resolutions = input_resolutions + self.output_resolutions = [] + self.resolution_unit = input_resolutions[0]["units"] + + if self.resolution_unit == "mpp": + self.highest_input_resolution = min( + self.input_resolutions, key=lambda x: x["resolution"] + ) + else: + self.highest_input_resolution = max( + self.input_resolutions, key=lambda x: x["resolution"] + ) + + def _validate(self): + """Validate the data format.""" + resolutions = self.input_resolutions + self.output_resolutions + units = [v["units"] for v in resolutions] + units = np.unique(units) + if len(units) != 1 or units[0] not in [ + "power", + "baseline", + "mpp", + ]: + raise ValueError(f"Invalid resolution units `{units[0]}`.") + + @staticmethod + def scale_to_highest(resolutions: List[dict], units: Units): + """Get the scaling factor from input resolutions. + + This will convert resolutions to a scaling factor with respect to + the highest resolution found in the input resolutions list. + + Args: + resolutions (list): + A list of resolutions where one is defined as + `{'resolution': value, 'unit': value}` + units (Units): + Resolution units. + + Returns: + :class:`numpy.ndarray`: + A 1D array of scaling factors having the same length as + `resolutions` + + """ + old_val = [v["resolution"] for v in resolutions] + if units not in ["baseline", "mpp", "power"]: + raise ValueError( + f"Unknown units `{units}`. " + "Units should be one of 'baseline', 'mpp' or 'power'." + ) + if units == "baseline": + return old_val + if units == "mpp": + return np.min(old_val) / np.array(old_val) + return np.array(old_val) / np.max(old_val) + + def to_baseline(self): + """Returns a new config object converted to baseline form. + + This will return a new :class:`ModelIOConfigABC` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + resolutions = self.input_resolutions + self.output_resolutions + save_resolution = getattr(self, "save_resolution", None) + if save_resolution is not None: + resolutions.append(save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + num_input_resolutions = len(self.input_resolutions) + + end_idx = num_input_resolutions + input_resolutions = [ + {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] + ] + + return ModelIOConfigABC( + input_resolutions=input_resolutions, + patch_input_shape=self.patch_input_shape, + stride_shape=self.stride_shape, + ) + + +class IOSegmentorConfig(ModelIOConfigABC): + """Contain semantic segmentor input and output information. + + Args: + input_resolutions (list): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + output_resolutions (list): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + stride_shape (:class:`numpy.ndarray`, list(int)): + Stride in (x, y) direction for patch extraction. + patch_input_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest input in (height, width). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + + Examples: + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... ) + + """ + + patch_output_shape: Union[List[int], np.ndarray] + save_resolution: dict = None + + def __init__( + self, + input_resolutions: List[dict], + output_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], + patch_output_shape: Union[List[int], np.ndarray], + save_resolution: dict = None, + ): + super().__init__( + input_resolutions=input_resolutions, + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + ) + self.patch_output_shape = patch_output_shape + self.output_resolutions = output_resolutions + self.save_resolution = save_resolution + + self._validate() + + def to_baseline(self): + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + new_config = super().to_baseline() + resolutions = self.input_resolutions + self.output_resolutions + if self.save_resolution is not None: + resolutions.append(self.save_resolution) + + scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) + num_input_resolutions = len(self.input_resolutions) + num_output_resolutions = len(self.output_resolutions) + + end_idx = num_input_resolutions + num_output_resolutions + output_resolutions = [ + {"units": "baseline", "resolution": v} + for v in scale_factors[num_input_resolutions:end_idx] + ] + + save_resolution = None + if self.save_resolution is not None: + save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} + + return IOSegmentorConfig( + input_resolutions=new_config.input_resolutions, + output_resolutions=output_resolutions, + patch_input_shape=self.patch_input_shape, + patch_output_shape=self.patch_output_shape, + stride_shape=self.stride_shape, + save_resolution=save_resolution, + ) + + +class IOPatchPredictorConfig(ModelIOConfigABC): + """Contains patch predictor input and output information.""" + + def __init__( + self, + input_resolutions=None, + patch_input_shape=None, + stride_shape=None, + ): + stride_shape = patch_input_shape if stride_shape is None else stride_shape + super().__init__( + input_resolutions=input_resolutions, + stride_shape=stride_shape, + patch_input_shape=patch_input_shape, + ) + + +class IOInstanceSegmentorConfig(IOSegmentorConfig): + """Contain instance segmentor input and output information. + + Args: + input_resolutions (list): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + output_resolutions (list): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + stride_shape (:class:`numpy.ndarray`, list(int)): + Stride in (x, y) direction for patch extraction. + patch_input_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest input in (height, width). + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Examples: + >>> # Defining io for a network having 1 input and 1 output at the + >>> # same resolution + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... ) + + Examples: + >>> # Defining io for a network having 3 input and 2 output + >>> # at the same resolution, the output is then merged at a + >>> # different resolution. + >>> ioconfig = IOSegmentorConfig( + ... input_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... {"units": "mpp", "resolution": 0.75}, + ... ], + ... output_resolutions=[ + ... {"units": "mpp", "resolution": 0.25}, + ... {"units": "mpp", "resolution": 0.50}, + ... ], + ... patch_input_shape=[2048, 2048], + ... patch_output_shape=[1024, 1024], + ... stride_shape=[512, 512], + ... save_resolution={"units": "mpp", "resolution": 4.0}, + ... ) + + """ + + margin: int + tile_shape: Tuple[int, int] + + def __init__( + self, + input_resolutions: List[dict], + output_resolutions: List[dict], + patch_input_shape: Union[List[int], np.ndarray], + stride_shape: Union[List[int], np.ndarray, Tuple[int]], + patch_output_shape: Union[List[int], np.ndarray], + save_resolution: dict = None, + margin: int = None, + tile_shape: Tuple[int, int] = None, + ): + super().__init__( + input_resolutions=input_resolutions, + output_resolutions=output_resolutions, + patch_input_shape=patch_input_shape, + stride_shape=stride_shape, + patch_output_shape=patch_output_shape, + save_resolution=save_resolution, + ) + self.margin = margin + self.tile_shape = tile_shape + + self._validate() + + def to_baseline(self): + """Returns a new config object converted to baseline form. + + This will return a new :class:`IOSegmentorConfig` where + resolutions have been converted to baseline format with the + highest possible resolution found in both input and output as + reference. + + """ + new_config = super().to_baseline() + + return IOInstanceSegmentorConfig( + input_resolutions=new_config.input_resolutions, + output_resolutions=new_config.output_resolutions, + patch_input_shape=self.patch_input_shape, + patch_output_shape=self.patch_output_shape, + stride_shape=self.stride_shape, + save_resolution=new_config.save_resolution, + margin=self.margin, + tile_shape=self.tile_shape, + ) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index a5f763869..6c0654960 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -35,7 +35,7 @@ _process_instance_predictions, ) -from .nucleus_instance_segmentor import IOInstanceSegmentorConfig +from .. import IOInstanceSegmentorConfig from .semantic_segmentor import WSIStreamDataset diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index be258ae4f..a389a5932 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -2,7 +2,7 @@ import uuid from collections import deque -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Union # replace with the sql database once the PR in place import joblib @@ -12,8 +12,8 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree +from tiatoolbox.models import IOInstanceSegmentorConfig from tiatoolbox.models.engine.semantic_segmentor import ( - IOSegmentorConfig, SemanticSegmentor, WSIStreamDataset, ) @@ -298,112 +298,6 @@ def _process_tile_predictions( return new_inst_dict, remove_insts_in_orig -class IOInstanceSegmentorConfig(IOSegmentorConfig): - """Contain instance segmentor input and output information. - - Args: - input_resolutions (list): - Resolution of each input head of model inference, must be in - the same order as `target model.forward()`. - output_resolutions (list): - Resolution of each output head from model inference, must be - in the same order as target model.infer_batch(). - stride_shape (:class:`numpy.ndarray`, list(int)): - Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest input in (height, width). - patch_output_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest output in (height, width). - save_resolution (dict): - Resolution to save all output. - margin (int): - Tile margin to accumulate the output. - tile_shape (tuple(int, int)): - Tile shape to process the WSI. - - Examples: - >>> # Defining io for a network having 1 input and 1 output at the - >>> # same resolution - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - - Examples: - >>> # Defining io for a network having 3 input and 2 output - >>> # at the same resolution, the output is then merged at a - >>> # different resolution. - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... {"units": "mpp", "resolution": 0.75}, - ... ], - ... output_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... save_resolution={"units": "mpp", "resolution": 4.0}, - ... ) - - """ - - margin: int - tile_shape: Tuple[int, int] - - def __init__( - self, - input_resolutions: List[dict], - output_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - patch_output_shape: Union[List[int], np.ndarray], - save_resolution: dict = None, - margin: int = None, - tile_shape: Tuple[int, int] = None, - ): - super().__init__( - input_resolutions=input_resolutions, - output_resolutions=output_resolutions, - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - patch_output_shape=patch_output_shape, - save_resolution=save_resolution, - ) - self.margin = margin - self.tile_shape = tile_shape - - self._validate() - - def to_baseline(self): - """Returns a new config object converted to baseline form. - - This will return a new :class:`IOSegmentorConfig` where - resolutions have been converted to baseline format with the - highest possible resolution found in both input and output as - reference. - - """ - new_config = super().to_baseline() - - return IOInstanceSegmentorConfig( - input_resolutions=new_config.input_resolutions, - output_resolutions=new_config.output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - stride_shape=self.stride_shape, - save_resolution=new_config.save_resolution, - margin=self.margin, - tile_shape=self.tile_shape, - ) - - class NucleusInstanceSegmentor(SemanticSegmentor): """An engine specifically designed to handle tiles or WSIs inference. diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 5ece0bb87..c697cf494 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -17,24 +17,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader -from .engine_abc import ModelIOConfigABC - - -class IOPatchPredictorConfig(ModelIOConfigABC): - """Contains patch predictor input and output information.""" - - def __init__( - self, - input_resolutions=None, - patch_input_shape=None, - stride_shape=None, - ): - stride_shape = patch_input_shape if stride_shape is None else stride_shape - super().__init__( - input_resolutions=input_resolutions, - stride_shape=stride_shape, - patch_input_shape=patch_input_shape, - ) +from .. import IOPatchPredictorConfig class PatchPredictor: @@ -457,7 +440,7 @@ def _update_ioconfig( """ Args: - ioconfig (IOPatchPredictorConfig): + ioconfig (tiatoolbox.models.IOPatchPredictorConfig): patch_input_shape (tuple): Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 9a87bc4d6..2c1ec6c17 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -25,7 +25,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader -from .engine_abc import ModelIOConfigABC +from .. import IOSegmentorConfig def _estimate_canvas_parameters(sample_prediction, canvas_shape): @@ -97,119 +97,6 @@ def _prepare_save_output( return is_on_drive, count_canvas, cum_canvas -class IOSegmentorConfig(ModelIOConfigABC): - """Contain semantic segmentor input and output information. - - Args: - input_resolutions (list): - Resolution of each input head of model inference, must be in - the same order as `target model.forward()`. - output_resolutions (list): - Resolution of each output head from model inference, must be - in the same order as target model.infer_batch(). - stride_shape (:class:`numpy.ndarray`, list(int)): - Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest input in (height, width). - patch_output_shape (:class:`numpy.ndarray`, list(int)): - Shape of the largest output in (height, width). - save_resolution (dict): - Resolution to save all output. - - Examples: - >>> # Defining io for a network having 1 input and 1 output at the - >>> # same resolution - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... ) - - Examples: - >>> # Defining io for a network having 3 input and 2 output - >>> # at the same resolution, the output is then merged at a - >>> # different resolution. - >>> ioconfig = IOSegmentorConfig( - ... input_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... {"units": "mpp", "resolution": 0.75}, - ... ], - ... output_resolutions=[ - ... {"units": "mpp", "resolution": 0.25}, - ... {"units": "mpp", "resolution": 0.50}, - ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], - ... save_resolution={"units": "mpp", "resolution": 4.0}, - ... ) - - """ - - patch_output_shape: Union[List[int], np.ndarray] - save_resolution: dict = None - - def __init__( - self, - input_resolutions: List[dict], - output_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - patch_output_shape: Union[List[int], np.ndarray], - save_resolution: dict = None, - ): - super().__init__( - input_resolutions=input_resolutions, - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - ) - self.patch_output_shape = patch_output_shape - self.output_resolutions = output_resolutions - self.save_resolution = save_resolution - - self._validate() - - def to_baseline(self): - """Returns a new config object converted to baseline form. - - This will return a new :class:`IOSegmentorConfig` where - resolutions have been converted to baseline format with the - highest possible resolution found in both input and output as - reference. - - """ - new_config = super().to_baseline() - resolutions = self.input_resolutions + self.output_resolutions - if self.save_resolution is not None: - resolutions.append(self.save_resolution) - - scale_factors = self.scale_to_highest(resolutions, self.resolution_unit) - num_input_resolutions = len(self.input_resolutions) - num_output_resolutions = len(self.output_resolutions) - - end_idx = num_input_resolutions + num_output_resolutions - output_resolutions = [ - {"units": "baseline", "resolution": v} - for v in scale_factors[num_input_resolutions:end_idx] - ] - - save_resolution = None - if self.save_resolution is not None: - save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} - - return IOSegmentorConfig( - input_resolutions=new_config.input_resolutions, - output_resolutions=output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - stride_shape=self.stride_shape, - save_resolution=save_resolution, - ) - - class WSIStreamDataset(torch_data.Dataset): """Reading a wsi in parallel mode with persistent workers. @@ -225,7 +112,7 @@ class WSIStreamDataset(torch_data.Dataset): mp_shared_space (:class:`Namespace`): A shared multiprocessing space, must be from `torch.multiprocessing`. - ioconfig (:class:`IOSegmentorConfig`): + ioconfig (:class:`tiatoolbox.models.IOSegmentorConfig`): An object which contains I/O placement for patches. wsi_paths (list): List of paths pointing to a WSI or tiles. preproc (Callable): From a8725a3c1117978f115305a1e3e154e85fc7a354 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 16 Jun 2023 23:30:16 +0100 Subject: [PATCH 019/112] :bug: Fix circular import - Fix circular import Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/nucleus_instance_segmentor.py | 3 ++- tiatoolbox/models/engine/patch_predictor.py | 2 +- tiatoolbox/models/engine/semantic_segmentor.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index a389a5932..ebdb9c1db 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -12,13 +12,14 @@ from shapely.geometry import box as shapely_box from shapely.strtree import STRtree -from tiatoolbox.models import IOInstanceSegmentorConfig from tiatoolbox.models.engine.semantic_segmentor import ( SemanticSegmentor, WSIStreamDataset, ) from tiatoolbox.tools.patchextraction import PatchExtractor +from .io_config import IOInstanceSegmentorConfig + def _process_instance_predictions( inst_dict, diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index c697cf494..5f4babc0f 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -17,7 +17,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader -from .. import IOPatchPredictorConfig +from .io_config import IOPatchPredictorConfig class PatchPredictor: diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 2c1ec6c17..b997f928f 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -25,7 +25,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader -from .. import IOSegmentorConfig +from .io_config import IOSegmentorConfig def _estimate_canvas_parameters(sample_prediction, canvas_shape): @@ -112,7 +112,7 @@ class WSIStreamDataset(torch_data.Dataset): mp_shared_space (:class:`Namespace`): A shared multiprocessing space, must be from `torch.multiprocessing`. - ioconfig (:class:`tiatoolbox.models.IOSegmentorConfig`): + ioconfig (:class:`IOSegmentorConfig`): An object which contains I/O placement for patches. wsi_paths (list): List of paths pointing to a WSI or tiles. preproc (Callable): From b226202c5740eaeed6f7eda8816a2a8d06709a08 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 00:01:10 +0100 Subject: [PATCH 020/112] :recycle: Refactor to reduce code in inheriting class - Refactor to reduce code in inheriting class Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/io_config.py | 96 +++++++-------------------- 1 file changed, 25 insertions(+), 71 deletions(-) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index ee70786e9..7ffbc428f 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List, Tuple, Union import numpy as np @@ -28,20 +28,14 @@ class ModelIOConfigABC: patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] stride_shape: Union[List[int], np.ndarray, Tuple[int]] highest_input_resolution: dict - output_resolutions: List[dict] = field(default_factory=list) - resolution_unit: Units = "mpp" - - def __init__( - self, - input_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray, Tuple[int]], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - ): - self.patch_input_shape = patch_input_shape - self.stride_shape = stride_shape - self.input_resolutions = input_resolutions - self.output_resolutions = [] - self.resolution_unit = input_resolutions[0]["units"] + output_resolutions: List[dict] + resolution_unit: Units + + def __post_init__(self): + if not self.output_resolutions: + self.output_resolutions = [] + + self.resolution_unit = self.input_resolutions[0]["units"] if self.resolution_unit == "mpp": self.highest_input_resolution = min( @@ -52,6 +46,8 @@ def __init__( self.input_resolutions, key=lambda x: x["resolution"] ) + self._validate() + def _validate(self): """Validate the data format.""" resolutions = self.input_resolutions + self.output_resolutions @@ -122,9 +118,13 @@ def to_baseline(self): input_resolutions=input_resolutions, patch_input_shape=self.patch_input_shape, stride_shape=self.stride_shape, + highest_input_resolution=self.highest_input_resolution, + resolution_unit=self.resolution_unit, + output_resolutions=self.output_resolutions, ) +@dataclass class IOSegmentorConfig(ModelIOConfigABC): """Contain semantic segmentor input and output information. @@ -178,27 +178,7 @@ class IOSegmentorConfig(ModelIOConfigABC): """ patch_output_shape: Union[List[int], np.ndarray] - save_resolution: dict = None - - def __init__( - self, - input_resolutions: List[dict], - output_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - patch_output_shape: Union[List[int], np.ndarray], - save_resolution: dict = None, - ): - super().__init__( - input_resolutions=input_resolutions, - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - ) - self.patch_output_shape = patch_output_shape - self.output_resolutions = output_resolutions - self.save_resolution = save_resolution - - self._validate() + save_resolution: dict def to_baseline(self): """Returns a new config object converted to baseline form. @@ -235,26 +215,22 @@ def to_baseline(self): patch_output_shape=self.patch_output_shape, stride_shape=self.stride_shape, save_resolution=save_resolution, + highest_input_resolution=self.highest_input_resolution, + resolution_unit=self.resolution_unit, ) class IOPatchPredictorConfig(ModelIOConfigABC): """Contains patch predictor input and output information.""" - def __init__( - self, - input_resolutions=None, - patch_input_shape=None, - stride_shape=None, - ): - stride_shape = patch_input_shape if stride_shape is None else stride_shape - super().__init__( - input_resolutions=input_resolutions, - stride_shape=stride_shape, - patch_input_shape=patch_input_shape, + def __post_init__(self): + self.stride_shape = ( + self.patch_input_shape if self.stride_shape is None else self.stride_shape ) + super().__post_init__() +@dataclass class IOInstanceSegmentorConfig(IOSegmentorConfig): """Contain instance segmentor input and output information. @@ -314,30 +290,6 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): margin: int tile_shape: Tuple[int, int] - def __init__( - self, - input_resolutions: List[dict], - output_resolutions: List[dict], - patch_input_shape: Union[List[int], np.ndarray], - stride_shape: Union[List[int], np.ndarray, Tuple[int]], - patch_output_shape: Union[List[int], np.ndarray], - save_resolution: dict = None, - margin: int = None, - tile_shape: Tuple[int, int] = None, - ): - super().__init__( - input_resolutions=input_resolutions, - output_resolutions=output_resolutions, - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - patch_output_shape=patch_output_shape, - save_resolution=save_resolution, - ) - self.margin = margin - self.tile_shape = tile_shape - - self._validate() - def to_baseline(self): """Returns a new config object converted to baseline form. @@ -358,4 +310,6 @@ def to_baseline(self): save_resolution=new_config.save_resolution, margin=self.margin, tile_shape=self.tile_shape, + highest_input_resolution=self.highest_input_resolution, + resolution_unit=self.resolution_unit, ) From 2fffb61086141f8f8dfa67588ad65e1eb86a4705 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 00:13:49 +0100 Subject: [PATCH 021/112] :bug: Fix deepsource errors - Fix deepsource errors Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_patch_predictor.py | 3 +++ tiatoolbox/models/engine/patch_predictor.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 5e3ebea2d..599f72e86 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -516,6 +516,9 @@ def test_io_config_delegation(remote_sample, tmp_path): patch_input_shape=[512, 512], stride_shape=[256, 256], input_resolutions=[{"resolution": 1.35, "units": "mpp"}], + highest_input_resolution={}, + output_resolutions=[], + resolution_unit="mpp", ) predictor.predict( [mini_wsi_svs], diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 5f4babc0f..a49f1531d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -440,7 +440,7 @@ def _update_ioconfig( """ Args: - ioconfig (tiatoolbox.models.IOPatchPredictorConfig): + ioconfig (:class:`IOPatchPredictorConfig`): patch_input_shape (tuple): Size of patches input to the model. Patches are at requested read resolution, not with respect to level 0, @@ -495,6 +495,9 @@ def _update_ioconfig( input_resolutions=[{"resolution": resolution, "units": units}], patch_input_shape=patch_input_shape, stride_shape=stride_shape, + highest_input_resolution={}, + output_resolutions=[], + resolution_unit="mpp", ) @staticmethod From 70da5d33d5639f5adb510e81310f80632719dd84 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 00:19:23 +0100 Subject: [PATCH 022/112] :recycle: Fix ioconfig import - Fix ioconfig import to avoid circular imports Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/multi_task_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 6c0654960..e7e501fe3 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -35,7 +35,7 @@ _process_instance_predictions, ) -from .. import IOInstanceSegmentorConfig +from .io_config import IOInstanceSegmentorConfig from .semantic_segmentor import WSIStreamDataset From cfc955f12336544081f93f5e7c3135cad0bb5dec Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 07:25:18 +0100 Subject: [PATCH 023/112] :bug: Fix missing input arguments - Fix missing input arguments Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_patch_predictor.py | 2 -- tiatoolbox/models/engine/io_config.py | 20 ++++++-------------- tiatoolbox/models/engine/patch_predictor.py | 2 -- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 599f72e86..c9b77390d 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -516,9 +516,7 @@ def test_io_config_delegation(remote_sample, tmp_path): patch_input_shape=[512, 512], stride_shape=[256, 256], input_resolutions=[{"resolution": 1.35, "units": "mpp"}], - highest_input_resolution={}, output_resolutions=[], - resolution_unit="mpp", ) predictor.predict( [mini_wsi_svs], diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 7ffbc428f..f13ef059b 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Tuple, Union import numpy as np @@ -27,9 +27,7 @@ class ModelIOConfigABC: input_resolutions: List[dict] patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] stride_shape: Union[List[int], np.ndarray, Tuple[int]] - highest_input_resolution: dict - output_resolutions: List[dict] - resolution_unit: Units + output_resolutions: List[dict] = field(default_factory=list) def __post_init__(self): if not self.output_resolutions: @@ -118,8 +116,6 @@ def to_baseline(self): input_resolutions=input_resolutions, patch_input_shape=self.patch_input_shape, stride_shape=self.stride_shape, - highest_input_resolution=self.highest_input_resolution, - resolution_unit=self.resolution_unit, output_resolutions=self.output_resolutions, ) @@ -177,8 +173,8 @@ class IOSegmentorConfig(ModelIOConfigABC): """ - patch_output_shape: Union[List[int], np.ndarray] - save_resolution: dict + patch_output_shape: Union[List[int], np.ndarray] = None + save_resolution: dict = None def to_baseline(self): """Returns a new config object converted to baseline form. @@ -215,8 +211,6 @@ def to_baseline(self): patch_output_shape=self.patch_output_shape, stride_shape=self.stride_shape, save_resolution=save_resolution, - highest_input_resolution=self.highest_input_resolution, - resolution_unit=self.resolution_unit, ) @@ -287,8 +281,8 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): """ - margin: int - tile_shape: Tuple[int, int] + margin: int = None + tile_shape: Tuple[int, int] = None def to_baseline(self): """Returns a new config object converted to baseline form. @@ -310,6 +304,4 @@ def to_baseline(self): save_resolution=new_config.save_resolution, margin=self.margin, tile_shape=self.tile_shape, - highest_input_resolution=self.highest_input_resolution, - resolution_unit=self.resolution_unit, ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index a49f1531d..e1417044e 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -495,9 +495,7 @@ def _update_ioconfig( input_resolutions=[{"resolution": resolution, "units": units}], patch_input_shape=patch_input_shape, stride_shape=stride_shape, - highest_input_resolution={}, output_resolutions=[], - resolution_unit="mpp", ) @staticmethod From b3ecadfe5a12790636d494a896370565bc953e88 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 09:14:52 +0100 Subject: [PATCH 024/112] :goal_net: Improve error catching - Improve error catching Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_multi_task_segmentor.py | 4 ++-- tiatoolbox/models/engine/io_config.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py index c0c10d15b..2d292b39c 100644 --- a/tests/models/test_multi_task_segmentor.py +++ b/tests/models/test_multi_task_segmentor.py @@ -190,7 +190,7 @@ def test_masked_segmentor(remote_sample, tmp_path): {"units": "mpp", "resolution": resolution}, ], margin=128, - tile_shape=[512, 512], + tile_shape=(512, 512), patch_input_shape=[256, 256], patch_output_shape=[164, 164], stride_shape=[164, 164], @@ -309,7 +309,7 @@ def test_empty_image(tmp_path): bcc_wsi_ioconfig = IOInstanceSegmentorConfig( input_resolutions=[{"units": "mpp", "resolution": 0.25}], output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, + tile_shape=(2048, 2048), patch_input_shape=[1024, 1024], patch_output_shape=[512, 512], stride_shape=[512, 512], diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index f13ef059b..dc0d2ef06 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -51,7 +51,14 @@ def _validate(self): resolutions = self.input_resolutions + self.output_resolutions units = [v["units"] for v in resolutions] units = np.unique(units) - if len(units) != 1 or units[0] not in [ + + if len(units) != 1: + raise ValueError( + f"Invalid resolution units `{units}`. " + f"The resolution units must be unique." + ) + + if units[0] not in [ "power", "baseline", "mpp", From 09b8cc4a338ae040763cb8a589b32fe47c072fd9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 09:28:25 +0100 Subject: [PATCH 025/112] :goal_net: Default configuration is not correct - Default configuration is not correct Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/semantic_segmentor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index b997f928f..0df4092f3 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1022,8 +1022,8 @@ def predict( patch_input_shape=None, patch_output_shape=None, stride_shape=None, - resolution=1.0, - units="baseline", + resolution=None, + units=None, save_dir=None, crash_on_exception=False, ): @@ -1111,6 +1111,17 @@ def predict( save_dir, self._cache_dir = self._prepare_save_dir(save_dir) + if not ioconfig: + ioconfig = self.ioconfig + + if not resolution and not units and ioconfig.input_resolutions: + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.resolution_unit[0]["units"] + else: + raise ValueError( + f"Invalid resolution: `{resolution}` and units: `{units}`. " + ) + ioconfig = self._update_ioconfig( ioconfig, mode, From c2e7aae6dd8c029f2aa4da30d8756741beb282b8 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 10:08:13 +0100 Subject: [PATCH 026/112] :goal_net: Fix error check - Fix error check Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/io_config.py | 3 --- tiatoolbox/models/engine/semantic_segmentor.py | 17 +++++++++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index dc0d2ef06..a6386801f 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -30,9 +30,6 @@ class ModelIOConfigABC: output_resolutions: List[dict] = field(default_factory=list) def __post_init__(self): - if not self.output_resolutions: - self.output_resolutions = [] - self.resolution_unit = self.input_resolutions[0]["units"] if self.resolution_unit == "mpp": diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 0df4092f3..11b24ce18 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1114,13 +1114,14 @@ def predict( if not ioconfig: ioconfig = self.ioconfig - if not resolution and not units and ioconfig.input_resolutions: - resolution = ioconfig.input_resolutions[0]["resolution"] - units = ioconfig.resolution_unit[0]["units"] - else: - raise ValueError( - f"Invalid resolution: `{resolution}` and units: `{units}`. " - ) + if not resolution and not units: + if ioconfig.input_resolutions: + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.resolution_unit[0]["units"] + elif not ioconfig: + raise ValueError( + f"Invalid resolution: `{resolution}` and units: `{units}`. " + ) ioconfig = self._update_ioconfig( ioconfig, @@ -1378,7 +1379,7 @@ def predict( Resolution used for reading the image. units (Units): Units of resolution used for reading the image. - save_dir (str): + save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. From e02fd4766bd4f91e17bbd021fd7159aa9b65ec25 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 10:39:27 +0100 Subject: [PATCH 027/112] :pencil2: Fix typo - Fix typo Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/semantic_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 11b24ce18..501fc9d31 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1117,7 +1117,7 @@ def predict( if not resolution and not units: if ioconfig.input_resolutions: resolution = ioconfig.input_resolutions[0]["resolution"] - units = ioconfig.resolution_unit[0]["units"] + units = ioconfig.input_resolutions[0]["units"] elif not ioconfig: raise ValueError( f"Invalid resolution: `{resolution}` and units: `{units}`. " From ec6ff9618d039dc9907faed84fd28eea86252d1a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 11:22:23 +0100 Subject: [PATCH 028/112] :bug: Fix unnecessary update - Fix unnecessary update Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/io_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index a6386801f..ad0333ba7 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -120,7 +120,7 @@ def to_baseline(self): input_resolutions=input_resolutions, patch_input_shape=self.patch_input_shape, stride_shape=self.stride_shape, - output_resolutions=self.output_resolutions, + output_resolutions=[], ) From e24ca47cf4d91eb8592d366e6726a05b6cd104f6 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 14:50:29 +0100 Subject: [PATCH 029/112] :recycle: Refactor code to improve behaviour - Refactor code to improve behaviour Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../models/engine/semantic_segmentor.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 501fc9d31..dc188a93e 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -893,14 +893,7 @@ def _update_ioconfig( if stride_shape is None: stride_shape = patch_output_shape - if ioconfig is None and patch_input_shape is None: - if self.ioconfig is None: - raise ValueError( - "Must provide either `ioconfig` or " - "`patch_input_shape` and `patch_output_shape`" - ) - ioconfig = copy.deepcopy(self.ioconfig) - elif ioconfig is None: + if ioconfig is None: ioconfig = IOSegmentorConfig( input_resolutions=[{"resolution": resolution, "units": units}], output_resolutions=[{"resolution": resolution, "units": units}], @@ -1111,17 +1104,22 @@ def predict( save_dir, self._cache_dir = self._prepare_save_dir(save_dir) - if not ioconfig: - ioconfig = self.ioconfig + if ioconfig is None: + ioconfig = copy.deepcopy(self.ioconfig) - if not resolution and not units: - if ioconfig.input_resolutions: - resolution = ioconfig.input_resolutions[0]["resolution"] - units = ioconfig.input_resolutions[0]["units"] - elif not ioconfig: - raise ValueError( - f"Invalid resolution: `{resolution}` and units: `{units}`. " - ) + if ioconfig is None and patch_input_shape is None: + raise ValueError( + "Must provide either `ioconfig` or " + "`patch_input_shape` and `patch_output_shape`" + ) + + if ioconfig is None and resolution is None and units is None: + raise ValueError( + f"Invalid resolution: `{resolution}` and units: `{units}`. " + ) + + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.input_resolutions[0]["units"] ioconfig = self._update_ioconfig( ioconfig, From 3fe5d13ea677f3d1aca8e0df6904b4870d8984ae Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 14:54:00 +0100 Subject: [PATCH 030/112] :bug: Fix Consider decorating method with @staticmethod PYL-R0201 - Fix Consider decorating method with @staticmethod PYL-R0201 Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/semantic_segmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index dc188a93e..ff68392ad 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -844,8 +844,8 @@ def _prepare_save_dir(save_dir): return save_dir, cache_dir + @staticmethod def _update_ioconfig( - self, ioconfig, mode, patch_input_shape, From 4d035332af24a64a537e28b83e9fdf00686ecc3b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 19:26:40 +0100 Subject: [PATCH 031/112] :bug: Fix update config logic - Fix update config logic Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_semantic_segmentation.py | 20 ++++++++++--------- .../models/engine/semantic_segmentor.py | 13 ++++++------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 586ce806e..36e5975e0 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -265,8 +265,8 @@ def test_crash_segmentor(remote_sample): model = _CNNTo1() semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) # fake injection to trigger Segmentor to create parallel - # post processing workers because baseline Semantic Segmentor does not support - # post processing out of the box. It only contains condition to create it + # post-processing workers because baseline Semantic Segmentor does not support + # post-processing out of the box. It only contains condition to create it # for any subclass semantic_segmentor.num_postproc_workers = 1 @@ -297,7 +297,7 @@ def test_crash_segmentor(remote_sample): crash_on_exception=True, ) with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict([], mode="tile", patch_input_shape=[2048, 2048]) + semantic_segmentor.predict([], mode="tile", patch_input_shape=(2048, 2048)) _rm_dir("output") # default output dir test # * test not providing any io_config info when not using pretrained model @@ -310,23 +310,25 @@ def test_crash_segmentor(remote_sample): ) _rm_dir("output") # default output dir test - # * Test crash propagation when parallelize post processing + # * Test crash propagation when parallelize post-processing _rm_dir("output") semantic_segmentor.num_postproc_workers = 2 semantic_segmentor.model.forward = _crash_func with pytest.raises(ValueError, match=r"Propagation Crash."): semantic_segmentor.predict( [mini_wsi_svs], - patch_input_shape=[2048, 2048], + patch_input_shape=(2048, 2048), mode="wsi", on_gpu=ON_GPU, crash_on_exception=True, + resolution=1.0, + units="baseline", ) _rm_dir("output") # test ignore crash semantic_segmentor.predict( [mini_wsi_svs], - patch_input_shape=[2048, 2048], + patch_input_shape=(2048, 2048), mode="wsi", on_gpu=ON_GPU, crash_on_exception=False, @@ -461,8 +463,8 @@ def test_functional_segmentor(remote_sample, tmp_path): model = _CNNTo1() semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) # fake injection to trigger Segmentor to create parallel - # post processing workers because baseline Semantic Segmentor does not support - # post processing out of the box. It only contains condition to create it + # post-processing workers because baseline Semantic Segmentor does not support + # post-processing out of the box. It only contains condition to create it # for any subclass semantic_segmentor.num_postproc_workers = 1 @@ -482,7 +484,7 @@ def test_functional_segmentor(remote_sample, tmp_path): [mini_wsi_jpg], mode="tile", on_gpu=ON_GPU, - patch_input_shape=[512, 512], + patch_input_shape=(512, 512), resolution=1 / resolution, units="baseline", crash_on_exception=True, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index ff68392ad..12844c893 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1113,13 +1113,14 @@ def predict( "`patch_input_shape` and `patch_output_shape`" ) - if ioconfig is None and resolution is None and units is None: - raise ValueError( - f"Invalid resolution: `{resolution}` and units: `{units}`. " - ) + if resolution is None and units is None: + if ioconfig is None: + raise ValueError( + f"Invalid resolution: `{resolution}` and units: `{units}`. " + ) - resolution = ioconfig.input_resolutions[0]["resolution"] - units = ioconfig.input_resolutions[0]["units"] + resolution = ioconfig.input_resolutions[0]["resolution"] + units = ioconfig.input_resolutions[0]["units"] ioconfig = self._update_ioconfig( ioconfig, From 45a1442635c6172d4e302f25181708769f0ea16d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 17 Jun 2023 19:33:42 +0100 Subject: [PATCH 032/112] :bug: Fix update config logic - Fix update config logic Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_semantic_segmentation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 36e5975e0..97dfea147 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -332,6 +332,8 @@ def test_crash_segmentor(remote_sample): mode="wsi", on_gpu=ON_GPU, crash_on_exception=False, + resolution=1.0, + units="baseline", ) _rm_dir("output") @@ -427,7 +429,7 @@ def test_functional_segmentor_merging(tmp_path): _rm_dir(save_dir) os.mkdir(save_dir) - # * with out of bound location + # * without of bound location canvas = semantic_segmentor.merge_prediction( [4, 4], [ From c6542a4456f928c12281e355f17edb2f6f90dcfb Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 11:15:09 +0100 Subject: [PATCH 033/112] :memo: Update docstring - Update docstring --- tiatoolbox/models/engine/io_config.py | 151 +++++++++++++++++++++----- 1 file changed, 122 insertions(+), 29 deletions(-) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index ad0333ba7..b488c8c67 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -17,16 +17,35 @@ class ModelIOConfigABC: input_resolutions (list(dict)): Resolution of each input head of model inference, must be in the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. """ input_resolutions: List[dict] - patch_input_shape: Union[List[int], np.ndarray, Tuple[int]] - stride_shape: Union[List[int], np.ndarray, Tuple[int]] + patch_input_shape: Union[List[int], np.ndarray, Tuple[int, int]] + stride_shape: Union[List[int], np.ndarray, Tuple[int, int]] output_resolutions: List[dict] = field(default_factory=list) def __post_init__(self): @@ -126,23 +145,43 @@ def to_baseline(self): @dataclass class IOSegmentorConfig(ModelIOConfigABC): - """Contain semantic segmentor input and output information. + """Contains semantic segmentor input and output information. Args: - input_resolutions (list): + input_resolutions (list(dict)): Resolution of each input head of model inference, must be in the same order as `target model.forward()`. - output_resolutions (list): + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): Resolution of each output head from model inference, must be in the same order as target model.infer_batch(). - stride_shape (:class:`numpy.ndarray`, list(int)): - Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int)): + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). patch_output_shape (:class:`numpy.ndarray`, list(int)): Shape of the largest output in (height, width). save_resolution (dict): Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. Examples: >>> # Defining io for a network having 1 input and 1 output at the @@ -150,9 +189,9 @@ class IOSegmentorConfig(ModelIOConfigABC): >>> ioconfig = IOSegmentorConfig( ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), ... ) Examples: @@ -169,15 +208,15 @@ class IOSegmentorConfig(ModelIOConfigABC): ... {"units": "mpp", "resolution": 0.25}, ... {"units": "mpp", "resolution": 0.50}, ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), ... save_resolution={"units": "mpp", "resolution": 4.0}, ... ) """ - patch_output_shape: Union[List[int], np.ndarray] = None + patch_output_shape: Union[List[int], np.ndarray, Tuple[int, int]] = None save_resolution: dict = None def to_baseline(self): @@ -219,7 +258,37 @@ def to_baseline(self): class IOPatchPredictorConfig(ModelIOConfigABC): - """Contains patch predictor input and output information.""" + """Contains patch predictor input and output information. + + Args: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. + + """ def __post_init__(self): self.stride_shape = ( @@ -230,23 +299,47 @@ def __post_init__(self): @dataclass class IOInstanceSegmentorConfig(IOSegmentorConfig): - """Contain instance segmentor input and output information. + """Contains instance segmentor input and output information. Args: - input_resolutions (list): + input_resolutions (list(dict)): Resolution of each input head of model inference, must be in the same order as `target model.forward()`. - output_resolutions (list): + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): Resolution of each output head from model inference, must be in the same order as target model.infer_batch(). - stride_shape (:class:`numpy.ndarray`, list(int)): - Stride in (x, y) direction for patch extraction. - patch_input_shape (:class:`numpy.ndarray`, list(int)): + patch_output_shape (:class:`numpy.ndarray`, list(int)): + Shape of the largest output in (height, width). + save_resolution (dict): + Resolution to save all output. + margin (int): + Tile margin to accumulate the output. + tile_shape (tuple(int, int)): + Tile shape to process the WSI. + + Attributes: + input_resolutions (list(dict)): + Resolution of each input head of model inference, must be in + the same order as `target model.forward()`. + patch_input_shape (:class:`numpy.ndarray`, list(int), tuple(int, int)): Shape of the largest input in (height, width). + stride_shape (:class:`numpy.ndarray`, list(int), tuple(int)): + Stride in (x, y) direction for patch extraction. + output_resolutions (list(dict)): + Resolution of each output head from model inference, must be + in the same order as target model.infer_batch(). patch_output_shape (:class:`numpy.ndarray`, list(int)): Shape of the largest output in (height, width). save_resolution (dict): Resolution to save all output. + highest_input_resolution (dict): + Highest resolution to process the image based on input and + output resolutions. This helps to read the image at the optimal + resolution and improves performance. margin (int): Tile margin to accumulate the output. tile_shape (tuple(int, int)): @@ -258,9 +351,9 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): >>> ioconfig = IOSegmentorConfig( ... input_resolutions=[{"units": "baseline", "resolution": 1.0}], ... output_resolutions=[{"units": "baseline", "resolution": 1.0}], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), ... ) Examples: @@ -277,9 +370,9 @@ class IOInstanceSegmentorConfig(IOSegmentorConfig): ... {"units": "mpp", "resolution": 0.25}, ... {"units": "mpp", "resolution": 0.50}, ... ], - ... patch_input_shape=[2048, 2048], - ... patch_output_shape=[1024, 1024], - ... stride_shape=[512, 512], + ... patch_input_shape=(2048, 2048), + ... patch_output_shape=(1024, 1024), + ... stride_shape=(512, 512), ... save_resolution={"units": "mpp", "resolution": 4.0}, ... ) From 2d7f2d96d5d0fba20ac43463580d889337ac8596 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 11:32:24 +0100 Subject: [PATCH 034/112] :white_check_mark: Add validation checks - Add validation checks --- tests/engines/__init__.py | 1 + tests/engines/test_ioconfig.py | 22 ++++++++++++++++++++++ tiatoolbox/models/__init__.py | 2 ++ tiatoolbox/models/engine/io_config.py | 5 ++++- 4 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 tests/engines/__init__.py create mode 100644 tests/engines/test_ioconfig.py diff --git a/tests/engines/__init__.py b/tests/engines/__init__.py new file mode 100644 index 000000000..193a523c1 --- /dev/null +++ b/tests/engines/__init__.py @@ -0,0 +1 @@ +"""Unit test package for tiatoolbox engines.""" diff --git a/tests/engines/test_ioconfig.py b/tests/engines/test_ioconfig.py new file mode 100644 index 000000000..27add18df --- /dev/null +++ b/tests/engines/test_ioconfig.py @@ -0,0 +1,22 @@ +"""Tests for IOconfig.""" + +import pytest + +from tiatoolbox.models import ModelIOConfigABC + + +def test_validation_error_io_config(): + with pytest.raises(ValueError, match=r".*The resolution units must be unique.*"): + ModelIOConfigABC( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + {"units": "mpp", "resolution": 0.25}, + ], + patch_input_shape=(224, 224), + ) + + with pytest.raises(ValueError, match=r"Invalid resolution units.*"): + ModelIOConfigABC( + input_resolutions=[{"units": "level", "resolution": 1.0}], + patch_input_shape=(224, 224), + ) diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 76c4389ab..db809725e 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -12,6 +12,7 @@ IOInstanceSegmentorConfig, IOPatchPredictorConfig, IOSegmentorConfig, + ModelIOConfigABC, ) from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor @@ -37,4 +38,5 @@ "IOPatchPredictorConfig", "IOSegmentorConfig", "IOInstanceSegmentorConfig", + "ModelIOConfigABC", ] diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index b488c8c67..5eb55dc38 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -45,10 +45,13 @@ class ModelIOConfigABC: input_resolutions: List[dict] patch_input_shape: Union[List[int], np.ndarray, Tuple[int, int]] - stride_shape: Union[List[int], np.ndarray, Tuple[int, int]] + stride_shape: Union[List[int], np.ndarray, Tuple[int, int]] = None output_resolutions: List[dict] = field(default_factory=list) def __post_init__(self): + if self.stride_shape is None: + self.stride_shape = self.patch_input_shape + self.resolution_unit = self.input_resolutions[0]["units"] if self.resolution_unit == "mpp": From 2a5d38c31179d170808248b0a8ea0dbaa08ea3e4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:07:16 +0100 Subject: [PATCH 035/112] :white_check_mark: Add validation checks - Add validation checks --- tests/models/test_semantic_segmentation.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 97dfea147..218b63173 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -324,6 +324,18 @@ def test_crash_segmentor(remote_sample): resolution=1.0, units="baseline", ) + + _rm_dir("output") + + with pytest.raises(ValueError, match=r"Invalid resolution.*"): + semantic_segmentor.predict( + [mini_wsi_svs], + patch_input_shape=(2048, 2048), + mode="wsi", + on_gpu=ON_GPU, + crash_on_exception=True, + ) + _rm_dir("output") # test ignore crash semantic_segmentor.predict( From cd4acbb590f4021399117ccf9d90ff3963225818 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 16:20:25 +0100 Subject: [PATCH 036/112] :technologist: Improve error message - Improve error message --- tests/engines/test_ioconfig.py | 2 +- tests/models/test_semantic_segmentation.py | 2 +- tiatoolbox/models/engine/io_config.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/engines/test_ioconfig.py b/tests/engines/test_ioconfig.py index 27add18df..245797a73 100644 --- a/tests/engines/test_ioconfig.py +++ b/tests/engines/test_ioconfig.py @@ -6,7 +6,7 @@ def test_validation_error_io_config(): - with pytest.raises(ValueError, match=r".*The resolution units must be unique.*"): + with pytest.raises(ValueError, match=r".*Multiple resolution units found.*"): ModelIOConfigABC( input_resolutions=[ {"units": "baseline", "resolution": 1.0}, diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 218b63173..cbd0222d5 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -441,7 +441,7 @@ def test_functional_segmentor_merging(tmp_path): _rm_dir(save_dir) os.mkdir(save_dir) - # * without of bound location + # * with an out of bound location canvas = semantic_segmentor.merge_prediction( [4, 4], [ diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 5eb55dc38..71aa51287 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -73,8 +73,8 @@ def _validate(self): if len(units) != 1: raise ValueError( - f"Invalid resolution units `{units}`. " - f"The resolution units must be unique." + f"Multiple resolution units found: `{units}`. " + f"Mixing resolution units is not allowed." ) if units[0] not in [ From caab9a924a0a2bd114aac9040ffb1768282ebc0a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 17:09:01 +0100 Subject: [PATCH 037/112] :art: Improve code structure - Improve code structure --- tiatoolbox/models/engine/io_config.py | 29 ++++++--------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/tiatoolbox/models/engine/io_config.py b/tiatoolbox/models/engine/io_config.py index 71aa51287..ed78fe4b1 100644 --- a/tiatoolbox/models/engine/io_config.py +++ b/tiatoolbox/models/engine/io_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import List, Tuple, Union import numpy as np @@ -138,12 +138,7 @@ def to_baseline(self): {"units": "baseline", "resolution": v} for v in scale_factors[:end_idx] ] - return ModelIOConfigABC( - input_resolutions=input_resolutions, - patch_input_shape=self.patch_input_shape, - stride_shape=self.stride_shape, - output_resolutions=[], - ) + return replace(self, input_resolutions=input_resolutions, output_resolutions=[]) @dataclass @@ -250,12 +245,10 @@ def to_baseline(self): if self.save_resolution is not None: save_resolution = {"units": "baseline", "resolution": scale_factors[-1]} - return IOSegmentorConfig( + return replace( + self, input_resolutions=new_config.input_resolutions, output_resolutions=output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - stride_shape=self.stride_shape, save_resolution=save_resolution, ) @@ -293,12 +286,6 @@ class IOPatchPredictorConfig(ModelIOConfigABC): """ - def __post_init__(self): - self.stride_shape = ( - self.patch_input_shape if self.stride_shape is None else self.stride_shape - ) - super().__post_init__() - @dataclass class IOInstanceSegmentorConfig(IOSegmentorConfig): @@ -395,13 +382,9 @@ def to_baseline(self): """ new_config = super().to_baseline() - return IOInstanceSegmentorConfig( + return replace( + self, input_resolutions=new_config.input_resolutions, output_resolutions=new_config.output_resolutions, - patch_input_shape=self.patch_input_shape, - patch_output_shape=self.patch_output_shape, - stride_shape=self.stride_shape, save_resolution=new_config.save_resolution, - margin=self.margin, - tile_shape=self.tile_shape, ) From 7bebf9a2d08b4a853cad72524bd1246156d1ac68 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:47:00 +0100 Subject: [PATCH 038/112] :fire: Remove invalid tests - Remove invalid tests --- tests/models/test_semantic_segmentation.py | 36 ---------------------- 1 file changed, 36 deletions(-) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index cbd0222d5..143634962 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -1,7 +1,4 @@ """Tests for Semantic Segmentor.""" - -import copy - # ! The garbage collector import gc import multiprocessing @@ -110,39 +107,6 @@ def infer_batch(model, batch_data, on_gpu): def test_segmentor_ioconfig(): """Test for IOConfig.""" - default_config = { - "input_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - "output_resolutions": [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - "patch_input_shape": [2048, 2048], - "patch_output_shape": [1024, 1024], - "stride_shape": [512, 512], - } - - # error when uniform resolution units are not uniform - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "power", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - - # error when uniform resolution units are not supported - xconfig = copy.deepcopy(default_config) - xconfig["input_resolutions"] = [ - {"units": "alpha", "resolution": 0.25}, - {"units": "alpha", "resolution": 0.50}, - ] - with pytest.raises(ValueError, match=r".*Invalid resolution units.*"): - _ = IOSegmentorConfig(**xconfig) - ioconfig = IOSegmentorConfig( input_resolutions=[ {"units": "mpp", "resolution": 0.25}, From 37d826a3197027e41e88904e5324131631a41085 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 6 Jul 2023 13:53:34 +0100 Subject: [PATCH 039/112] :art: New `EngineABC` design - New `EngineABC` design --- tiatoolbox/models/engine/engine_abc.py | 167 +++++++++++++++++- tiatoolbox/models/engine/patch_predictor.py | 81 ++++++--- .../models/engine/semantic_segmentor.py | 8 +- 3 files changed, 215 insertions(+), 41 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 661c882f2..05e3a81bd 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,18 +1,169 @@ from abc import ABC, abstractmethod +import torch.nn as nn + +from tiatoolbox.models.architecture import get_pretrained_model + class EngineABC(ABC): - """Abstract base class for engines used in tiatoolbox.""" + """Abstract base class for engines used in tiatoolbox. + + Args: + model (nn.Module): + Use externally defined PyTorch model for prediction with. + weights already loaded. Default is `None`. If provided, + `pretrained_model` argument is ignored. + pretrained_model (str): + Name of the existing models support by tiatoolbox for + processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `pretrained_weights` argument. Argument + is case-insensitive. + pretrained_weights (str): + Path to the weight of the corresponding `pretrained_model`. + + >>> engine = EngineABC( + ... pretrained_model="pretrained-model-name", + ... pretrained_weights="pretrained-local-weights.pth") + + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. default = 0 + num_postproc_workers (int): + Number of workers to postprocess the results of the model. default = 0 + verbose (bool): + Whether to output logging information. + + Attributes: + images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + A HWC image or a path to WSI. + mode (str): + Type of input to process. Choose from either `patch`, `tile` + or `wsi`. + model (nn.Module): + Defined PyTorch model. + pretrained_model (str): + Name of the existing models support by tiatoolbox for + processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `pretrained_weights` argument. Argument + is case insensitive. + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers used in torch.utils.data.DataLoader. + verbose (bool): + Whether to output logging information. + + Examples: + >>> # list of 2 image patches as input + >>> data = ["path/to/image1.svs", "path/to/image2.svs"] + >>> engine = EngineABC(pretrained_model="resnet18-kather100k") + >>> output = engine.predict(data, mode='patch') + + >>> # array of list of 2 image patches as input + >>> data = np.array([img1, img2]) + >>> engine = EngineABC(pretrained_model="resnet18-kather100k") + >>> output = engine.predict(data, mode='patch') + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> engine = EngineABC(pretrained_model="resnet18-kather100k") + >>> output = engine.predict(data, mode='patch') + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> engine = EngineABC(pretraind_model="resnet18-kather100k") + >>> output = engine.predict(tile_file, mode='tile') + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> engine = EngineABC(pretraind_model="resnet18-kather100k") + >>> output = engine.predict(wsi_file, mode='wsi') - def __init__(self): + """ + + def __init__( + self, + batch_size: int = 8, + num_loader_workers: int = 0, + num_postproc_workers: int = 0, + model: nn.Module = None, + pretrained_model: str = None, + pretrained_weights: str = None, + verbose: bool = False, + ): super().__init__() + self.images = None + self.mode = None + + if model is None and pretrained_model is None: + raise ValueError("Must provide either `model` or `pretrained_model`.") + + if model is not None: + self.model = model + ioconfig = None # retrieve ioconfig from provided model. + else: + model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) + + self.ioconfig = ioconfig # for storing original + self.model = model # for runtime, such as after wrapping with nn.DataParallel + self.pretrained_model = pretrained_model + self.batch_size = batch_size + self.num_loader_workers = num_loader_workers + self.num_postproc_workers = num_postproc_workers + self.verbose = verbose + + @abstractmethod + @property + def _ioconfig(self): # runtime ioconfig + return self.ioconfig + + @abstractmethod + def pre_process_patch(self): + raise NotImplementedError + + @abstractmethod + def pre_process_tile(self): + raise NotImplementedError + + @abstractmethod + def pre_process_wsi(self): + raise NotImplementedError + + @abstractmethod + def infer_patch(self): + raise NotImplementedError + + @abstractmethod + def infer_tile(self): + raise NotImplementedError + @abstractmethod - def process_patch(self): + def infer_wsi(self): raise NotImplementedError - # how to deal with patches, list of patches/numpy arrays, WSIs - # how to communicate with sub-processes. - # define how to deal with patches as numpy/zarr arrays. - # convert list of patches/numpy arrays to zarr and then pass to each sub-processes. - # define how to read WSIs, read the image and convert to zarr array. + @abstractmethod + def post_process_patch(self): + raise NotImplementedError + + @abstractmethod + def post_process_tile(self): + raise NotImplementedError + + @abstractmethod + def post_process_wsi(self): + raise NotImplementedError + + @abstractmethod + def run_pipeline(self): + raise NotImplementedError diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index e1417044e..05d998f6f 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -9,9 +9,9 @@ import numpy as np import torch import tqdm +from engine_abc import EngineABC from tiatoolbox import logger -from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset from tiatoolbox.utils import misc, save_as_json from tiatoolbox.wsicore.wsimeta import Resolution, Units @@ -20,8 +20,8 @@ from .io_config import IOPatchPredictorConfig -class PatchPredictor: - r"""Patch level predictor. +class PatchPredictor(EngineABC): + r"""Patch level predictor for digital histology images. The models provided by tiatoolbox should give the following results: @@ -137,7 +137,7 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` @@ -162,7 +162,7 @@ class PatchPredictor: Examples: >>> # list of 2 image patches as input - >>> data = [img1, img2] + >>> data = ['path/img.svs', 'path/img.svs'] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') @@ -199,34 +199,57 @@ class PatchPredictor: def __init__( self, - batch_size=8, - num_loader_workers=0, - model=None, - pretrained_model=None, - pretrained_weights=None, - verbose=True, + batch_size: int = 8, + num_loader_workers: int = 0, + num_postproc_workers: int = 0, + model: torch.nn.Module = None, + pretrained_model: str = None, + pretrained_weights: str = None, + verbose: bool = False, ): - super().__init__() + super().__init__( + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_postproc_workers=num_postproc_workers, + model=model, + pretrained_model=pretrained_model, + pretrained_weights=pretrained_weights, + verbose=verbose, + ) - self.imgs = None - self.mode = None + @property + def _ioconfig(self): # runtime ioconfig + return self.ioconfig - if model is None and pretrained_model is None: - raise ValueError("Must provide either `model` or `pretrained_model`.") + def pre_process_patch(self): + pass - if model is not None: - self.model = model - ioconfig = None # retrieve iostate from provided model ? - else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) + def pre_process_tile(self): + pass + + def pre_process_wsi(self): + pass + + def post_process_patch(self): + pass + + def post_process_tile(self): + pass + + def post_process_wsi(self): + pass + + def infer_patch(self): + pass + + def infer_tile(self): + pass + + def infer_wsi(self): + pass - self.ioconfig = ioconfig # for storing original - self._ioconfig = None # for storing runtime - self.model = model # for runtime, such as after wrapping with nn.DataParallel - self.pretrained_model = pretrained_model - self.batch_size = batch_size - self.num_loader_worker = num_loader_workers - self.verbose = verbose + def run_pipeline(self): + pass @staticmethod def merge_predictions( @@ -376,7 +399,7 @@ def _predict_engine( # preprocessing must be defined with the dataset dataloader = torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_worker, + num_workers=self.num_loader_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 12844c893..e23cd4619 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -340,13 +340,13 @@ def __init__( raise ValueError("Must provide either of `model` or `pretrained_model`") if model is not None: - self.model = model # template ioconfig, usually coming from pretrained - self.ioconfig = None + ioconfig = None else: model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) - self.ioconfig = ioconfig - self.model = model + + self.ioconfig = ioconfig + self.model = model # local variables for flagging mode within class, # subclass should have overwritten to alter some specific behavior From 2df715f557b964f86a798f791d8d48553ab4b433 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 6 Jul 2023 14:30:03 +0100 Subject: [PATCH 040/112] :bug: Fix import - Fix import --- tiatoolbox/models/engine/patch_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 05d998f6f..ea0905017 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -9,7 +9,6 @@ import numpy as np import torch import tqdm -from engine_abc import EngineABC from tiatoolbox import logger from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset @@ -17,6 +16,7 @@ from tiatoolbox.wsicore.wsimeta import Resolution, Units from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .engine_abc import EngineABC from .io_config import IOPatchPredictorConfig From 027a6e57e4ec849921f934a649f86209d03fb603 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:47:44 +0100 Subject: [PATCH 041/112] :recycle: Rename `run_pipeline` to `run` - Rename `run_pipeline` to `run` --- tiatoolbox/models/engine/engine_abc.py | 2 +- tiatoolbox/models/engine/patch_predictor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 05e3a81bd..c1c278c49 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -165,5 +165,5 @@ def post_process_wsi(self): raise NotImplementedError @abstractmethod - def run_pipeline(self): + def run(self): raise NotImplementedError diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index ea0905017..f81629b99 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -248,7 +248,7 @@ def infer_tile(self): def infer_wsi(self): pass - def run_pipeline(self): + def run(self): pass @staticmethod From de48fce886204372ddd1ee1b30167a96745f9dd9 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:48:51 +0100 Subject: [PATCH 042/112] :bug: Remove `abstractmethod` from `property` - Remove `abstractmethod` from `property` --- tiatoolbox/models/engine/engine_abc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index c1c278c49..f279d4582 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -123,7 +123,6 @@ def __init__( self.num_postproc_workers = num_postproc_workers self.verbose = verbose - @abstractmethod @property def _ioconfig(self): # runtime ioconfig return self.ioconfig From 2ed5c385e53c934d2da7ddad9c041f928097e95f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:48:28 +0100 Subject: [PATCH 043/112] :bug: Remove `_ioconfig` - Remove `_ioconfig` --- tiatoolbox/models/engine/engine_abc.py | 5 +---- tiatoolbox/models/engine/patch_predictor.py | 4 ---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index f279d4582..4b7da423e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -116,6 +116,7 @@ def __init__( model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) self.ioconfig = ioconfig # for storing original + self._ioconfig = self.ioconfig # runtime ioconfig self.model = model # for runtime, such as after wrapping with nn.DataParallel self.pretrained_model = pretrained_model self.batch_size = batch_size @@ -123,10 +124,6 @@ def __init__( self.num_postproc_workers = num_postproc_workers self.verbose = verbose - @property - def _ioconfig(self): # runtime ioconfig - return self.ioconfig - @abstractmethod def pre_process_patch(self): raise NotImplementedError diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index f81629b99..dd3079674 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -217,10 +217,6 @@ def __init__( verbose=verbose, ) - @property - def _ioconfig(self): # runtime ioconfig - return self.ioconfig - def pre_process_patch(self): pass From 3c963cac82e9c1cd94d0378c8e51369a863cee6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jul 2023 12:27:22 +0000 Subject: [PATCH 044/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/multi_task_segmentor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 1b5a0348f..d8309d257 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -285,15 +285,15 @@ def _predict_one_wsi( """Make a prediction on tile/wsi. Args: - wsi_idx (int): + wsi_idx (int): Index of the tile/wsi to be processed within `self`. - ioconfig (IOInstanceSegmentorConfig): + ioconfig (IOInstanceSegmentorConfig): Object which defines I/O placement during inference and when assembling back to full tile/wsi. - save_path (str): + save_path (str): Location to save output prediction as well as possible intermediate results. - mode (str): + mode (str): `tile` or `wsi` to indicate run mode. """ From 8938ea91c6c032e14942995b766319bddc6cd8d1 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 2 Aug 2023 12:46:41 +0100 Subject: [PATCH 045/112] :rotating_light: Fix linter errors Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/engine_abc.py | 42 ++++++++++----- tiatoolbox/models/engine/patch_predictor.py | 58 ++++++++++++--------- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6c743a670..56b90cce1 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,11 +1,14 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" -from abc import ABC, abstractmethod -from typing import Optional +from __future__ import annotations -import torch.nn as nn +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING from tiatoolbox.models.architecture import get_pretrained_model +if TYPE_CHECKING: + from torch import nn + class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. @@ -36,7 +39,7 @@ class EngineABC(ABC): num_loader_workers (int): Number of workers to load the data using :class:`torch.utils.data.Dataset`. Please note that they will also perform preprocessing. default = 0 - num_postproc_workers (int): + num_post_proc_workers (int): Number of workers to postprocess the results of the model. default = 0 verbose (bool): Whether to output logging information. @@ -69,27 +72,28 @@ class EngineABC(ABC): >>> # list of 2 image patches as input >>> data = ["path/to/image1.svs", "path/to/image2.svs"] >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.predict(data, mode='patch') + >>> output = engine.run(data, mode='patch') >>> # array of list of 2 image patches as input + >>> import numpy as np >>> data = np.array([img1, img2]) >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.predict(data, mode='patch') + >>> output = engine.run(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.predict(data, mode='patch') + >>> output = engine.run(data, mode='patch') >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> engine = EngineABC(pretraind_model="resnet18-kather100k") - >>> output = engine.predict(tile_file, mode='tile') + >>> output = engine.run(tile_file, mode='tile') >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] >>> engine = EngineABC(pretraind_model="resnet18-kather100k") - >>> output = engine.predict(wsi_file, mode='wsi') + >>> output = engine.run(wsi_file, mode='wsi') """ @@ -97,12 +101,12 @@ def __init__( self, batch_size: int = 8, num_loader_workers: int = 0, - num_postproc_workers: int = 0, + num_post_proc_workers: int = 0, model: nn.Module = None, - pretrained_model: Optional[str] = None, - pretrained_weights: Optional[str] = None, + pretrained_model: str | None = None, + pretrained_weights: str | None = None, verbose: bool = False, - ): + ) -> None: """Initialize Engine.""" super().__init__() @@ -125,45 +129,55 @@ def __init__( self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_workers = num_loader_workers - self.num_postproc_workers = num_postproc_workers + self.num_post_proc_workers = num_post_proc_workers self.verbose = verbose @abstractmethod def pre_process_patch(self): + """Pre-process an image patch.""" raise NotImplementedError @abstractmethod def pre_process_tile(self): + """Pre-process an image tile.""" raise NotImplementedError @abstractmethod def pre_process_wsi(self): + """Pre-process a WSI.""" raise NotImplementedError @abstractmethod def infer_patch(self): + """Model inference on an image patch.""" raise NotImplementedError @abstractmethod def infer_tile(self): + """Model inference on an image tile.""" raise NotImplementedError @abstractmethod def infer_wsi(self): + """Model inference on a WSI.""" raise NotImplementedError @abstractmethod def post_process_patch(self): + """Post-process an image patch.""" raise NotImplementedError @abstractmethod def post_process_tile(self): + """Post-process an image tile.""" raise NotImplementedError @abstractmethod def post_process_wsi(self): + """Post-process a WSI.""" raise NotImplementedError @abstractmethod def run(self): + """Run engine.""" raise NotImplementedError diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 1696cac5a..75d8e9b0d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -203,7 +203,7 @@ def __init__( self, batch_size: int = 8, num_loader_workers: int = 0, - num_postproc_workers: int = 0, + num_post_proc_workers: int = 0, model: torch.nn.Module = None, pretrained_model: str | None = None, pretrained_weights: str | None = None, @@ -213,7 +213,7 @@ def __init__( super().__init__( batch_size=batch_size, num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, + num_post_proc_workers=num_post_proc_workers, model=model, pretrained_model=pretrained_model, pretrained_weights=pretrained_weights, @@ -221,34 +221,44 @@ def __init__( ) def pre_process_patch(self): - pass + """Pre-process an image patch.""" + raise NotImplementedError def pre_process_tile(self): - pass + """Pre-process an image tile.""" + raise NotImplementedError def pre_process_wsi(self): - pass - - def post_process_patch(self): - pass - - def post_process_tile(self): - pass - - def post_process_wsi(self): - pass + """Pre-process a WSI.""" + raise NotImplementedError def infer_patch(self): - pass + """Model inference on an image patch.""" + raise NotImplementedError def infer_tile(self): - pass + """Model inference on an image tile.""" + raise NotImplementedError def infer_wsi(self): - pass + """Model inference on a WSI.""" + raise NotImplementedError + + def post_process_patch(self): + """Post-process an image patch.""" + raise NotImplementedError + + def post_process_tile(self): + """Post-process an image tile.""" + raise NotImplementedError + + def post_process_wsi(self): + """Post-process a WSI.""" + raise NotImplementedError def run(self): - pass + """Run engine.""" + raise NotImplementedError @staticmethod def merge_predictions( @@ -256,7 +266,7 @@ def merge_predictions( output: dict, resolution: Resolution | None = None, units: Units | None = None, - postproc_func: Callable | None = None, + post_proc_func: Callable | None = None, return_raw: bool = False, ): """Merge patch level predictions to form a 2-dimensional prediction map. @@ -277,7 +287,7 @@ def merge_predictions( units (Units): Units of resolution used when merging predictions. This must be the same `units` used when processing the data. - postproc_func (callable): + post_proc_func (callable): A function to post-process raw prediction from model. By default, internal code uses the `np.argmax` function. return_raw (bool): @@ -359,8 +369,8 @@ def merge_predictions( output = output / (np.expand_dims(denominator, -1) + 1.0e-8) if not return_raw: # convert raw probabilities to predictions - if postproc_func is not None: - output = postproc_func(output) + if post_proc_func is not None: + output = post_proc_func(output) else: output = np.argmax(output, axis=-1) # to make sure background is 0 while class will be 1...N @@ -639,7 +649,7 @@ def _predict_tile_wsi( file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. - masks (list): + masks (list or None): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a @@ -740,7 +750,7 @@ def _predict_tile_wsi( output_model, resolution=output_model["resolution"], units=output_model["units"], - postproc_func=self.model.postproc, + post_proc_func=self.model.postproc, ) outputs.append(merged_prediction) From b4dceeafbc0d54d82d630ab328ecc9c8fdeff79c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 10 Aug 2023 14:57:26 +0100 Subject: [PATCH 046/112] :recycle: Remove tests for engines and move them to tests/engine --- .../_test_multi_task_segmentor.py} | 0 .../_test_nucleus_instance_segmentor.py} | 0 .../test_patch_predictor.py => engines/_test_patch_predictor.py} | 0 .../_test_semantic_segmentation.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/{models/test_multi_task_segmentor.py => engines/_test_multi_task_segmentor.py} (100%) rename tests/{models/test_nucleus_instance_segmentor.py => engines/_test_nucleus_instance_segmentor.py} (100%) rename tests/{models/test_patch_predictor.py => engines/_test_patch_predictor.py} (100%) rename tests/{models/test_semantic_segmentation.py => engines/_test_semantic_segmentation.py} (100%) diff --git a/tests/models/test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py similarity index 100% rename from tests/models/test_multi_task_segmentor.py rename to tests/engines/_test_multi_task_segmentor.py diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py similarity index 100% rename from tests/models/test_nucleus_instance_segmentor.py rename to tests/engines/_test_nucleus_instance_segmentor.py diff --git a/tests/models/test_patch_predictor.py b/tests/engines/_test_patch_predictor.py similarity index 100% rename from tests/models/test_patch_predictor.py rename to tests/engines/_test_patch_predictor.py diff --git a/tests/models/test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py similarity index 100% rename from tests/models/test_semantic_segmentation.py rename to tests/engines/_test_semantic_segmentation.py From 08039ded9a7e1f78db5c3ec1375248387c9d7391 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:00:41 +0100 Subject: [PATCH 047/112] :bug: Fix linter errors --- tiatoolbox/models/dataset/dataset_abc.py | 3 ++- tiatoolbox/models/engine/engine_abc.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index c74eefbbe..3fba84d8d 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -349,9 +349,10 @@ def __init__( stride_shape=None, resolution=None, units=None, - auto_get_mask=True, min_mask_ratio=0, preproc_func=None, + *, + auto_get_mask=True, ) -> None: """Create a WSI-level patch dataset. diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 56b90cce1..0f6c7e896 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -105,6 +105,7 @@ def __init__( model: nn.Module = None, pretrained_model: str | None = None, pretrained_weights: str | None = None, + *, verbose: bool = False, ) -> None: """Initialize Engine.""" From 3e8bcb4d951fbab15a2c42293f930a785047d820 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 10 Aug 2023 15:40:50 +0100 Subject: [PATCH 048/112] :sparkles: Add basic structure for `run` --- tests/engines/test_engine_abc.py | 11 ++ tiatoolbox/models/engine/__init__.py | 4 +- tiatoolbox/models/engine/engine_abc.py | 184 ++++++++++++++++++-- tiatoolbox/models/engine/patch_predictor.py | 45 +---- 4 files changed, 183 insertions(+), 61 deletions(-) create mode 100644 tests/engines/test_engine_abc.py diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py new file mode 100644 index 000000000..92d0afd65 --- /dev/null +++ b/tests/engines/test_engine_abc.py @@ -0,0 +1,11 @@ +"""Test tiatoolbox.models.engine.engine_abc.""" +import pytest + +from tiatoolbox.models.engine.engine_abc import EngineABC + + +def test_engine_abc(): + """Test EngineABC initialization.""" + with pytest.raises(TypeError): + # Can't instantiate abstract class with abstract methods + EngineABC() # skipcq diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 0a5968b44..7d0dfe0e1 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,11 +1,13 @@ """Engines to run models implemented in tiatoolbox.""" -from tiatoolbox.models.engine import ( +from . import ( + engine_abc, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, ) __all__ = [ + "engine_abc", "nucleus_instance_segmentor", "patch_predictor", "semantic_segmentor", diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 0f6c7e896..8b756f3bf 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -2,13 +2,20 @@ from __future__ import annotations from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING +from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model if TYPE_CHECKING: + import os + + import numpy as np from torch import nn + from .io_config import IOPatchPredictorConfig + class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. @@ -60,7 +67,7 @@ class EngineABC(ABC): By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument - is case insensitive. + is case-insensitive. batch_size (int): Number of images fed into the model each time. num_loader_workers (int): @@ -138,11 +145,6 @@ def pre_process_patch(self): """Pre-process an image patch.""" raise NotImplementedError - @abstractmethod - def pre_process_tile(self): - """Pre-process an image tile.""" - raise NotImplementedError - @abstractmethod def pre_process_wsi(self): """Pre-process a WSI.""" @@ -153,11 +155,6 @@ def infer_patch(self): """Model inference on an image patch.""" raise NotImplementedError - @abstractmethod - def infer_tile(self): - """Model inference on an image tile.""" - raise NotImplementedError - @abstractmethod def infer_wsi(self): """Model inference on a WSI.""" @@ -168,17 +165,166 @@ def post_process_patch(self): """Post-process an image patch.""" raise NotImplementedError - @abstractmethod - def post_process_tile(self): - """Post-process an image tile.""" - raise NotImplementedError - @abstractmethod def post_process_wsi(self): """Post-process a WSI.""" raise NotImplementedError + @staticmethod + def _prepare_save_dir(save_dir: os | Path, images: list | np.ndarray) -> Path: + """Create directory if not defined and number of images is more than 1. + + Args: + save_dir (str or Path): + Path to output directory. + images (list, ndarray): + List of inputs to process. + + Returns: + :class:`Path`: + Path to output directory. + + """ + if save_dir is None and len(images) > 1: + logger.warning( + "More than 1 WSIs detected but there is no save directory set." + "All subsequent output will be saved to current runtime" + "location under folder 'output'. Overwriting may happen!", + stacklevel=2, + ) + save_dir = Path.cwd() / "output" + elif save_dir is not None and len(images) > 1: + logger.warning( + "When providing multiple whole-slide images / tiles, " + "the outputs will be saved and the locations of outputs" + "will be returned" + "to the calling function.", + stacklevel=2, + ) + + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=False) + return save_dir + + return Path.cwd() / "output" + @abstractmethod - def run(self): - """Run engine.""" - raise NotImplementedError + def run( + self, + images: list[os | Path] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + mode: str = "patch", + ioconfig: IOPatchPredictorConfig | None = None, + patch_input_shape: tuple[int, int] | None = None, + stride_shape: tuple[int, int] | None = None, + resolution=None, + units=None, + *, + return_probabilities=False, + return_labels=False, + on_gpu=True, + merge_predictions=False, + save_dir=None, + save_output=False, + **kwargs: dict, + ) -> np.ndarray | dict: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. when using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. When using + `tile` or `wsi` mode, the input must be a list of file + paths. + masks (list): + List of masks. Only utilised when processing image tiles + and whole-slide images. Patches are only processed if + they are within a masked area. If not provided, then a + tissue mask will be automatically generated for + whole-slide images or the entire image is processed for + image tiles. + labels: + List of labels. If using `tile` or `wsi` mode, then only + a single label per image tile or whole-slide image is + supported. + mode (str): + Type of input to process. Choose from either `patch`, + `tile` or `wsi`. + return_probabilities (bool): + Whether to return per-class probabilities. + return_labels (bool): + Whether to return the labels with the predictions. + on_gpu (bool): + Whether to run model on the GPU. + ioconfig (IOPatchPredictorConfig): + Patch Predictor IO configuration. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride using during tile and WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable for `mode='wsi'` or + `mode='tile'`. + save_dir (str or pathlib.Path): + Output directory when processing multiple tiles and + whole-slide images. By default, it is folder `output` + where the running script is invoked. + save_output (bool): + Whether to save output for a single file. default=False + **kwargs (dict): + Keyword Args for ... + + Returns: + (:class:`numpy.ndarray`, dict): + Model predictions of the input dataset. If multiple + image tiles or whole-slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + - merged: path to .npy contain merged + predictions if `merge_predictions` is `True`. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> predictor = EngineABC( + ... pretrained_model="resnet18-kather100k") + >>> output = predictor.run(wsis, mode="wsi") + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} + >>> output['wsi2.svs'] + ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + + """ + if mode not in ["patch", "wsi"]: + msg = f"{mode} is not a valid mode. Use either `patch` or `wsi`." + raise ValueError( + msg, + ) + + save_dir = self._prepare_save_dir(save_dir, images) + + return {"save_dir": save_dir} diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index ede723af4..aa23e6b3d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -257,10 +257,6 @@ def post_process_wsi(self): """Post-process a WSI.""" raise NotImplementedError - def run(self): - """Run engine.""" - raise NotImplementedError - @staticmethod def merge_predictions( img: str | Path | np.ndarray, @@ -545,43 +541,6 @@ def _update_ioconfig( output_resolutions=[], ) - @staticmethod - def _prepare_save_dir(save_dir, imgs): - """Create directory if not defined and number of images is more than 1. - - Args: - save_dir (str or pathlib.Path): - Path to output directory. - imgs (list, ndarray): - List of inputs to process. - - Returns: - :class:`pathlib.Path`: - Path to output directory. - - """ - if save_dir is None and len(imgs) > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory set." - "All subsequent output will be saved to current runtime" - "location under folder 'output'. Overwriting may happen!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len(imgs) > 1: - logger.warning( - "When providing multiple whole-slide images / tiles, " - "we save the outputs and return the locations " - "to the corresponding files.", - stacklevel=2, - ) - - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) - - return save_dir - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): """Process patch mode. @@ -774,6 +733,10 @@ def _predict_tile_wsi( return file_dict if save_output else outputs + def run(self): + """Run engine.""" + super().run() + def predict( self, imgs, From ee076cb5182d08043cc1100469358111ea5ad45b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 11 Aug 2023 11:10:22 +0100 Subject: [PATCH 049/112] :memo: Add suggestions --- tests/engines/_test_patch_predictor.py | 2 +- tiatoolbox/cli/patch_predictor.py | 2 +- tiatoolbox/models/engine/engine_abc.py | 60 ++++++++++----------- tiatoolbox/models/engine/patch_predictor.py | 12 ++--- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py index 2425bc0f0..f88fb9fe0 100644 --- a/tests/engines/_test_patch_predictor.py +++ b/tests/engines/_test_patch_predictor.py @@ -674,7 +674,7 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path: Path) -> No _ = PatchPredictor( pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, + weights=pretrained_weights, batch_size=1, ) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 2c754d1c6..8c6128e8c 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -83,7 +83,7 @@ def patch_predictor( predictor = PatchPredictor( pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + weights=pretrained_weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8b756f3bf..4cd4751f2 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -14,6 +14,8 @@ import numpy as np from torch import nn + from tiatoolbox.annotation.storage import Annotation + from .io_config import IOPatchPredictorConfig @@ -25,7 +27,6 @@ class EngineABC(ABC): Use externally defined PyTorch model for prediction with. weights already loaded. Default is `None`. If provided, `pretrained_model` argument is ignored. - pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs @@ -34,12 +35,12 @@ class EngineABC(ABC): be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - pretrained_weights (str): + weights (str): Path to the weight of the corresponding `pretrained_model`. >>> engine = EngineABC( ... pretrained_model="pretrained-model-name", - ... pretrained_weights="pretrained-local-weights.pth") + ... weights="pretrained-local-weights.pth") batch_size (int): Number of images fed into the model each time. @@ -57,9 +58,8 @@ class EngineABC(ABC): mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. - model (nn.Module): + model (str | nn.Module): Defined PyTorch model. - pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs @@ -106,12 +106,11 @@ class EngineABC(ABC): def __init__( self, + model: nn.Module, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, - model: nn.Module = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, + weights: str | None = None, *, verbose: bool = False, ) -> None: @@ -121,20 +120,15 @@ def __init__( self.images = None self.mode = None - if model is None and pretrained_model is None: - msg = "Must provide either `model` or `pretrained_model`." - raise ValueError(msg) - if model is not None: self.model = model ioconfig = None # retrieve ioconfig from provided model. else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) + model, ioconfig = get_pretrained_model(model, weights) self.ioconfig = ioconfig # for storing original self._ioconfig = self.ioconfig # runtime ioconfig self.model = model # for runtime, such as after wrapping with nn.DataParallel - self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers @@ -214,20 +208,22 @@ def run( self, images: list[os | Path] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, - labels: list | None = None, - mode: str = "patch", + labels: list | None = None, # kwargs ioconfig: IOPatchPredictorConfig | None = None, - patch_input_shape: tuple[int, int] | None = None, - stride_shape: tuple[int, int] | None = None, - resolution=None, - units=None, + patch_input_shape: tuple[int, int] | None = None, # kwargs (ioconfig) + stride_shape: tuple[int, int] | None = None, # kwargs (ioconfig) + resolution=None, # kwargs (pass to wsireader) + units=None, # kwargs (pass to wsireader) *, - return_probabilities=False, - return_labels=False, + patch_mode: bool = False, + return_probabilities=False, # kwargs + return_labels=False, # kwargs on_gpu=True, - merge_predictions=False, + merge_predictions=False, # kwargs save_dir=None, - save_output=False, + # None will not save output + # save_output can be np.ndarray + save_output: np.ndarray | Annotation | str | None = True, **kwargs: dict, ) -> np.ndarray | dict: """Run the engine on input images. @@ -250,9 +246,9 @@ def run( List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = False. return_probabilities (bool): Whether to return per-class probabilities. return_labels (bool): @@ -319,11 +315,11 @@ def run( ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} """ - if mode not in ["patch", "wsi"]: - msg = f"{mode} is not a valid mode. Use either `patch` or `wsi`." - raise ValueError( - msg, - ) + self._update(kwargs) # prefer kwargs as attribute and update if required. + + # if mode not in ["patch", "wsi"]: + # raise ValueError( + # msg, save_dir = self._prepare_save_dir(save_dir, images) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index aa23e6b3d..ae62554bf 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -123,12 +123,12 @@ class PatchPredictor(EngineABC): be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - pretrained_weights (str): + weights (str): Path to the weight of the corresponding `pretrained_model`. >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k", - ... pretrained_weights="resnet18_local_weight") + ... weights="resnet18_local_weight") batch_size (int): Number of images fed into the model each time. @@ -146,7 +146,7 @@ class PatchPredictor(EngineABC): or `wsi`. model (nn.Module): Defined PyTorch model. - pretrained_model (str): + model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs @@ -206,7 +206,7 @@ def __init__( num_post_proc_workers: int = 0, model: torch.nn.Module = None, pretrained_model: str | None = None, - pretrained_weights: str | None = None, + weights: str | None = None, *, verbose=True, ) -> None: @@ -217,7 +217,7 @@ def __init__( num_post_proc_workers=num_post_proc_workers, model=model, pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + weights=weights, verbose=verbose, ) @@ -700,7 +700,7 @@ def _predict_tile_wsi( ) output_model["label"] = img_label # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.pretrained_model + output_model["pretrained_model"] = self.model output_model["resolution"] = highest_input_resolution["resolution"] output_model["units"] = highest_input_resolution["units"] From e8b0f9050763eaf4d89e8afb89885660d336dc8a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 14 Aug 2023 15:52:42 +0100 Subject: [PATCH 050/112] :art: Remove extra input arguments for `run`. Use some variables as attributes and update them with `setattr`. --- tiatoolbox/models/engine/engine_abc.py | 53 ++++++-------------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 4cd4751f2..a94326b34 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -133,6 +133,14 @@ def __init__( self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.verbose = verbose + self.return_probabilities = False + self.return_labels = False + self.merge_predictions = False + self.units = "baseline" + self.resolution = 1.0 + self.patch_input_shape = None + self.stride_shape = None + self.labels = None @abstractmethod def pre_process_patch(self): @@ -208,21 +216,13 @@ def run( self, images: list[os | Path] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, - labels: list | None = None, # kwargs ioconfig: IOPatchPredictorConfig | None = None, - patch_input_shape: tuple[int, int] | None = None, # kwargs (ioconfig) - stride_shape: tuple[int, int] | None = None, # kwargs (ioconfig) - resolution=None, # kwargs (pass to wsireader) - units=None, # kwargs (pass to wsireader) *, patch_mode: bool = False, - return_probabilities=False, # kwargs - return_labels=False, # kwargs on_gpu=True, - merge_predictions=False, # kwargs save_dir=None, # None will not save output - # save_output can be np.ndarray + # save_output can be np.ndarray, Annotation or Json str save_output: np.ndarray | Annotation | str | None = True, **kwargs: dict, ) -> np.ndarray | dict: @@ -242,41 +242,13 @@ def run( tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. patch_mode (bool): Whether to treat input image as a patch or WSI. default = False. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. on_gpu (bool): Whether to run model on the GPU. ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. Choose - from either `level`, `power` or `mpp`. Please see - :obj:`WSIReader` for details. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` @@ -315,11 +287,8 @@ def run( ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} """ - self._update(kwargs) # prefer kwargs as attribute and update if required. - - # if mode not in ["patch", "wsi"]: - # raise ValueError( - # msg, + for key in kwargs: + setattr(self, key, kwargs[key]) save_dir = self._prepare_save_dir(save_dir, images) From 7bb62c1b8370790412edcc9360dfa7304e87b676 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:25:28 +0100 Subject: [PATCH 051/112] :memo: Update documentation. --- tiatoolbox/models/engine/engine_abc.py | 37 +++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index a94326b34..9ffdb27b5 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -16,7 +16,7 @@ from tiatoolbox.annotation.storage import Annotation - from .io_config import IOPatchPredictorConfig + from .io_config import ModelIOConfigABC class EngineABC(ABC): @@ -68,8 +68,39 @@ class EngineABC(ABC): be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration to run the Engine. + _ioconfig (): + Runtime ioconfig. + return_probabilities (bool): + Whether to return per-class probabilities. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable `patch_mode` is False in inference. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride using during tile and WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. + labels: + List of labels. If using `tile` or `wsi` mode, then only + a single label per image tile or whole-slide image is + supported. num_loader_workers (int): Number of workers used in torch.utils.data.DataLoader. verbose (bool): @@ -216,7 +247,7 @@ def run( self, images: list[os | Path] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, - ioconfig: IOPatchPredictorConfig | None = None, + ioconfig: ModelIOConfigABC | None = None, *, patch_mode: bool = False, on_gpu=True, @@ -248,7 +279,7 @@ def run( on_gpu (bool): Whether to run model on the GPU. ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration. + IO configuration. save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` From 8866110343c999083d02825106e182e8702849f6 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:27:52 +0100 Subject: [PATCH 052/112] :bug: Remove _rm_dir --- tests/engines/_test_semantic_segmentation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/engines/_test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py index f923cb65b..61a51d1e2 100644 --- a/tests/engines/_test_semantic_segmentation.py +++ b/tests/engines/_test_semantic_segmentation.py @@ -289,8 +289,6 @@ def test_crash_segmentor(remote_sample: Callable) -> None: units="baseline", ) - _rm_dir("output") - with pytest.raises(ValueError, match=r"Invalid resolution.*"): semantic_segmentor.predict( [mini_wsi_svs], From c6444e469247f376e641f4756bc085471322f0ec Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 24 Aug 2023 14:13:04 +0100 Subject: [PATCH 053/112] :rewind: Remove flake8-annotations check. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0bead17ab..91551031f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ select = [ "C90", # mccabe "T10", # flake8-debugger "T20", # flake8-print - "ANN", # flake8-annotations + # "ANN", # flake8-annotations "ARG", # flake8-unused-arguments "BLE", # flake8-blind-except "COM", # flake8-commas From 8faa23fa779975d14417da7f181ef2b17062f647 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 24 Aug 2023 17:03:48 +0100 Subject: [PATCH 054/112] :white_check_mark: Add default model parameter checks. --- tests/engines/test_engine_abc.py | 60 +++++++++++++++++++++- tiatoolbox/models/architecture/__init__.py | 34 ++++++------ tiatoolbox/models/engine/engine_abc.py | 39 +++++++------- 3 files changed, 97 insertions(+), 36 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 92d0afd65..7950308a6 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -1,11 +1,69 @@ """Test tiatoolbox.models.engine.engine_abc.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, NoReturn + import pytest from tiatoolbox.models.engine.engine_abc import EngineABC +if TYPE_CHECKING: + import torch.nn + + +class TestEngineABC(EngineABC): + """Test EngineABC.""" + + def __init__(self: TestEngineABC, model: str | torch.nn.Module) -> NoReturn: + """Test EngineABC init.""" + super().__init__(model=model) + + def infer_patch(self: EngineABC) -> NoReturn: + """Test infer_patch.""" + ... + + def infer_wsi(self: EngineABC) -> NoReturn: + """Test infer_wsi.""" + ... + + def post_process_patch(self: EngineABC) -> NoReturn: + """Test post_process_patch.""" + ... + + def post_process_wsi(self: EngineABC) -> NoReturn: + """Test post_process_wsi.""" + ... + + def pre_process_patch(self: EngineABC) -> NoReturn: + """Test pre_process_patch.""" + ... + + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Test pre_process_wsi.""" + ... + def test_engine_abc(): """Test EngineABC initialization.""" - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match=r".*Can't instantiate abstract class EngineABC with abstract methods*", + ): # Can't instantiate abstract class with abstract methods EngineABC() # skipcq + + +def test_engine_abc_incorrect_model_type(): + """Test EngineABC initialization with incorrect model type.""" + with pytest.raises( + TypeError, + match=r".*missing 1 required positional argument: 'model'", + ): + TestEngineABC() # skipcq + + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + # Can't instantiate abstract class with abstract methods + TestEngineABC(model=1) # skipcq diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index d37ad5c80..b60f22080 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -59,15 +59,15 @@ def fetch_pretrained_weights( def get_pretrained_model( - pretrained_model: str | None = None, - pretrained_weights: str | Path | None = None, + model: str | None = None, + weights: str | Path | None = None, *, overwrite: bool = False, ) -> tuple[torch.nn.Module, IOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: - pretrained_model (str): + model (str): Name of the existing models support by tiatoolbox for processing the data. The models currently supported: @@ -99,7 +99,7 @@ def get_pretrained_model( downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - pretrained_weights (str): + weights (str): Path to the weight of the corresponding `pretrained_model`. overwrite (bool): @@ -107,25 +107,25 @@ def get_pretrained_model( Examples: >>> # get mobilenet pretrained on Kather100K dataset by the TIA team - >>> model = get_pretrained_model(pretrained_model='mobilenet_v2-kather100k') + >>> model = get_pretrained_model(model='mobilenet_v2-kather100k') >>> # get mobilenet defined by TIA team, but loaded with user defined weights >>> model = get_pretrained_model( - ... pretrained_model='mobilenet_v2-kather100k', - ... pretrained_weights='/A/B/C/my_weights.tar', + ... model='mobilenet_v2-kather100k', + ... weights='/A/B/C/my_weights.tar', ... ) >>> # get resnet34 pretrained on PCam dataset by TIA team - >>> model = get_pretrained_model(pretrained_model='resnet34-pcam') + >>> model = get_pretrained_model(model='resnet34-pcam') """ - if not isinstance(pretrained_model, str): - msg = "pretrained_model must be a string." + if not isinstance(model, str): + msg = "Input model must be a string." raise TypeError(msg) - if pretrained_model not in PRETRAINED_INFO: - msg = f"Pretrained model `{pretrained_model}` does not exist." + if model not in PRETRAINED_INFO: + msg = f"Pretrained model `{model}` does not exist." raise ValueError(msg) - info = PRETRAINED_INFO[pretrained_model] + info = PRETRAINED_INFO[model] arch_info = info["architecture"] creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}") @@ -137,15 +137,15 @@ def get_pretrained_model( # ! associated pre-processing coming from dataset (Kumar, Kather, etc.) model.preproc_func = predefined_preproc_func(info["dataset"]) - if pretrained_weights is None: - pretrained_weights = fetch_pretrained_weights( - pretrained_model, + if weights is None: + weights = fetch_pretrained_weights( + model, overwrite=overwrite, ) # ! assume to be saved in single GPU mode # always load on to the CPU - saved_state_dict = torch.load(pretrained_weights, map_location="cpu") + saved_state_dict = torch.load(weights, map_location="cpu") model.load_state_dict(saved_state_dict, strict=True) # ! diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 5a24b7e54..478fc3468 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import TYPE_CHECKING, NoReturn +from torch import nn + from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model @@ -12,11 +14,6 @@ import os import numpy as np - from torch import nn - - from tiatoolbox.annotation.storage import Annotation - - from .io_config import ModelIOConfigABC class EngineABC(ABC): @@ -137,7 +134,7 @@ class EngineABC(ABC): def __init__( self: EngineABC, - model: nn.Module, + model: str | nn.Module, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, @@ -151,15 +148,22 @@ def __init__( self.images = None self.mode = None - if model is not None: - self.model = model - ioconfig = None # retrieve ioconfig from provided model. - else: - model, ioconfig = get_pretrained_model(model, weights) + ioconfig = None # retrieve ioconfig from provided model. + + if not isinstance(model, (str, nn.Module)): + msg = "Input model must be a string or 'torch.nn.Module'." + raise TypeError(msg) + + if isinstance(model, str): + self.model, ioconfig = get_pretrained_model(model, weights) + + if isinstance(model, nn.Module): + self.model = ( + model # for runtime, such as after wrapping with nn.DataParallel + ) self.ioconfig = ioconfig # for storing original self._ioconfig = self.ioconfig # runtime ioconfig - self.model = model # for runtime, such as after wrapping with nn.DataParallel self.batch_size = batch_size self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers @@ -245,19 +249,18 @@ def _prepare_save_dir( return Path.cwd() / "output" - @abstractmethod def run( self: EngineABC, images: list[os | Path] | np.ndarray, - masks: list[os | Path] | np.ndarray | None = None, - ioconfig: ModelIOConfigABC | None = None, + # masks: list[os | Path] | np.ndarray | None = None, # noqa: ERA001 + # ioconfig: ModelIOConfigABC | None = None, # noqa: ERA001 *, - patch_mode: bool = False, - on_gpu: bool = True, + # patch_mode: bool = False, # noqa: ERA001 + # on_gpu: bool = True, # noqa: ERA001 save_dir: os | Path | None = None, # None will not save output # save_output can be np.ndarray, Annotation or Json str - save_output: np.ndarray | Annotation | str | None = True, + # save_output: np.ndarray | Annotation | str | None = True, # noqa: ERA001 **kwargs: dict, ) -> np.ndarray | dict: """Run the engine on input images. From 7e1ed8a9891aacf48fd847ae5ba5c9a9c5e93357 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:20:45 +0100 Subject: [PATCH 055/112] :recycle: Remove test_feature_extractor.py --- .../_test_feature_extractor.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{models/test_feature_extractor.py => engines/_test_feature_extractor.py} (100%) diff --git a/tests/models/test_feature_extractor.py b/tests/engines/_test_feature_extractor.py similarity index 100% rename from tests/models/test_feature_extractor.py rename to tests/engines/_test_feature_extractor.py From d44bddd527bb93699db5d3c18687ab4df0fa900e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:53:06 +0100 Subject: [PATCH 056/112] :sparkles: Add _initialize_model_ioconfig --- tiatoolbox/models/__init__.py | 30 ++++++++++ tiatoolbox/models/architecture/__init__.py | 11 ++-- tiatoolbox/models/engine/engine_abc.py | 69 ++++++++++++++++------ 3 files changed, 85 insertions(+), 25 deletions(-) diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index ecd173ced..9e90832d4 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -1,4 +1,10 @@ """Models package for the models implemented in tiatoolbox.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + from . import architecture, dataset, engine, models_abc from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus @@ -19,6 +25,9 @@ from .engine.patch_predictor import PatchPredictor from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor +if TYPE_CHECKING: # pragma: no cover + from pathlib import Path + __all__ = [ "architecture", "dataset", @@ -43,4 +52,25 @@ "WSIStreamDataset", "WSIPatchDataset", "PatchDataset", + "load_torch_model", ] + + +def load_torch_model(model: torch.nn.Module, weights: str | Path) -> torch.nn.Module: + """Helper function to load a torch model. + + Args: + model (torch.nn.Module): + A torch model. + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + return model.load_state_dict(saved_state_dict, strict=True) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index b60f22080..04a418fa8 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -1,19 +1,19 @@ """Define a set of models to be used within tiatoolbox.""" from __future__ import annotations -import os from pydoc import locate from typing import TYPE_CHECKING, Optional, Union -import torch - from tiatoolbox import rcParam +from tiatoolbox.models import load_torch_model from tiatoolbox.models.dataset.classification import predefined_preproc_func from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover from pathlib import Path + import torch + from tiatoolbox.models.models_abc import IOConfigABC @@ -143,10 +143,7 @@ def get_pretrained_model( overwrite=overwrite, ) - # ! assume to be saved in single GPU mode - # always load on to the CPU - saved_state_dict = torch.load(weights, map_location="cpu") - model.load_state_dict(saved_state_dict, strict=True) + model = load_torch_model(model=model, weights=weights) # ! io_info = info["ioconfig"] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 478fc3468..9661e709e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -8,6 +8,7 @@ from torch import nn from tiatoolbox import logger +from tiatoolbox.models import load_torch_model from tiatoolbox.models.architecture import get_pretrained_model if TYPE_CHECKING: @@ -32,7 +33,7 @@ class EngineABC(ABC): be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - weights (str): + weights (str or Path): Path to the weight of the corresponding `pretrained_model`. >>> engine = EngineABC( @@ -138,7 +139,7 @@ def __init__( batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, - weights: str | None = None, + weights: str | Path | None = None, *, verbose: bool = False, ) -> None: @@ -148,22 +149,6 @@ def __init__( self.images = None self.mode = None - ioconfig = None # retrieve ioconfig from provided model. - - if not isinstance(model, (str, nn.Module)): - msg = "Input model must be a string or 'torch.nn.Module'." - raise TypeError(msg) - - if isinstance(model, str): - self.model, ioconfig = get_pretrained_model(model, weights) - - if isinstance(model, nn.Module): - self.model = ( - model # for runtime, such as after wrapping with nn.DataParallel - ) - - self.ioconfig = ioconfig # for storing original - self._ioconfig = self.ioconfig # runtime ioconfig self.batch_size = batch_size self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers @@ -177,6 +162,54 @@ def __init__( self.stride_shape = None self.labels = None + # Initialize model with specified weights and ioconfig. + self._initialize_model_ioconfig(model=model, weights=weights) + + def _initialize_model_ioconfig( + self: EngineABC, + model: str | nn.Module, + weights: str | Path | None, + ) -> NoReturn: + """Helper function to initialize model and ioconfig attributes. + + If a pretrained model provided by the TIAToolbox is requested. The model + can be specified as a string otherwise torch.nn.Module is required. + This function also loads the :class:`ModelIOConfigABC` using the information + from the pretrained models in TIAToolbox. If ioconfig is not available then it + should be provided in the :func:`run` function. + + Args: + model (str | nn.Module): + A torch model which should be run by the engine. + + weights (str | Path | None): + Path to pretrained weights. If no pretrained weights are provided + and the model is provided by TIAToolbox, then pretrained weights will + be automatically loaded from the TIA servers. + + """ + if not isinstance(model, (str, nn.Module)): + msg = "Input model must be a string or 'torch.nn.Module'." + raise TypeError(msg) + + if isinstance(model, nn.Module): + self.model = ( + model # for runtime, such as after wrapping with nn.DataParallel + ) + + if weights is not None: + self.model = load_torch_model(model=self.model, weights=weights) + + ioconfig = None # requires ioconfig to be provided in EngineABC.run(). + + if isinstance(model, str): + # ioconfig is retrieved from the pretrained model in the toolbox. + # no need to provide ioconfig in EngineABC.run() this case. + self.model, ioconfig = get_pretrained_model(model, weights) + + self.ioconfig = ioconfig # for storing original + self._ioconfig = self.ioconfig # runtime ioconfig + @abstractmethod def pre_process_patch(self: EngineABC) -> NoReturn: """Pre-process an image patch.""" From 105f32e3d70ad0c1381a2d819ef907211822e621 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:05:48 +0100 Subject: [PATCH 057/112] :recycle: Refactor `load_torch_model` and `model_to` to models_abc.py --- tests/models/test_abc.py | 22 +++++++- tests/models/test_arch_vanilla.py | 2 +- tests/test_utils.py | 18 ------- tiatoolbox/models/__init__.py | 28 ---------- tiatoolbox/models/architecture/__init__.py | 15 +++--- tiatoolbox/models/engine/engine_abc.py | 10 ++-- tiatoolbox/models/engine/patch_predictor.py | 5 +- .../models/engine/semantic_segmentor.py | 8 ++- tiatoolbox/models/models_abc.py | 53 +++++++++++++++++-- tiatoolbox/utils/misc.py | 20 ------- 10 files changed, 94 insertions(+), 87 deletions(-) diff --git a/tests/models/test_abc.py b/tests/models/test_abc.py index c097499f0..52fa4a6e8 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_abc.py @@ -1,8 +1,10 @@ """Unit test package for ABC and __init__ .""" +from __future__ import annotations import pytest -from tiatoolbox import rcParam +import tiatoolbox.models +from tiatoolbox import rcParam, utils from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import env_detection as toolbox_env @@ -105,3 +107,21 @@ def infer_batch() -> None: # coverage setter check model.postproc_func = None # skipcq: PYL-W0201 assert model.postproc_func(2) == 0 + + +def test_model_to() -> None: + """Test for placing model on device.""" + import torchvision.models as torch_models + from torch import nn + + # Test on GPU + # no GPU on Travis so this will crash + if not utils.env_detection.has_gpu(): + model = torch_models.resnet18() + with pytest.raises((AssertionError, RuntimeError)): + _ = tiatoolbox.models.models_abc.model_to(on_gpu=True, model=model) + + # Test on CPU + model = torch_models.resnet18() + model = tiatoolbox.models.models_abc.model_to(on_gpu=False, model=model) + assert isinstance(model, nn.Module) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 26020aa07..a2b1ac5c9 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -5,7 +5,7 @@ import torch from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.utils.misc import model_to +from tiatoolbox.models.models_abc import model_to ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator diff --git a/tests/test_utils.py b/tests/test_utils.py index 4eb2d42cd..ef6afa734 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1312,24 +1312,6 @@ def test_select_device() -> None: assert device == "cpu" -def test_model_to() -> None: - """Test for placing model on device.""" - import torchvision.models as torch_models - from torch import nn - - # Test on GPU - # no GPU on Travis so this will crash - if not utils.env_detection.has_gpu(): - model = torch_models.resnet18() - with pytest.raises((AssertionError, RuntimeError)): - _ = misc.model_to(on_gpu=True, model=model) - - # Test on CPU - model = torch_models.resnet18() - model = misc.model_to(on_gpu=False, model=model) - assert isinstance(model, nn.Module) - - def test_save_as_json(tmp_path: Path) -> None: """Test save data to json.""" # This should be broken up into separate tests! diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 9e90832d4..e91a3b68c 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -1,10 +1,6 @@ """Models package for the models implemented in tiatoolbox.""" from __future__ import annotations -from typing import TYPE_CHECKING - -import torch - from . import architecture, dataset, engine, models_abc from .architecture.hovernet import HoVerNet from .architecture.hovernetplus import HoVerNetPlus @@ -25,9 +21,6 @@ from .engine.patch_predictor import PatchPredictor from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor -if TYPE_CHECKING: # pragma: no cover - from pathlib import Path - __all__ = [ "architecture", "dataset", @@ -52,25 +45,4 @@ "WSIStreamDataset", "WSIPatchDataset", "PatchDataset", - "load_torch_model", ] - - -def load_torch_model(model: torch.nn.Module, weights: str | Path) -> torch.nn.Module: - """Helper function to load a torch model. - - Args: - model (torch.nn.Module): - A torch model. - weights (str or Path): - Path to pretrained weights. - - Returns: - torch.nn.Module: - Torch model with pretrained weights loaded on CPU. - - """ - # ! assume to be saved in single GPU mode - # always load on to the CPU - saved_state_dict = torch.load(weights, map_location="cpu") - return model.load_state_dict(saved_state_dict, strict=True) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 04a418fa8..852c3e2b5 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -2,11 +2,11 @@ from __future__ import annotations from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from tiatoolbox import rcParam -from tiatoolbox.models import load_torch_model from tiatoolbox.models.dataset.classification import predefined_preproc_func +from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover @@ -14,8 +14,7 @@ import torch - from tiatoolbox.models.models_abc import IOConfigABC - + from tiatoolbox.models.engine.io_config import ModelIOConfigABC __all__ = ["get_pretrained_model", "fetch_pretrained_weights"] PRETRAINED_INFO = rcParam["pretrained_model_info"] @@ -63,7 +62,7 @@ def get_pretrained_model( weights: str | Path | None = None, *, overwrite: bool = False, -) -> tuple[torch.nn.Module, IOConfigABC]: +) -> tuple[torch.nn.Module, ModelIOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: @@ -107,14 +106,14 @@ def get_pretrained_model( Examples: >>> # get mobilenet pretrained on Kather100K dataset by the TIA team - >>> model = get_pretrained_model(model='mobilenet_v2-kather100k') + >>> model, ioconfig = get_pretrained_model(model='mobilenet_v2-kather100k') >>> # get mobilenet defined by TIA team, but loaded with user defined weights - >>> model = get_pretrained_model( + >>> model, ioconfig = get_pretrained_model( ... model='mobilenet_v2-kather100k', ... weights='/A/B/C/my_weights.tar', ... ) >>> # get resnet34 pretrained on PCam dataset by TIA team - >>> model = get_pretrained_model(model='resnet34-pcam') + >>> model, ioconfig = get_pretrained_model(model='resnet34-pcam') """ if not isinstance(model, str): diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9661e709e..dbedd954d 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -8,14 +8,16 @@ from torch import nn from tiatoolbox import logger -from tiatoolbox.models import load_torch_model from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.models_abc import load_torch_model if TYPE_CHECKING: import os import numpy as np + from tiatoolbox.annotation import AnnotationStore + class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. @@ -292,10 +294,10 @@ def run( # on_gpu: bool = True, # noqa: ERA001 save_dir: os | Path | None = None, # None will not save output - # save_output can be np.ndarray, Annotation or Json str - # save_output: np.ndarray | Annotation | str | None = True, # noqa: ERA001 + # output_type can be np.ndarray, Annotation or Json str + # output_type: np.ndarray | Annotation | str = Annotation, # noqa: ERA001 **kwargs: dict, - ) -> np.ndarray | dict: + ) -> AnnotationStore | np.ndarray | dict | str: """Run the engine on input images. Args: diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index d111dc274..58c18a63f 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -10,9 +10,10 @@ import torch import tqdm +import tiatoolbox.models.models_abc from tiatoolbox import logger from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset -from tiatoolbox.utils import misc, save_as_json +from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -425,7 +426,7 @@ def _predict_engine( ) # use external for testing - model = misc.model_to(model=self.model, on_gpu=on_gpu) + model = tiatoolbox.models.models_abc.model_to(model=self.model, on_gpu=on_gpu) cum_output = { "probabilities": [], diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 33acd4bd5..e1341c640 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -16,11 +16,12 @@ import torch.utils.data as torch_data import tqdm +import tiatoolbox.models.models_abc from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread, misc +from tiatoolbox.utils import imread from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader from .io_config import IOSegmentorConfig @@ -1049,7 +1050,10 @@ def predict( # noqa: PLR0913 # use external for testing self._on_gpu = on_gpu - self._model = misc.model_to(model=self.model, on_gpu=on_gpu) + self._model = tiatoolbox.models.models_abc.model_to( + model=self.model, + on_gpu=on_gpu, + ) # workers should be > 0 else Value Error will be thrown self._prepare_workers() diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 4edc5defa..e8d2eba3d 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -1,9 +1,56 @@ """Define Abstract Base Class for Models defined in tiatoolbox.""" +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -import numpy as np +import torch from torch import nn +if TYPE_CHECKING: # pragma: no cover + from pathlib import Path + + import numpy as np + + +def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. + + Args: + model (torch.nn.Module): + A torch model. + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + return model.load_state_dict(saved_state_dict, strict=True) + + +def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): PyTorch defined model. + on_gpu (bool): Transfers model to gpu if True otherwise to cpu. + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + + """ + if on_gpu: # DataParallel work only for cuda + model = torch.nn.DataParallel(model) + return model.to("cuda") + + return model.to("cpu") + class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" @@ -67,7 +114,7 @@ def preproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func - >>> transformed_img = model.preproc_func(img) + >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -98,7 +145,7 @@ def postproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func - >>> transformed_img = model.postproc_func(img) + >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 89e60970e..98a6fe7ec 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -14,7 +14,6 @@ import numpy as np import pandas as pd import requests -import torch import yaml from filelock import FileLock from shapely.affinity import translate @@ -873,25 +872,6 @@ def select_device(*, on_gpu: bool) -> str: return "cpu" -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu") - - def get_bounding_box(img: np.ndarray) -> np.ndarray: """Get bounding box coordinate information. From 79999441d12f88b339029a5deaa2153fe07e4316 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:17:33 +0100 Subject: [PATCH 058/112] :bug: Fix deepsource errors --- tests/engines/test_engine_abc.py | 12 ++++++------ tests/test_annotation_tilerendering.py | 2 +- tests/test_utils.py | 5 ++++- tiatoolbox/annotation/storage.py | 15 +++++++++++++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 7950308a6..f89e27331 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -20,27 +20,27 @@ def __init__(self: TestEngineABC, model: str | torch.nn.Module) -> NoReturn: def infer_patch(self: EngineABC) -> NoReturn: """Test infer_patch.""" - ... + ... # dummy function for tests. def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" - ... + ... # dummy function for tests. def post_process_patch(self: EngineABC) -> NoReturn: """Test post_process_patch.""" - ... + ... # dummy function for tests. def post_process_wsi(self: EngineABC) -> NoReturn: """Test post_process_wsi.""" - ... + ... # dummy function for tests. def pre_process_patch(self: EngineABC) -> NoReturn: """Test pre_process_patch.""" - ... + ... # dummy function for tests. def pre_process_wsi(self: EngineABC) -> NoReturn: """Test pre_process_wsi.""" - ... + ... # dummy function for tests. def test_engine_abc(): diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py index 0950698bc..e53fd0373 100644 --- a/tests/test_annotation_tilerendering.py +++ b/tests/test_annotation_tilerendering.py @@ -448,7 +448,7 @@ def test_function_mapper(fill_store, tmp_path: Path) -> None: _, store = fill_store(SQLiteStore, tmp_path / "test.db") def color_fn(props): - # simple test function that returns red for cells, otherwise green. + """Tests Red for cells, otherwise green.""" if props["type"] == "cell": return 1, 0, 0 return 0, 1, 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index ef6afa734..c95f3acb0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -561,7 +561,8 @@ def test_sub_pixel_read_bad_read_func() -> None: bounds = (0, 0, 8, 8) def bad_read_func(img, bounds, *kwargs): # noqa: ARG001 - return None + """Dummy read function for tests.""" + return with pytest.raises(ValueError, match="None"): utils.image.sub_pixel_read( @@ -719,6 +720,7 @@ def test_sub_pixel_read_incorrect_read_func_return() -> None: image = np.ones((10, 10)) def read_func(*args, **kwargs): # noqa: ARG001 + """Dummy read function for tests.""" return np.ones((5, 5)) with pytest.raises(ValueError, match="incorrect size"): @@ -737,6 +739,7 @@ def test_sub_pixel_read_empty_read_func_return() -> None: image = np.ones((10, 10)) def read_func(*args, **kwargs): # noqa: ARG001 + """Dummy read function for tests.""" return np.ones((0, 0)) with pytest.raises(ValueError, match="is empty"): diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index 7a3fa18af..f55c6e19f 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2296,6 +2296,21 @@ def _unpack_wkb( cx: float, cy: float, ) -> bytes: + """Return the geometry as bytes using WKB. + + Args: + data (bytes or str): + The WKB/WKT data to be unpacked. + cx (int): + The X coordinate of the centroid/representative point. + cy (float): + The Y coordinate of the centroid/representative point. + + Returns: + bytes: + The geometry as bytes. + + """ return ( self._decompress_data(data) if data From cb5c9862d3d406f79b7883ddfa10467d68b0a70a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:39:18 +0100 Subject: [PATCH 059/112] :sparkles: Add `_load_ioconfig` --- tiatoolbox/models/engine/engine_abc.py | 64 ++++++++++++++++++++------ 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index dbedd954d..9afda979b 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -18,7 +18,10 @@ from tiatoolbox.annotation import AnnotationStore + from .io_config import ModelIOConfigABC + +# noinspection PyUnreachableCode class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. @@ -107,31 +110,26 @@ class EngineABC(ABC): Whether to output logging information. Examples: - >>> # list of 2 image patches as input - >>> data = ["path/to/image1.svs", "path/to/image2.svs"] - >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.run(data, mode='patch') - >>> # array of list of 2 image patches as input >>> import numpy as np - >>> data = np.array([img1, img2]) + >>> data = np.array([np.ndarray, np.ndarray]) >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.run(data, mode='patch') + >>> output = engine.run(data, patch_mode=True) >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] >>> engine = EngineABC(pretrained_model="resnet18-kather100k") - >>> output = engine.run(data, mode='patch') + >>> output = engine.run(data, patch_mode=False) >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] >>> engine = EngineABC(pretraind_model="resnet18-kather100k") - >>> output = engine.run(tile_file, mode='tile') + >>> output = engine.run(tile_file, patch_mode=False) >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] >>> engine = EngineABC(pretraind_model="resnet18-kather100k") - >>> output = engine.run(wsi_file, mode='wsi') + >>> output = engine.run(wsi_file, patch_mode=True) """ @@ -148,9 +146,11 @@ def __init__( """Initialize Engine.""" super().__init__() + self.masks = None self.images = None self.mode = None - + self.ioconfig = None + self._ioconfig = None # runtime ioconfig self.batch_size = batch_size self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers @@ -284,11 +284,43 @@ def _prepare_save_dir( return Path.cwd() / "output" + def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: + """Helper function to load ioconfig. + + If the model is provided by TIAToolbox it will load the default ioconfig. + Otherwise, ioconfig must be specified. + + Args: + ioconfig (ModelIOConfigABC): + IO configuration to run the engines. + + Raises: + ValueError: + If no io configuration is provided or found in the pretrained TIAToolbox + models. + + Returns: + ModelIOConfigABC: + The ioconfig used for the run. + + """ + if self.ioconfig is None and ioconfig is None: + msg = ( + "Please provide a valid ModelIOConfigABC. " + "No default ModelIOConfigABC found." + ) + raise ValueError(msg) + + if ioconfig is not None: + self.ioconfig = ioconfig + + return self.ioconfig + def run( self: EngineABC, images: list[os | Path] | np.ndarray, - # masks: list[os | Path] | np.ndarray | None = None, # noqa: ERA001 - # ioconfig: ModelIOConfigABC | None = None, # noqa: ERA001 + masks: list[os | Path] | np.ndarray | None = None, + ioconfig: ModelIOConfigABC | None = None, *, # patch_mode: bool = False, # noqa: ERA001 # on_gpu: bool = True, # noqa: ERA001 @@ -350,7 +382,7 @@ def run( >>> wsis = ['wsi1.svs', 'wsi2.svs'] >>> predictor = EngineABC( ... pretrained_model="resnet18-kather100k") - >>> output = predictor.run(wsis, mode="wsi") + >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] @@ -362,6 +394,10 @@ def run( for key in kwargs: setattr(self, key, kwargs[key]) + self.images = images + self.masks = masks + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + save_dir = self._prepare_save_dir(save_dir, images) return {"save_dir": save_dir} From 43b3239c68343b1ad9b3c81c8b443349d6f52f93 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:48:52 +0100 Subject: [PATCH 060/112] :white_check_mark: Add test for ioconfig load error. --- tests/engines/test_engine_abc.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index f89e27331..904520e40 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -66,4 +66,17 @@ def test_engine_abc_incorrect_model_type(): match="Input model must be a string or 'torch.nn.Module'.", ): # Can't instantiate abstract class with abstract methods - TestEngineABC(model=1) # skipcq + TestEngineABC(model=1) + + +def test_incorrect_ioconfig(): + """Test EngineABC initialization with incorrect ioconfig.""" + import torchvision.models as torch_models + + model = torch_models.resnet18() + engine = TestEngineABC(model=model) + with pytest.raises( + ValueError, + match=r".*provide a valid ModelIOConfigABC.*", + ): + engine.run(images=[], masks=[], ioconfig=None) From 7a70ad8184a36925f0488af1c8d182bb42dca635 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 14:39:04 +0100 Subject: [PATCH 061/112] :sparkles: Add `model_to` in EngineABC and initialize variables in __init__ --- tests/test_wsimeta.py | 1 - tiatoolbox/models/engine/engine_abc.py | 135 +++++++++++++------------ 2 files changed, 69 insertions(+), 67 deletions(-) diff --git a/tests/test_wsimeta.py b/tests/test_wsimeta.py index bc3555e36..01b1cac8b 100644 --- a/tests/test_wsimeta.py +++ b/tests/test_wsimeta.py @@ -8,7 +8,6 @@ from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader -# noinspection PyTypeChecker def test_wsimeta_init_fail() -> None: """Test incorrect init for WSIMeta raises TypeError.""" with pytest.raises(TypeError): diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9afda979b..daab37a07 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -9,7 +9,7 @@ from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.models_abc import load_torch_model +from tiatoolbox.models.models_abc import load_torch_model, model_to if TYPE_CHECKING: import os @@ -21,7 +21,48 @@ from .io_config import ModelIOConfigABC -# noinspection PyUnreachableCode +def _prepare_save_dir( + save_dir: os | Path | None, + images: list | np.ndarray, +) -> Path: + """Create directory if not defined and number of images is more than 1. + + Args: + save_dir (str or Path): + Path to output directory. + images (list, ndarray): + List of inputs to process. + + Returns: + :class:`Path`: + Path to output directory. + + """ + if save_dir is None and len(images) > 1: + logger.warning( + "More than 1 WSIs detected but there is no save directory provided." + "All subsequent output will be saved to current runtime" + "location under folder 'Path.cwd() / output'. Overwrite may happen!", + stacklevel=2, + ) + save_dir = Path.cwd() / "output" + elif save_dir is not None and len(images) > 1: + logger.warning( + "When providing multiple whole-slide images / tiles, " + "the outputs will be saved and the locations of outputs" + "will be returned" + "to the calling function.", + stacklevel=2, + ) + + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=False) + return save_dir + + return Path.cwd() / "output" + + class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. @@ -149,8 +190,14 @@ def __init__( self.masks = None self.images = None self.mode = None - self.ioconfig = None - self._ioconfig = None # runtime ioconfig + + # Initialize model with specified weights and ioconfig. + self.model, self.ioconfig = self._initialize_model_ioconfig( + model=model, + weights=weights, + ) + self._ioconfig = self.ioconfig # runtime ioconfig + self.batch_size = batch_size self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers @@ -164,14 +211,11 @@ def __init__( self.stride_shape = None self.labels = None - # Initialize model with specified weights and ioconfig. - self._initialize_model_ioconfig(model=model, weights=weights) - + @staticmethod def _initialize_model_ioconfig( - self: EngineABC, model: str | nn.Module, weights: str | Path | None, - ) -> NoReturn: + ) -> tuple[nn.Module, ModelIOConfigABC | None]: """Helper function to initialize model and ioconfig attributes. If a pretrained model provided by the TIAToolbox is requested. The model @@ -189,28 +233,28 @@ def _initialize_model_ioconfig( and the model is provided by TIAToolbox, then pretrained weights will be automatically loaded from the TIA servers. + Returns: + nn.Module: + The requested PyTorch model. + + ModelIOConfigABC | None: + The model io configuration for TIAToolbox pretrained models. + Otherwise, None. + """ if not isinstance(model, (str, nn.Module)): msg = "Input model must be a string or 'torch.nn.Module'." raise TypeError(msg) - if isinstance(model, nn.Module): - self.model = ( - model # for runtime, such as after wrapping with nn.DataParallel - ) - - if weights is not None: - self.model = load_torch_model(model=self.model, weights=weights) - - ioconfig = None # requires ioconfig to be provided in EngineABC.run(). - if isinstance(model, str): # ioconfig is retrieved from the pretrained model in the toolbox. # no need to provide ioconfig in EngineABC.run() this case. - self.model, ioconfig = get_pretrained_model(model, weights) + return get_pretrained_model(model, weights) - self.ioconfig = ioconfig # for storing original - self._ioconfig = self.ioconfig # runtime ioconfig + if weights is not None: + model = load_torch_model(model=model, weights=weights) + + return model, None @abstractmethod def pre_process_patch(self: EngineABC) -> NoReturn: @@ -242,48 +286,6 @@ def post_process_wsi(self: EngineABC) -> NoReturn: """Post-process a WSI.""" raise NotImplementedError - @staticmethod - def _prepare_save_dir( - save_dir: os | Path | None, - images: list | np.ndarray, - ) -> Path: - """Create directory if not defined and number of images is more than 1. - - Args: - save_dir (str or Path): - Path to output directory. - images (list, ndarray): - List of inputs to process. - - Returns: - :class:`Path`: - Path to output directory. - - """ - if save_dir is None and len(images) > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory set." - "All subsequent output will be saved to current runtime" - "location under folder 'output'. Overwriting may happen!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len(images) > 1: - logger.warning( - "When providing multiple whole-slide images / tiles, " - "the outputs will be saved and the locations of outputs" - "will be returned" - "to the calling function.", - stacklevel=2, - ) - - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) - return save_dir - - return Path.cwd() / "output" - def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: """Helper function to load ioconfig. @@ -323,7 +325,7 @@ def run( ioconfig: ModelIOConfigABC | None = None, *, # patch_mode: bool = False, # noqa: ERA001 - # on_gpu: bool = True, # noqa: ERA001 + on_gpu: bool = False, # model runs on CPU by default. save_dir: os | Path | None = None, # None will not save output # output_type can be np.ndarray, Annotation or Json str @@ -397,7 +399,8 @@ def run( self.images = images self.masks = masks self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + self.model = model_to(model=self.model, on_gpu=on_gpu) - save_dir = self._prepare_save_dir(save_dir, images) + save_dir = _prepare_save_dir(save_dir, images) return {"save_dir": save_dir} From 748e0eb559736be5ee4c0c57d6e8ff912221863d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 25 Aug 2023 15:45:36 +0100 Subject: [PATCH 062/112] :sparkles: Add preprocessing logic to engine_abc.py --- tiatoolbox/models/dataset/dataset_abc.py | 8 ++- tiatoolbox/models/engine/engine_abc.py | 87 ++++++++++++++++++------ 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index 70f203c2d..9cc8bb96c 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -566,7 +566,11 @@ class PatchDataset(PatchDatasetABC): """ - def __init__(self, inputs, labels=None) -> None: + def __init__( + self: PatchDataset, + inputs: np.ndarray | list, + labels: list | None = None, + ) -> None: """Initialize :class:`PatchDataset`.""" super().__init__() @@ -578,7 +582,7 @@ def __init__(self, inputs, labels=None) -> None: # perform check on the input self._check_input_integrity(mode="patch") - def __getitem__(self, idx): + def __getitem__(self: PatchDataset, idx: int) -> dict: """Get an item from the dataset.""" patch = self.inputs[idx] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index daab37a07..098cb1950 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,10 +5,12 @@ from pathlib import Path from typing import TYPE_CHECKING, NoReturn +import torch from torch import nn from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.dataset.dataset_abc import PatchDataset from tiatoolbox.models.models_abc import load_torch_model, model_to if TYPE_CHECKING: @@ -23,35 +25,39 @@ def _prepare_save_dir( save_dir: os | Path | None, - images: list | np.ndarray, + len_images: int, + *, + patch_mode: bool, ) -> Path: """Create directory if not defined and number of images is more than 1. Args: save_dir (str or Path): Path to output directory. - images (list, ndarray): + len_images (int): List of inputs to process. + patch_mode(bool): + Whether to treat input image as a patch or WSI. Returns: :class:`Path`: Path to output directory. """ - if save_dir is None and len(images) > 1: + if patch_mode is False and save_dir is None and len_images > 1: logger.warning( "More than 1 WSIs detected but there is no save directory provided." "All subsequent output will be saved to current runtime" - "location under folder 'Path.cwd() / output'. Overwrite may happen!", + "location under folder 'Path.cwd() / output'. " + "The output might be overwritten!", stacklevel=2, ) save_dir = Path.cwd() / "output" - elif save_dir is not None and len(images) > 1: + elif save_dir is not None and len_images > 1: logger.warning( "When providing multiple whole-slide images / tiles, " "the outputs will be saved and the locations of outputs" - "will be returned" - "to the calling function.", + "will be returned to the calling function.", stacklevel=2, ) @@ -93,6 +99,8 @@ class EngineABC(ABC): Please note that they will also perform preprocessing. default = 0 num_post_proc_workers (int): Number of workers to postprocess the results of the model. default = 0 + on_gpu (bool): + verbose (bool): Whether to output logging information. @@ -145,6 +153,8 @@ class EngineABC(ABC): List of labels. If using `tile` or `wsi` mode, then only a single label per image tile or whole-slide image is supported. + on_gpu (bool): + Whether to run model on the GPU. Default is False. num_loader_workers (int): Number of workers used in torch.utils.data.DataLoader. verbose (bool): @@ -182,6 +192,7 @@ def __init__( num_post_proc_workers: int = 0, weights: str | Path | None = None, *, + on_gpu: bool = False, verbose: bool = False, ) -> None: """Initialize Engine.""" @@ -190,12 +201,14 @@ def __init__( self.masks = None self.images = None self.mode = None + self.on_gpu = on_gpu # Initialize model with specified weights and ioconfig. self.model, self.ioconfig = self._initialize_model_ioconfig( model=model, weights=weights, ) + self.model = model_to(model=self.model, on_gpu=self.on_gpu) self._ioconfig = self.ioconfig # runtime ioconfig self.batch_size = batch_size @@ -257,9 +270,33 @@ def _initialize_model_ioconfig( return model, None @abstractmethod - def pre_process_patch(self: EngineABC) -> NoReturn: + def pre_process_patch( + self: EngineABC, + images: np.ndarray | list, + labels: list, + ) -> torch.utils.data.DataLoader: """Pre-process an image patch.""" - raise NotImplementedError + if labels: + # if a labels is provided, then return with the prediction + self.return_labels = bool(labels) + + if labels and len(labels) != len(images): + msg = f"len(labels) != len(imgs) : {len(labels)} != {len(images)}" + raise ValueError( + msg, + ) + + dataset = PatchDataset(inputs=images, labels=labels) + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) @abstractmethod def pre_process_wsi(self: EngineABC) -> NoReturn: @@ -322,14 +359,14 @@ def run( self: EngineABC, images: list[os | Path] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, *, - # patch_mode: bool = False, # noqa: ERA001 - on_gpu: bool = False, # model runs on CPU by default. + patch_mode: bool = True, save_dir: os | Path | None = None, # None will not save output # output_type can be np.ndarray, Annotation or Json str - # output_type: np.ndarray | Annotation | str = Annotation, # noqa: ERA001 + # output_type: str = "Annotation", # noqa: ERA001 **kwargs: dict, ) -> AnnotationStore | np.ndarray | dict | str: """Run the engine on input images. @@ -341,28 +378,30 @@ def run( file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. - masks (list): + masks (list | None): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a tissue mask will be automatically generated for whole-slide images or the entire image is processed for image tiles. + labels (list | None): + List of labels. If using `tile` or `wsi` mode, then only + a single label per image tile or whole-slide image is + supported. patch_mode (bool): Whether to treat input image as a patch or WSI. - default = False. - on_gpu (bool): - Whether to run model on the GPU. + default = True. ioconfig (IOPatchPredictorConfig): IO configuration. save_dir (str or pathlib.Path): Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. - save_output (bool): + output_type (str): Whether to save output for a single file. default=False **kwargs (dict): - Keyword Args for ... + Keyword Args to update :class:`EngineABC` attributes. Returns: (:class:`numpy.ndarray`, dict): @@ -398,9 +437,17 @@ def run( self.images = images self.masks = masks + self.labels = labels + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) - self.model = model_to(model=self.model, on_gpu=on_gpu) + self.model = model_to(model=self.model, on_gpu=self.on_gpu) - save_dir = _prepare_save_dir(save_dir, images) + save_dir = _prepare_save_dir(save_dir, len(self.images), patch_mode=patch_mode) + + if patch_mode: + _ = self.pre_process_patch( + self.images, + self.labels, + ) return {"save_dir": save_dir} From 8619719ea5e04a7274cf14da0729bf2b194c1fc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:34:41 +0000 Subject: [PATCH 063/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/patch_predictor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index f401113f0..58c18a63f 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -13,7 +13,6 @@ import tiatoolbox.models.models_abc from tiatoolbox import logger from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset - from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader From e4e78c50cf30368d007e670e880bbe6f9be733d3 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Wed, 30 Aug 2023 10:50:58 +0100 Subject: [PATCH 064/112] :bug: F811 Redefinition of unused `model_to` --- tiatoolbox/models/models_abc.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index b4772b287..e8d2eba3d 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -156,22 +156,3 @@ def postproc_func(self, func): self._postproc = self.postproc else: self._postproc = func - - -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu") From b54f3a3cbea375449f38f91366886ced4c5b560b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 12:56:02 +0100 Subject: [PATCH 065/112] :recycle: Fix `_prepare_save_dir` logic --- tiatoolbox/models/engine/engine_abc.py | 53 ++++++++++++++------------ 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 098cb1950..719eace12 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -28,7 +28,7 @@ def _prepare_save_dir( len_images: int, *, patch_mode: bool, -) -> Path: +) -> Path | None: """Create directory if not defined and number of images is more than 1. Args: @@ -44,29 +44,34 @@ def _prepare_save_dir( Path to output directory. """ - if patch_mode is False and save_dir is None and len_images > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory provided." - "All subsequent output will be saved to current runtime" - "location under folder 'Path.cwd() / output'. " - "The output might be overwritten!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len_images > 1: - logger.warning( + if patch_mode is True: + return save_dir + + if save_dir is None: + if len_images > 1: + msg = ( + "More than 1 WSIs detected but there is no save directory provided." + "Please provide a 'save_dir'." + "All subsequent output will be saved to current runtime" + "location under folder 'Path.cwd() / output'. " + "The output might be overwritten!", + ) + raise OSError(msg) + return ( + Path.cwd() + ) # save the output to current working directory and return save_dir + + if len_images > 1: + logger.info( "When providing multiple whole-slide images / tiles, " "the outputs will be saved and the locations of outputs" "will be returned to the calling function.", - stacklevel=2, ) - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) - return save_dir + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=False) - return Path.cwd() / "output" + return save_dir class EngineABC(ABC): @@ -276,16 +281,16 @@ def pre_process_patch( labels: list, ) -> torch.utils.data.DataLoader: """Pre-process an image patch.""" - if labels: - # if a labels is provided, then return with the prediction - self.return_labels = bool(labels) - if labels and len(labels) != len(images): msg = f"len(labels) != len(imgs) : {len(labels)} != {len(images)}" raise ValueError( msg, ) + if labels: + # if a labels is provided, then return with the prediction + self.return_labels = bool(labels) + dataset = PatchDataset(inputs=images, labels=labels) dataset.preproc_func = self.model.preproc_func @@ -363,8 +368,7 @@ def run( ioconfig: ModelIOConfigABC | None = None, *, patch_mode: bool = True, - save_dir: os | Path | None = None, - # None will not save output + save_dir: os | Path | None = None, # None will not save output # output_type can be np.ndarray, Annotation or Json str # output_type: str = "Annotation", # noqa: ERA001 **kwargs: dict, @@ -439,6 +443,7 @@ def run( self.masks = masks self.labels = labels + # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) self.model = model_to(model=self.model, on_gpu=self.on_gpu) From 36ed0edfbea0113dd2cac30cee2c362d667e9f7c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:14:06 +0100 Subject: [PATCH 066/112] :white_check_mark: Add tests for `_prepare_save_dir` --- tests/engines/test_engine_abc.py | 63 ++++++++++++++++-- tiatoolbox/models/engine/engine_abc.py | 74 ++++++++++++++++++--- tiatoolbox/models/engine/patch_predictor.py | 4 +- 3 files changed, 125 insertions(+), 16 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 904520e40..a5ca87ca1 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -5,10 +5,12 @@ import pytest -from tiatoolbox.models.engine.engine_abc import EngineABC +from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir if TYPE_CHECKING: + import numpy as np import torch.nn + from torch.utils.data import DataLoader class TestEngineABC(EngineABC): @@ -18,7 +20,7 @@ def __init__(self: TestEngineABC, model: str | torch.nn.Module) -> NoReturn: """Test EngineABC init.""" super().__init__(model=model) - def infer_patch(self: EngineABC) -> NoReturn: + def infer_patches(self: EngineABC, data_loader: DataLoader) -> NoReturn: """Test infer_patch.""" ... # dummy function for tests. @@ -34,7 +36,11 @@ def post_process_wsi(self: EngineABC) -> NoReturn: """Test post_process_wsi.""" ... # dummy function for tests. - def pre_process_patch(self: EngineABC) -> NoReturn: + def pre_process_patches( + self: EngineABC, + images: np.ndarray, + labels: list, + ) -> NoReturn: """Test pre_process_patch.""" ... # dummy function for tests. @@ -43,7 +49,7 @@ def pre_process_wsi(self: EngineABC) -> NoReturn: ... # dummy function for tests. -def test_engine_abc(): +def test_engine_abc() -> NoReturn: """Test EngineABC initialization.""" with pytest.raises( TypeError, @@ -53,7 +59,7 @@ def test_engine_abc(): EngineABC() # skipcq -def test_engine_abc_incorrect_model_type(): +def test_engine_abc_incorrect_model_type() -> NoReturn: """Test EngineABC initialization with incorrect model type.""" with pytest.raises( TypeError, @@ -69,7 +75,7 @@ def test_engine_abc_incorrect_model_type(): TestEngineABC(model=1) -def test_incorrect_ioconfig(): +def test_incorrect_ioconfig() -> NoReturn: """Test EngineABC initialization with incorrect ioconfig.""" import torchvision.models as torch_models @@ -80,3 +86,48 @@ def test_incorrect_ioconfig(): match=r".*provide a valid ModelIOConfigABC.*", ): engine.run(images=[], masks=[], ioconfig=None) + + +def test_prepare_engines_save_dir( + tmp_path: pytest.TempPathFactory, + caplog: pytest.LogCaptureFixture, +) -> NoReturn: + """Test prepare save directory for engines.""" + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + with pytest.raises( + OSError, + match=r".*More than 1 WSIs detected but there is no save directory provided.*", + ): + _ = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=2, + ) + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_single_output", + patch_mode=False, + len_images=1, + ) + + assert out_dir == tmp_path / "wsi_single_output" + assert out_dir.exists() + assert r"When providing multiple whole-slide images / tiles" not in caplog.text + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_multiple_output", + patch_mode=False, + len_images=2, + ) + + assert out_dir == tmp_path / "wsi_multiple_output" + assert out_dir.exists() + assert r"When providing multiple whole-slide images / tiles" in caplog.text diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 719eace12..2bb269dcc 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, NoReturn import torch +import tqdm from torch import nn from tiatoolbox import logger @@ -13,17 +14,18 @@ from tiatoolbox.models.dataset.dataset_abc import PatchDataset from tiatoolbox.models.models_abc import load_torch_model, model_to -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import os import numpy as np + from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore from .io_config import ModelIOConfigABC -def _prepare_save_dir( +def prepare_engines_save_dir( save_dir: os | Path | None, len_images: int, *, @@ -45,6 +47,7 @@ def _prepare_save_dir( """ if patch_mode is True: + save_dir.mkdir(parents=True, exist_ok=False) return save_dir if save_dir is None: @@ -64,7 +67,7 @@ def _prepare_save_dir( if len_images > 1: logger.info( "When providing multiple whole-slide images / tiles, " - "the outputs will be saved and the locations of outputs" + "the outputs will be saved and the locations of outputs " "will be returned to the calling function.", ) @@ -275,7 +278,7 @@ def _initialize_model_ioconfig( return model, None @abstractmethod - def pre_process_patch( + def pre_process_patches( self: EngineABC, images: np.ndarray | list, labels: list, @@ -309,9 +312,57 @@ def pre_process_wsi(self: EngineABC) -> NoReturn: raise NotImplementedError @abstractmethod - def infer_patch(self: EngineABC) -> NoReturn: + def infer_patches( + self: EngineABC, + data_loader: DataLoader, + ) -> AnnotationStore | np.ndarray | dict | str: """Model inference on an image patch.""" - raise NotImplementedError + progress_bar = None + + if self.verbose: + progress_bar = tqdm.tqdm( + total=int(len(data_loader)), + leave=True, + ncols=80, + ascii=True, + position=0, + ) + output = { + "predictions": [], + "labels": [], + } + if self.return_probabilities: + output["probabilities"] = [] + + for _, batch_data in enumerate(data_loader): + batch_output_probabilities = self.model.infer_batch( + self.model, + batch_data["image"], + on_gpu=self.on_gpu, + ) + # We get the index of the class with the maximum probability + batch_output_predictions = self.model.postproc_func( + batch_output_probabilities, + ) + + output["predictions"].extend(batch_output_predictions.tolist()) + + # tolist might be very expensive + if self.return_probabilities: + output["probabilities"].extend(batch_output_probabilities.tolist()) + + if self.return_labels: # be careful of `s` + # We do not use tolist here because label may be of mixed types + # and hence collated as list by torch + output["labels"].extend(list(batch_data["label"])) + + if progress_bar: + progress_bar.update() + + if progress_bar: + progress_bar.close() + + return output @abstractmethod def infer_wsi(self: EngineABC) -> NoReturn: @@ -447,12 +498,19 @@ def run( self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) self.model = model_to(model=self.model, on_gpu=self.on_gpu) - save_dir = _prepare_save_dir(save_dir, len(self.images), patch_mode=patch_mode) + save_dir = prepare_engines_save_dir( + save_dir, + len(self.images), + patch_mode=patch_mode, + ) if patch_mode: - _ = self.pre_process_patch( + data_loader = self.pre_process_patches( self.images, self.labels, ) + return self.infer_patches( + data_loader=data_loader, + ) return {"save_dir": save_dir} diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 58c18a63f..48c7ab60b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -222,7 +222,7 @@ def __init__( verbose=verbose, ) - def pre_process_patch(self): + def pre_process_patches(self): """Pre-process an image patch.""" raise NotImplementedError @@ -234,7 +234,7 @@ def pre_process_wsi(self): """Pre-process a WSI.""" raise NotImplementedError - def infer_patch(self): + def infer_patches(self): """Model inference on an image patch.""" raise NotImplementedError From 87ba44c65c48f278a23788b1225ebe6f2770d8bd Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:50:55 +0100 Subject: [PATCH 067/112] :white_check_mark: Add tests for EngineABC --- tests/engines/test_engine_abc.py | 15 +++++++++++++++ tests/models/{test_abc.py => test_models_abc.py} | 0 tiatoolbox/models/engine/engine_abc.py | 10 +++++----- 3 files changed, 20 insertions(+), 5 deletions(-) rename tests/models/{test_abc.py => test_models_abc.py} (100%) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index a5ca87ca1..d6390c996 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -1,6 +1,7 @@ """Test tiatoolbox.models.engine.engine_abc.""" from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, NoReturn import pytest @@ -112,6 +113,14 @@ def test_prepare_engines_save_dir( len_images=2, ) + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=1, + ) + + assert out_dir == Path.cwd() + out_dir = prepare_engines_save_dir( save_dir=tmp_path / "wsi_single_output", patch_mode=False, @@ -131,3 +140,9 @@ def test_prepare_engines_save_dir( assert out_dir == tmp_path / "wsi_multiple_output" assert out_dir.exists() assert r"When providing multiple whole-slide images / tiles" in caplog.text + + +def test_engine_initalization() -> NoReturn: + """Test engine initialization.""" + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) diff --git a/tests/models/test_abc.py b/tests/models/test_models_abc.py similarity index 100% rename from tests/models/test_abc.py rename to tests/models/test_models_abc.py diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 2bb269dcc..9cfadf431 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -306,11 +306,6 @@ def pre_process_patches( shuffle=False, ) - @abstractmethod - def pre_process_wsi(self: EngineABC) -> NoReturn: - """Pre-process a WSI.""" - raise NotImplementedError - @abstractmethod def infer_patches( self: EngineABC, @@ -364,6 +359,11 @@ def infer_patches( return output + @abstractmethod + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Pre-process a WSI.""" + raise NotImplementedError + @abstractmethod def infer_wsi(self: EngineABC) -> NoReturn: """Model inference on a WSI.""" From 382c265bfc6041731c22f7f17434039d0b60cda7 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 15:24:24 +0100 Subject: [PATCH 068/112] :white_check_mark: Add tests for EngineABC.run() --- tests/engines/test_engine_abc.py | 27 +++++++++++++++- tiatoolbox/models/architecture/__init__.py | 36 +++++++++++----------- tiatoolbox/models/engine/engine_abc.py | 3 +- tiatoolbox/models/models_abc.py | 3 +- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index d6390c996..f925051e2 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -4,12 +4,13 @@ from pathlib import Path from typing import TYPE_CHECKING, NoReturn +import numpy as np import pytest +from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir if TYPE_CHECKING: - import numpy as np import torch.nn from torch.utils.data import DataLoader @@ -103,6 +104,13 @@ def test_prepare_engines_save_dir( assert out_dir == tmp_path / "patch_output" assert out_dir.exists() + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=True, + len_images=1, + ) + assert out_dir is None + with pytest.raises( OSError, match=r".*More than 1 WSIs detected but there is no save directory provided.*", @@ -144,5 +152,22 @@ def test_prepare_engines_save_dir( def test_engine_initalization() -> NoReturn: """Test engine initialization.""" + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + _ = TestEngineABC(model=0) + + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + model = CNNModel("alexnet", num_classes=1) + eng = TestEngineABC(model=model) + assert isinstance(eng, EngineABC) + + +def test_engine_run() -> NoReturn: + """Test engine run.""" eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) + + eng.run(images=np.zeros((10, 10, 10, 10)), on_gpu=False) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 852c3e2b5..e5853676e 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -58,15 +58,15 @@ def fetch_pretrained_weights( def get_pretrained_model( - model: str | None = None, - weights: str | Path | None = None, + requested_model: str | None = None, + pretrained_weights: str | Path | None = None, *, overwrite: bool = False, ) -> tuple[torch.nn.Module, ModelIOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: - model (str): + requested_model (str): Name of the existing models support by tiatoolbox for processing the data. The models currently supported: @@ -98,7 +98,7 @@ def get_pretrained_model( downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - weights (str): + pretrained_weights (str): Path to the weight of the corresponding `pretrained_model`. overwrite (bool): @@ -106,25 +106,25 @@ def get_pretrained_model( Examples: >>> # get mobilenet pretrained on Kather100K dataset by the TIA team - >>> model, ioconfig = get_pretrained_model(model='mobilenet_v2-kather100k') + >>> model = get_pretrained_model(requested_model='mobilenet_v2-kather100k') >>> # get mobilenet defined by TIA team, but loaded with user defined weights - >>> model, ioconfig = get_pretrained_model( - ... model='mobilenet_v2-kather100k', - ... weights='/A/B/C/my_weights.tar', + >>> model = get_pretrained_model( + ... requested_model='mobilenet_v2-kather100k', + ... pretrained_weights='/A/B/C/my_weights.tar', ... ) >>> # get resnet34 pretrained on PCam dataset by TIA team - >>> model, ioconfig = get_pretrained_model(model='resnet34-pcam') + >>> model = get_pretrained_model(requested_model='resnet34-pcam') """ - if not isinstance(model, str): - msg = "Input model must be a string." + if not isinstance(requested_model, str): + msg = "pretrained_model must be a string." raise TypeError(msg) - if model not in PRETRAINED_INFO: - msg = f"Pretrained model `{model}` does not exist." + if requested_model not in PRETRAINED_INFO: + msg = f"Pretrained model `{requested_model}` does not exist." raise ValueError(msg) - info = PRETRAINED_INFO[model] + info = PRETRAINED_INFO[requested_model] arch_info = info["architecture"] creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}") @@ -136,13 +136,13 @@ def get_pretrained_model( # ! associated pre-processing coming from dataset (Kumar, Kather, etc.) model.preproc_func = predefined_preproc_func(info["dataset"]) - if weights is None: - weights = fetch_pretrained_weights( - model, + if pretrained_weights is None: + pretrained_weights = fetch_pretrained_weights( + requested_model, overwrite=overwrite, ) - model = load_torch_model(model=model, weights=weights) + model = load_torch_model(model=model, weights=pretrained_weights) # ! io_info = info["ioconfig"] diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9cfadf431..9d9703b02 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -47,7 +47,8 @@ def prepare_engines_save_dir( """ if patch_mode is True: - save_dir.mkdir(parents=True, exist_ok=False) + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=False) return save_dir if save_dir is None: diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index e8d2eba3d..a93fccae3 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -30,7 +30,8 @@ def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") - return model.load_state_dict(saved_state_dict, strict=True) + model.load_state_dict(saved_state_dict, strict=True) + return model def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: From f75babd4cb1916109f1be159ee5722099ba6aa59 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 16:25:16 +0100 Subject: [PATCH 069/112] :bug: Fix Tests --- tests/engines/test_engine_abc.py | 15 +-------------- tiatoolbox/models/engine/engine_abc.py | 2 -- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index f925051e2..30113ca77 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: import torch.nn - from torch.utils.data import DataLoader class TestEngineABC(EngineABC): @@ -22,10 +21,6 @@ def __init__(self: TestEngineABC, model: str | torch.nn.Module) -> NoReturn: """Test EngineABC init.""" super().__init__(model=model) - def infer_patches(self: EngineABC, data_loader: DataLoader) -> NoReturn: - """Test infer_patch.""" - ... # dummy function for tests. - def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" ... # dummy function for tests. @@ -38,14 +33,6 @@ def post_process_wsi(self: EngineABC) -> NoReturn: """Test post_process_wsi.""" ... # dummy function for tests. - def pre_process_patches( - self: EngineABC, - images: np.ndarray, - labels: list, - ) -> NoReturn: - """Test pre_process_patch.""" - ... # dummy function for tests. - def pre_process_wsi(self: EngineABC) -> NoReturn: """Test pre_process_wsi.""" ... # dummy function for tests. @@ -170,4 +157,4 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) - eng.run(images=np.zeros((10, 10, 10, 10)), on_gpu=False) + eng.run(images=np.zeros((10, 3, 224, 224)), on_gpu=False) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9d9703b02..d065e117e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -278,7 +278,6 @@ def _initialize_model_ioconfig( return model, None - @abstractmethod def pre_process_patches( self: EngineABC, images: np.ndarray | list, @@ -307,7 +306,6 @@ def pre_process_patches( shuffle=False, ) - @abstractmethod def infer_patches( self: EngineABC, data_loader: DataLoader, From a948826a45f32783b4e578965688789b71e1d8ad Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:08:56 +0100 Subject: [PATCH 070/112] :white_check_mark: Add Tests for `run()` --- tests/engines/test_engine_abc.py | 54 +++++++++++++++++++++++++- tiatoolbox/models/engine/engine_abc.py | 33 +++++++++++++--- 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 30113ca77..188ce6027 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -157,4 +157,56 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) - eng.run(images=np.zeros((10, 3, 224, 224)), on_gpu=False) + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*The input numpy array should be four dimensional.*", + ): + eng.run(images=np.zeros((10, 10))) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + TypeError, + match=r"Input must be a list of file paths or a numpy array.", + ): + eng.run(images=1) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises(ValueError, match=r".* is not equal to len(imgs)*"): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(1)), + on_gpu=False, + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) + assert "probabilities" not in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + ) + assert "probabilities" not in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + assert "probabilities" not in out + assert "labels" in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + return_probabilities=True, + on_gpu=False, + ) + assert "probabilities" in out + assert "labels" not in out diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index d065e117e..87a420ffd 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import TYPE_CHECKING, NoReturn +import numpy as np import torch import tqdm from torch import nn @@ -17,7 +18,6 @@ if TYPE_CHECKING: # pragma: no cover import os - import numpy as np from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore @@ -115,7 +115,7 @@ class EngineABC(ABC): Attributes: images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. + A NHWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. @@ -285,7 +285,10 @@ def pre_process_patches( ) -> torch.utils.data.DataLoader: """Pre-process an image patch.""" if labels and len(labels) != len(images): - msg = f"len(labels) != len(imgs) : {len(labels)} != {len(images)}" + msg = ( + f"len(labels) is not equal to len(imgs) " + f": {len(labels)} != {len(images)}" + ) raise ValueError( msg, ) @@ -323,11 +326,13 @@ def infer_patches( ) output = { "predictions": [], - "labels": [], } if self.return_probabilities: output["probabilities"] = [] + if self.return_labels: + output["labels"] = [] + for _, batch_data in enumerate(data_loader): batch_output_probabilities = self.model.infer_batch( self.model, @@ -410,6 +415,24 @@ def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfig return self.ioconfig + @staticmethod + def _validate_images(images: list | np.ndarray) -> NoReturn: + """Validate input images for a run.""" + if not isinstance(images, (list, np.ndarray)): + msg = "Input must be a list of file paths or a numpy array." + raise TypeError( + msg, + ) + + if isinstance(images, np.ndarray) and images.ndim != 4: # noqa: PLR2004 + msg = ( + "The input numpy array should be four dimensional." + "The shape of the numpy array should be NHWC." + ) + raise ValueError(msg) + + return images + def run( self: EngineABC, images: list[os | Path] | np.ndarray, @@ -489,7 +512,7 @@ def run( for key in kwargs: setattr(self, key, kwargs[key]) - self.images = images + self.images = self._validate_images(images=images) self.masks = masks self.labels = labels From a9bf7474036f35d2c84d0c0a0c12997105448706 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:41:05 +0100 Subject: [PATCH 071/112] :white_check_mark: Add checks for length of images, labels, masks - Add checks for length of images, labels, masks - Allow overwrite - Allow output_type to be "dict", "array", "AnnotationStore", "DataFrame" or "json". --- tests/engines/test_engine_abc.py | 25 ++++- tiatoolbox/models/architecture/__init__.py | 4 +- tiatoolbox/models/engine/engine_abc.py | 102 +++++++++++++++----- tiatoolbox/models/engine/patch_predictor.py | 38 ++------ 4 files changed, 114 insertions(+), 55 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 188ce6027..dbbb3b9e0 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -172,13 +172,36 @@ def test_engine_run() -> NoReturn: eng.run(images=1) eng = TestEngineABC(model="alexnet-kather100k") - with pytest.raises(ValueError, match=r".* is not equal to len(imgs)*"): + with pytest.raises( + ValueError, + match=r".*len\(labels\) is not equal to len(images)*", + ): eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(1)), on_gpu=False, ) + with pytest.raises( + ValueError, + match=r".*len\(masks\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((1, 224, 224, 3)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*The shape of the numpy array should be NHWC*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((10, 3)), + on_gpu=False, + ) + eng = TestEngineABC(model="alexnet-kather100k") out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) assert "probabilities" not in out diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index e5853676e..b7ddeb62d 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -148,5 +148,5 @@ def get_pretrained_model( io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") - iostate = creator(**io_info["kwargs"]) - return model, iostate + ioconfig = creator(**io_info["kwargs"]) + return model, ioconfig diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 87a420ffd..dc4c7402e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,11 +1,13 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" from __future__ import annotations +import json from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, NoReturn import numpy as np +import pandas as pd import torch import tqdm from torch import nn @@ -30,6 +32,7 @@ def prepare_engines_save_dir( len_images: int, *, patch_mode: bool, + overwrite: bool, ) -> Path | None: """Create directory if not defined and number of images is more than 1. @@ -40,6 +43,8 @@ def prepare_engines_save_dir( List of inputs to process. patch_mode(bool): Whether to treat input image as a patch or WSI. + overwrite (bool): + Whether to overwrite the results. Default = False. Returns: :class:`Path`: @@ -56,9 +61,6 @@ def prepare_engines_save_dir( msg = ( "More than 1 WSIs detected but there is no save directory provided." "Please provide a 'save_dir'." - "All subsequent output will be saved to current runtime" - "location under folder 'Path.cwd() / output'. " - "The output might be overwritten!", ) raise OSError(msg) return ( @@ -73,7 +75,7 @@ def prepare_engines_save_dir( ) save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=overwrite) return save_dir @@ -209,7 +211,7 @@ def __init__( self.masks = None self.images = None - self.mode = None + self.patch_mode = None self.on_gpu = on_gpu # Initialize model with specified weights and ioconfig. @@ -284,15 +286,6 @@ def pre_process_patches( labels: list, ) -> torch.utils.data.DataLoader: """Pre-process an image patch.""" - if labels and len(labels) != len(images): - msg = ( - f"len(labels) is not equal to len(imgs) " - f": {len(labels)} != {len(images)}" - ) - raise ValueError( - msg, - ) - if labels: # if a labels is provided, then return with the prediction self.return_labels = bool(labels) @@ -309,10 +302,29 @@ def pre_process_patches( shuffle=False, ) + @staticmethod + def _convert_output_to_requested_type( + output: dict, + output_type: str, + ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: + """Converts inference output to requested type.""" + # function convert output to output_type + if output_type.lower() == "array": + return np.array(output["predictions"]) + + if output_type.lower() == "json": + return json.dumps(output, indent=4) + + if output_type.lower() == "dataframe": + return pd.DataFrame.from_dict(data=output) + + return output + def infer_patches( self: EngineABC, data_loader: DataLoader, - ) -> AnnotationStore | np.ndarray | dict | str: + output_type: str, + ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: """Model inference on an image patch.""" progress_bar = None @@ -361,7 +373,10 @@ def infer_patches( if progress_bar: progress_bar.close() - return output + return self._convert_output_to_requested_type( + output=output, + output_type=output_type, + ) @abstractmethod def pre_process_wsi(self: EngineABC) -> NoReturn: @@ -416,7 +431,7 @@ def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfig return self.ioconfig @staticmethod - def _validate_images(images: list | np.ndarray) -> NoReturn: + def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: """Validate input images for a run.""" if not isinstance(images, (list, np.ndarray)): msg = "Input must be a list of file paths or a numpy array." @@ -433,6 +448,37 @@ def _validate_images(images: list | np.ndarray) -> NoReturn: return images + @staticmethod + def _validate_input_numbers( + images: list | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ) -> NoReturn: + """Validates number of input images, masks and labels.""" + if masks is None and labels is None: + return + + len_images = len(images) + + if masks is not None and len_images != len(masks): + msg = ( + f"len(masks) is not equal to len(images) " + f": {len(masks)} != {len(images)}" + ) + raise ValueError( + msg, + ) + + if labels is not None and len_images != len(labels): + msg = ( + f"len(labels) is not equal to len(images) " + f": {len(labels)} != {len(images)}" + ) + raise ValueError( + msg, + ) + return + def run( self: EngineABC, images: list[os | Path] | np.ndarray, @@ -442,10 +488,10 @@ def run( *, patch_mode: bool = True, save_dir: os | Path | None = None, # None will not save output - # output_type can be np.ndarray, Annotation or Json str - # output_type: str = "Annotation", # noqa: ERA001 + overwrite: bool = False, + output_type: str = "dict", **kwargs: dict, - ) -> AnnotationStore | np.ndarray | dict | str: + ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: """Run the engine on input images. Args: @@ -475,8 +521,12 @@ def run( Output directory when processing multiple tiles and whole-slide images. By default, it is folder `output` where the running script is invoked. + overwrite (bool): + Whether to overwrite the results. Default = False. output_type (str): - Whether to save output for a single file. default=False + The format of the output type. "output_type" can be + "dict", "array", "AnnotationStore", "DataFrame" or "json". + Default is "AnnotationStore". **kwargs (dict): Keyword Args to update :class:`EngineABC` attributes. @@ -512,8 +562,12 @@ def run( for key in kwargs: setattr(self, key, kwargs[key]) - self.images = self._validate_images(images=images) - self.masks = masks + self._validate_input_numbers(images=images, masks=masks, labels=labels) + self.images = self._validate_images_masks(images=images) + + if masks is not None: + self.masks = self._validate_images_masks(images=masks) + self.labels = labels # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig @@ -524,6 +578,7 @@ def run( save_dir, len(self.images), patch_mode=patch_mode, + overwrite=overwrite, ) if patch_mode: @@ -533,6 +588,7 @@ def run( ) return self.infer_patches( data_loader=data_loader, + output_type=output_type, ) return {"save_dir": save_dir} diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 48c7ab60b..244984507 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -4,7 +4,7 @@ import copy from collections import OrderedDict from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, NoReturn import numpy as np import torch @@ -222,41 +222,21 @@ def __init__( verbose=verbose, ) - def pre_process_patches(self): - """Pre-process an image patch.""" - raise NotImplementedError - - def pre_process_tile(self): - """Pre-process an image tile.""" - raise NotImplementedError - - def pre_process_wsi(self): + def pre_process_wsi(self: PatchPredictor) -> NoReturn: """Pre-process a WSI.""" - raise NotImplementedError - - def infer_patches(self): - """Model inference on an image patch.""" - raise NotImplementedError + ... - def infer_tile(self): - """Model inference on an image tile.""" - raise NotImplementedError - - def infer_wsi(self): + def infer_wsi(self: PatchPredictor) -> NoReturn: """Model inference on a WSI.""" - raise NotImplementedError + ... - def post_process_patch(self): + def post_process_patch(self: PatchPredictor) -> NoReturn: """Post-process an image patch.""" - raise NotImplementedError - - def post_process_tile(self): - """Post-process an image tile.""" - raise NotImplementedError + ... - def post_process_wsi(self): + def post_process_wsi(self: PatchPredictor) -> NoReturn: """Post-process a WSI.""" - raise NotImplementedError + ... @staticmethod def merge_predictions( From b78d38604b92eb905650c7c9c942d308e429f64b Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:03:05 +0100 Subject: [PATCH 072/112] :bug: Fix tests --- tests/engines/test_engine_abc.py | 16 ++++++++++++++++ tiatoolbox/models/engine/engine_abc.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index dbbb3b9e0..6e6725c69 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -86,6 +86,17 @@ def test_prepare_engines_save_dir( save_dir=tmp_path / "patch_output", patch_mode=True, len_images=1, + overwrite=False, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + overwrite=True, ) assert out_dir == tmp_path / "patch_output" @@ -95,6 +106,7 @@ def test_prepare_engines_save_dir( save_dir=None, patch_mode=True, len_images=1, + overwrite=False, ) assert out_dir is None @@ -106,12 +118,14 @@ def test_prepare_engines_save_dir( save_dir=None, patch_mode=False, len_images=2, + overwrite=False, ) out_dir = prepare_engines_save_dir( save_dir=None, patch_mode=False, len_images=1, + overwrite=False, ) assert out_dir == Path.cwd() @@ -120,6 +134,7 @@ def test_prepare_engines_save_dir( save_dir=tmp_path / "wsi_single_output", patch_mode=False, len_images=1, + overwrite=False, ) assert out_dir == tmp_path / "wsi_single_output" @@ -130,6 +145,7 @@ def test_prepare_engines_save_dir( save_dir=tmp_path / "wsi_multiple_output", patch_mode=False, len_images=2, + overwrite=False, ) assert out_dir == tmp_path / "wsi_multiple_output" diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index dc4c7402e..8369b1441 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -53,7 +53,7 @@ def prepare_engines_save_dir( """ if patch_mode is True: if save_dir is not None: - save_dir.mkdir(parents=True, exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=overwrite) return save_dir if save_dir is None: From 84f69a7ad5e483c980ac823d99c0e01306e30d28 Mon Sep 17 00:00:00 2001 From: measty <20169086+measty@users.noreply.github.com> Date: Tue, 5 Sep 2023 22:51:20 +0100 Subject: [PATCH 073/112] add patch store --- tests/test_utils.py | 27 +++++++++++++++++++++++++++ tiatoolbox/utils/misc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 320e91964..fc7f4c52e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,6 +18,7 @@ from tests.test_annotation_stores import cell_polygon from tiatoolbox import utils +from tiatoolbox.annotation.storage import SQLiteStore from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError @@ -1626,3 +1627,29 @@ def test_imwrite(tmp_path: Path) -> NoReturn: tmp_path / "thisfolderdoesnotexist" / "test_imwrite.jpg", img, ) + + +def test_patch_pred_store(): + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [0.9, 0.1, 0.8], + "labels": [1, 0, 1], + "other": "other", + } + + store = misc.patch_pred_store(patch_output) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 1 + assert annotation.properties["label"] in [0, 1] + assert "other" not in annotation.properties + + patch_output.pop("coordinates") + # check correct error is raised if coordinates are missing + with pytest.raises(ValueError, match="coordinates"): + misc.patch_pred_store(patch_output) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 98a6fe7ec..d57ebe14b 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -17,6 +17,7 @@ import yaml from filelock import FileLock from shapely.affinity import translate +from shapely.geometry import box from shapely.geometry import shape as feature2geometry from skimage import exposure @@ -1174,3 +1175,34 @@ def add_from_dat( logger.info("Added %d annotations.", len(anns)) store.append_many(anns) + + +def patch_pred_store(patch_output: dict) -> AnnotationStore: + """Create an SQLiteStore containing Annotations for each patch. + + Args: + patch_output (dict): A dictionary of patch prediction information. Important keys are + "probabilities", "predictions", "coordinates", and "labels". + + Returns: + SQLiteStore: An SQLiteStore containing Annotations for each patch. + """ + store = SQLiteStore() + annotations = [] + # find what keys we need to save + keys = ["predictions"] + if "coordinates" not in patch_output: + # we cant create annotations without coordinates + msg = "Patch output must contain coordinates." + raise ValueError(msg) + keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] + + for idx, coordinate in enumerate(patch_output["coordinates"]): + properties = {key: patch_output[key][idx] for key in keys} + # Create a square annotation for the patch + annotations.append(Annotation(geometry=box(*coordinate), properties=properties)) + + # Add annotations to store + store.append_many(annotations) + + return store From 11eb57f909fe426b25c6325b86be8e2a07f271d3 Mon Sep 17 00:00:00 2001 From: measty <20169086+measty@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:44:02 +0100 Subject: [PATCH 074/112] add store conversion for patch pred --- tests/test_utils.py | 50 ++++++++++++++++++++++++++++---- tiatoolbox/utils/misc.py | 61 ++++++++++++++++++++++++++++++---------- 2 files changed, 91 insertions(+), 20 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index fc7f4c52e..478cde2be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1630,26 +1630,66 @@ def test_imwrite(tmp_path: Path) -> NoReturn: def test_patch_pred_store(): + """Test patch_pred_store.""" # Define a mock patch_output patch_output = { "predictions": [1, 0, 1], "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], - "probabilities": [0.9, 0.1, 0.8], - "labels": [1, 0, 1], "other": "other", } - store = misc.patch_pred_store(patch_output) + store = misc.patch_pred_store(patch_output, (1.0, 1.0)) # Check that its an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): assert annotation.geometry.area == 1 - assert annotation.properties["label"] in [0, 1] + assert annotation.properties["type"] in [0, 1] assert "other" not in annotation.properties patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.patch_pred_store(patch_output) + misc.patch_pred_store(patch_output, (1.0, 1.0)) + + +def test_patch_pred_store_cdict(): + """Test patch_pred_store with a class dict.""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], + "labels": [1, 0, 1], + "other": "other", + } + class_dict = {0: "class0", 1: "class1"} + store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 1 + assert annotation.properties["label"] in ["class0", "class1"] + assert annotation.properties["type"] in ["class0", "class1"] + assert "other" not in annotation.properties + + +def test_patch_pred_store_sf(): + """Test patch_pred_store with scale factor.""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], + "labels": [1, 0, 1], + } + store = misc.patch_pred_store(patch_output, (2.0, 2.0)) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 4 diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index d57ebe14b..fe6c30f19 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -17,7 +17,7 @@ import yaml from filelock import FileLock from shapely.affinity import translate -from shapely.geometry import box +from shapely.geometry import Polygon from shapely.geometry import shape as feature2geometry from skimage import exposure @@ -1177,32 +1177,63 @@ def add_from_dat( store.append_many(anns) -def patch_pred_store(patch_output: dict) -> AnnotationStore: +def patch_pred_store( + patch_output: dict, + scale_factor: tuple[int, int], + class_dict: dict | None = None, +) -> AnnotationStore: """Create an SQLiteStore containing Annotations for each patch. Args: - patch_output (dict): A dictionary of patch prediction information. Important keys are - "probabilities", "predictions", "coordinates", and "labels". + patch_output (dict): A dictionary of patch prediction information. Important + keys are "probabilities", "predictions", "coordinates", and "labels". + scale_factor (tuple[int, int]): The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + class_dict (dict): Optional dictionary mapping class indices to class names. Returns: SQLiteStore: An SQLiteStore containing Annotations for each patch. """ - store = SQLiteStore() - annotations = [] - # find what keys we need to save - keys = ["predictions"] if "coordinates" not in patch_output: # we cant create annotations without coordinates msg = "Patch output must contain coordinates." raise ValueError(msg) + # get relevant keys + class_probs = patch_output.get("probabilities", []) + preds = patch_output.get("predictions", []) + patch_coords = np.array(patch_output.get("coordinates", [])) + if not np.all(scale_factor == 1): + patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp + labels = patch_output.get("labels", []) + # get classes to consider + if len(class_probs) == 0: + classes_predicted = np.unique(preds).tolist() + else: + classes_predicted = range(len(class_probs[0])) + if class_dict is None: + # if no class dict create a default one + class_dict = {i: i for i in np.unique(preds + labels).tolist()} + annotations = [] + # find what keys we need to save + keys = ["predictions"] keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output] - for idx, coordinate in enumerate(patch_output["coordinates"]): - properties = {key: patch_output[key][idx] for key in keys} - # Create a square annotation for the patch - annotations.append(Annotation(geometry=box(*coordinate), properties=properties)) - - # Add annotations to store - store.append_many(annotations) + # put patch predictions into a store + annotations = [] + for i in range(len(preds)): + if "probabilities" in keys: + props = { + f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted + } + else: + props = {} + if "labels" in keys: + props["label"] = class_dict[labels[i]] + props["type"] = class_dict[preds[i]] + annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) + store = SQLiteStore() + keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) return store From d598fa9d11ddbf7298f99ba3830dddd6a7c6eb1a Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:01:14 +0100 Subject: [PATCH 075/112] :recycle: Refactor code to include post_process_patches --- tests/engines/test_engine_abc.py | 2 +- tiatoolbox/models/engine/engine_abc.py | 49 +++++++++------------ tiatoolbox/models/engine/patch_predictor.py | 2 +- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 6e6725c69..6bc4059e5 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -25,7 +25,7 @@ def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" ... # dummy function for tests. - def post_process_patch(self: EngineABC) -> NoReturn: + def post_process_patches(self: EngineABC) -> NoReturn: """Test post_process_patch.""" ... # dummy function for tests. diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8369b1441..a5d220bf0 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -323,8 +323,7 @@ def _convert_output_to_requested_type( def infer_patches( self: EngineABC, data_loader: DataLoader, - output_type: str, - ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: + ) -> dict: """Model inference on an image patch.""" progress_bar = None @@ -336,36 +335,26 @@ def infer_patches( ascii=True, position=0, ) - output = { + raw_predictions = { "predictions": [], } - if self.return_probabilities: - output["probabilities"] = [] if self.return_labels: - output["labels"] = [] + raw_predictions["labels"] = [] for _, batch_data in enumerate(data_loader): - batch_output_probabilities = self.model.infer_batch( + batch_output_predictions = self.model.infer_batch( self.model, batch_data["image"], on_gpu=self.on_gpu, ) - # We get the index of the class with the maximum probability - batch_output_predictions = self.model.postproc_func( - batch_output_probabilities, - ) - output["predictions"].extend(batch_output_predictions.tolist()) - - # tolist might be very expensive - if self.return_probabilities: - output["probabilities"].extend(batch_output_probabilities.tolist()) + raw_predictions["predictions"].extend(batch_output_predictions.tolist()) if self.return_labels: # be careful of `s` # We do not use tolist here because label may be of mixed types # and hence collated as list by torch - output["labels"].extend(list(batch_data["label"])) + raw_predictions["labels"].extend(list(batch_data["label"])) if progress_bar: progress_bar.update() @@ -373,10 +362,7 @@ def infer_patches( if progress_bar: progress_bar.close() - return self._convert_output_to_requested_type( - output=output, - output_type=output_type, - ) + return raw_predictions @abstractmethod def pre_process_wsi(self: EngineABC) -> NoReturn: @@ -388,10 +374,16 @@ def infer_wsi(self: EngineABC) -> NoReturn: """Model inference on a WSI.""" raise NotImplementedError - @abstractmethod - def post_process_patch(self: EngineABC) -> NoReturn: - """Post-process an image patch.""" - raise NotImplementedError + def post_process_patches( + self: EngineABC, + raw_predictions: dict, + output_type: str, + ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: + """Post-process an image patches.""" + return self._convert_output_to_requested_type( + output=raw_predictions, + output_type=output_type, + ) @abstractmethod def post_process_wsi(self: EngineABC) -> NoReturn: @@ -453,7 +445,7 @@ def _validate_input_numbers( images: list | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, labels: list | None = None, - ) -> NoReturn: + ) -> None: """Validates number of input images, masks and labels.""" if masks is None and labels is None: return @@ -586,8 +578,11 @@ def run( self.images, self.labels, ) - return self.infer_patches( + raw_predictions = self.infer_patches( data_loader=data_loader, + ) + return self.post_process_patches( + raw_predictions=raw_predictions, output_type=output_type, ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 244984507..fb24322eb 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -230,7 +230,7 @@ def infer_wsi(self: PatchPredictor) -> NoReturn: """Model inference on a WSI.""" ... - def post_process_patch(self: PatchPredictor) -> NoReturn: + def post_process_patches(self: PatchPredictor) -> NoReturn: """Post-process an image patch.""" ... From 409f128de3b0191aedb20b09f020ea83fde16d95 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:38:03 +0100 Subject: [PATCH 076/112] :bug: test initialization was overwriting the function. --- tests/engines/test_engine_abc.py | 19 +++---------------- tiatoolbox/models/engine/engine_abc.py | 23 ++++++++++------------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 6bc4059e5..1c2ab2370 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -25,10 +25,6 @@ def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" ... # dummy function for tests. - def post_process_patches(self: EngineABC) -> NoReturn: - """Test post_process_patch.""" - ... # dummy function for tests. - def post_process_wsi(self: EngineABC) -> NoReturn: """Test post_process_wsi.""" ... # dummy function for tests. @@ -220,7 +216,7 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) - assert "probabilities" not in out + assert "predictions" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -229,7 +225,7 @@ def test_engine_run() -> NoReturn: on_gpu=False, verbose=False, ) - assert "probabilities" not in out + assert "predictions" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -238,14 +234,5 @@ def test_engine_run() -> NoReturn: labels=list(range(10)), on_gpu=False, ) - assert "probabilities" not in out + assert "predictions" in out assert "labels" in out - - eng = TestEngineABC(model="alexnet-kather100k") - out = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - return_probabilities=True, - on_gpu=False, - ) - assert "probabilities" in out - assert "labels" not in out diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index a5d220bf0..41ff46377 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -135,8 +135,6 @@ class EngineABC(ABC): Input IO configuration to run the Engine. _ioconfig (): Runtime ioconfig. - return_probabilities (bool): - Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. merge_predictions (bool): @@ -226,7 +224,6 @@ def __init__( self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.verbose = verbose - self.return_probabilities = False self.return_labels = False self.merge_predictions = False self.units = "baseline" @@ -364,16 +361,6 @@ def infer_patches( return raw_predictions - @abstractmethod - def pre_process_wsi(self: EngineABC) -> NoReturn: - """Pre-process a WSI.""" - raise NotImplementedError - - @abstractmethod - def infer_wsi(self: EngineABC) -> NoReturn: - """Model inference on a WSI.""" - raise NotImplementedError - def post_process_patches( self: EngineABC, raw_predictions: dict, @@ -385,6 +372,16 @@ def post_process_patches( output_type=output_type, ) + @abstractmethod + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Pre-process a WSI.""" + raise NotImplementedError + + @abstractmethod + def infer_wsi(self: EngineABC) -> NoReturn: + """Model inference on a WSI.""" + raise NotImplementedError + @abstractmethod def post_process_wsi(self: EngineABC) -> NoReturn: """Post-process a WSI.""" From d970f49baeb08badedd2d1d214354348b1ffd6cd Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:05:16 +0100 Subject: [PATCH 077/112] :recycle: Address comments from the previous discussion. --- tests/engines/_test_feature_extractor.py | 99 -- tests/engines/_test_multi_task_segmentor.py | 423 --------- .../_test_nucleus_instance_segmentor.py | 596 ------------ tests/engines/_test_patch_predictor.py | 763 ---------------- tests/engines/_test_semantic_segmentation.py | 853 ------------------ tests/models/test_arch_vanilla.py | 3 +- tests/models/test_models_abc.py | 4 +- tiatoolbox/models/engine/engine_abc.py | 71 +- tiatoolbox/models/engine/patch_predictor.py | 10 +- .../models/engine/semantic_segmentor.py | 18 +- tiatoolbox/models/models_abc.py | 11 +- 11 files changed, 57 insertions(+), 2794 deletions(-) delete mode 100644 tests/engines/_test_feature_extractor.py delete mode 100644 tests/engines/_test_multi_task_segmentor.py delete mode 100644 tests/engines/_test_nucleus_instance_segmentor.py delete mode 100644 tests/engines/_test_patch_predictor.py delete mode 100644 tests/engines/_test_semantic_segmentation.py diff --git a/tests/engines/_test_feature_extractor.py b/tests/engines/_test_feature_extractor.py deleted file mode 100644 index 3315cf0c3..000000000 --- a/tests/engines/_test_feature_extractor.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Test for feature extractor.""" - -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import torch - -from tiatoolbox.models import IOSegmentorConfig -from tiatoolbox.models.architecture.vanilla import CNNBackbone -from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_functional(remote_sample: Callable, tmp_path: Path) -> None: - """Test for feature extraction.""" - save_dir = tmp_path / "output" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * test providing pretrained from torch vs pretrained_model.yaml - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - extractor = DeepFeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask") - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - assert len(features.shape) == 4 - - # * test same output between full infer and engine - # pre-emptive clean up - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[256, 256], - save_resolution={"units": "mpp", "resolution": 8.0}, - ) - - model = CNNBackbone("resnet50") - extractor = DeepFeatureExtractor(batch_size=4, model=model) - # should still run because we skip exception - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - - reader = WSIReader.open(mini_wsi_svs) - patches = [ - reader.read_bounds( - positions[patch_idx], - resolution=0.25, - units="mpp", - pad_constant_values=0, - coord_space="resolution", - ) - for patch_idx in range(4) - ] - patches = np.array(patches) - patches = torch.from_numpy(patches) # NHWC - patches = patches.permute(0, 3, 1, 2) # NCHW - patches = patches.type(torch.float32) - model = model.to("cpu") - # Inference mode - model.eval() - with torch.inference_mode(): - _features = model(patches).numpy() - # ! must maintain same batch size and likely same ordering - # ! else the output values will not exactly be the same (still < 1.0e-4 - # ! of epsilon though) - assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 diff --git a/tests/engines/_test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py deleted file mode 100644 index c3cc85cea..000000000 --- a/tests/engines/_test_multi_task_segmentor.py +++ /dev/null @@ -1,423 +0,0 @@ -"""Unit test package for HoVerNet+.""" - -import copy - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest - -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - MultiTaskSegmentor, - SemanticSegmentor, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection - -ON_GPU = toolbox_env.has_gpu() -BATCH_SIZE = 1 if not ON_GPU else 8 # 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def semantic_postproc_func(raw_output: np.ndarray) -> np.ndarray: - """Function to post process semantic segmentations. - - Post processes semantic segmentation to form one map output. - - """ - return np.argmax(raw_output, axis=-1) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for multi task segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("svs-1-small")) - save_dir = root_save_dir / "multitask" - shutil.rmtree(save_dir, ignore_errors=True) - - # * generate full output w/o parallel post-processing worker first - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_a = joblib.load(f"{output[0][1]}.0.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - assert multi_segmentor.num_postproc_workers == NUM_POSTPROC_WORKERS - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_b = joblib.load(f"{output[0][1]}.0.dat") - layer_map_b = np.load(f"{output[0][1]}.1.npy") - assert len(inst_dict_b) > 0, "Must have some nuclei" - assert layer_map_b is not None, "Must have some layers." - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - -def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - required_dims = (258, 258) - # above image is 512 x 512 at 0.252 mpp resolution. This is 258 x 258 at 0.500 mpp. - - save_dir = f"{root_save_dir}/multi/" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - layer_map = np.load(f"{output[0][1]}.1.npy") - - assert len(inst_dict) > 0, "Must have some nuclei." - assert layer_map is not None, "Must have some layers." - assert ( - layer_map.shape == required_dims - ), "Output layer map dimensions must be same as the expected output shape" - - -def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test segmentor when image is masked.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = root_save_dir / "instance" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - output = multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_functionality_process_instance_predictions( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test the functionality of instance predictions processing.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(4)] - - dummy_reference = [{i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)}] - - dummy_tiles = [np.zeros((512, 512))] - dummy_bounds = np.array([0, 0, 512, 512]) - - multi_segmentor.wsi_layers = [np.zeros_like(raw_maps[0][..., 0])] - multi_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - multi_segmentor._futures = [ - [dummy_reference, [dummy_reference[0].keys()], dummy_tiles, dummy_bounds], - ] - multi_segmentor._merge_post_process_results() - assert len(multi_segmentor._wsi_inst_info[0]) == 0 - - -def test_empty_image(tmp_path: Path) -> None: - """Test MultiTaskSegmentor for an empty image.""" - root_save_dir = Path(tmp_path) - sample_patch = np.ones((256, 256, 3), dtype="uint8") * 255 - sample_patch_path = root_save_dir / "sample_tile.png" - imwrite(sample_patch_path, sample_patch) - - save_dir = root_save_dir / "hovernetplus" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "hovernet" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "semantic" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=(2048, 2048), - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - -def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - with pytest.raises( - ValueError, - match=r"Output type must be specified for instance or semantic segmentation.", - ): - MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - save_dir = f"{root_save_dir}/multi/" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - multi_segmentor.model.postproc_func = semantic_postproc_func - - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - layer_map = np.load(f"{output[0][1]}.0.npy") - - assert layer_map is not None, "Must have some segmentations." - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/multi/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernetplus-oed", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Crash."): - multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) diff --git a/tests/engines/_test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py deleted file mode 100644 index 6d3ea2f67..000000000 --- a/tests/engines/_test_nucleus_instance_segmentor.py +++ /dev/null @@ -1,596 +0,0 @@ -"""Test for Nucleus Instance Segmentor.""" - -import copy - -# ! The garbage collector -import gc -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest -import yaml -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - NucleusInstanceSegmentor, - SemanticSegmentor, -) -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - _process_tile_predictions, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def helper_tile_info() -> list: - """Helper function for tile information.""" - predictor = NucleusInstanceSegmentor(model="A") - # ! assuming the tiles organized as follows (coming out from - # ! PatchExtractor). If this is broken, need to check back - # ! PatchExtractor output ordering first - # left to right, top to bottom - # --------------------- - # | 0 | 1 | 2 | 3 | - # --------------------- - # | 4 | 5 | 6 | 7 | - # --------------------- - # | 8 | 9 | 10 | 11 | - # --------------------- - # | 12 | 13 | 14 | 15 | - # --------------------- - # ! assume flag index ordering: left right top bottom - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - ], - margin=1, - tile_shape=(4, 4), - stride_shape=[4, 4], - patch_input_shape=[4, 4], - patch_output_shape=[4, 4], - ) - - return predictor._get_tile_info([16, 16], ioconfig) - - -# ---------------------------------------------------- - - -def test_get_tile_info() -> None: - """Test for getting tile info.""" - info = helper_tile_info() - _, flag = info[0] # index 0 should be full grid, removal - # removal flag at top edges - assert ( - np.sum( - np.nonzero(flag[:, 0]) - != np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), - ) - == 0 - ), "Fail Top" - # removal flag at bottom edges - assert ( - np.sum( - np.nonzero(flag[:, 1]) != np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), - ) - == 0 - ), "Fail Bottom" - # removal flag at left edges - assert ( - np.sum( - np.nonzero(flag[:, 2]) - != np.array([1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]), - ) - == 0 - ), "Fail Left" - # removal flag at right edges - assert ( - np.sum( - np.nonzero(flag[:, 3]) - != np.array([0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14]), - ) - == 0 - ), "Fail Right" - - -def test_vertical_boundary_boxes() -> None: - """Test for vertical boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [3, 0, 5, 4], - [7, 0, 9, 4], - [11, 0, 13, 4], - [3, 4, 5, 8], - [7, 4, 9, 8], - [11, 4, 13, 8], - [3, 8, 5, 12], - [7, 8, 9, 12], - [11, 8, 13, 12], - [3, 12, 5, 16], - [7, 12, 9, 16], - [11, 12, 13, 16], - ], - ) - _flag = np.array( - [ - [0, 1, 0, 0], - [0, 1, 0, 0], - [0, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - ], - ) - boxes, flag = info[1] - assert np.sum(_boxes - boxes) == 0, "Wrong Vertical Bounds" - assert np.sum(flag - _flag) == 0, "Fail Vertical Flag" - - -def test_horizontal_boundary_boxes() -> None: - """Test for horizontal boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [0, 3, 4, 5], - [4, 3, 8, 5], - [8, 3, 12, 5], - [12, 3, 16, 5], - [0, 7, 4, 9], - [4, 7, 8, 9], - [8, 7, 12, 9], - [12, 7, 16, 9], - [0, 11, 4, 13], - [4, 11, 8, 13], - [8, 11, 12, 13], - [12, 11, 16, 13], - ], - ) - _flag = np.array( - [ - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - ], - ) - boxes, flag = info[2] - assert np.sum(_boxes - boxes) == 0, "Wrong Horizontal Bounds" - assert np.sum(flag - _flag) == 0, "Fail Horizontal Flag" - - -def test_cross_section_boundary_boxes() -> None: - """Test for cross-section boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [2, 2, 6, 6], - [6, 2, 10, 6], - [10, 2, 14, 6], - [2, 6, 6, 10], - [6, 6, 10, 10], - [10, 6, 14, 10], - [2, 10, 6, 14], - [6, 10, 10, 14], - [10, 10, 14, 14], - ], - ) - _flag = np.array( - [ - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - ], - ) - boxes, flag = info[3] - assert np.sum(boxes - _boxes) == 0, "Wrong Cross Section Bounds" - assert np.sum(flag - _flag) == 0, "Fail Cross Section Flag" - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/instance/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - instance_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree("output", ignore_errors=True) - shutil.rmtree(save_dir, ignore_errors=True) - instance_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - instance_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for nuclei instance segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - save_dir = f"{root_save_dir}/instance/" - - # * test run on wsi, test run with worker - # resolution for travis testing, not the correct ones - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(1024, 1024), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - inst_segmentor = NucleusInstanceSegmentor( - batch_size=1, - num_loader_workers=0, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_merge_tile_predictions_ci( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Functional tests for merging tile predictions.""" - gc.collect() # Force clean up everything on hold - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 0.5 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - # mainly to hook the merge prediction function - inst_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - raw_maps = [[v] for v in raw_maps] # mask it as patch output - - dummy_reference = {i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)} - dummy_flag_mode_list = [ - [[1, 1, 0, 0], 1], - [[0, 0, 1, 1], 2], - [[1, 1, 1, 1], 3], - [[0, 0, 0, 0], 0], - ] - - inst_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - inst_segmentor._futures = [[dummy_reference, dummy_reference.keys()]] - inst_segmentor._merge_post_process_results() - assert len(inst_segmentor._wsi_inst_info) == 0 - - blank_raw_maps = [np.zeros_like(v) for v in raw_maps] - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=dummy_flag_mode_list[0][0], - tile_mode=dummy_flag_mode_list[0][1], - tile_output=[[np.array([0, 0, 512, 512]), blank_raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - for tile_flag, tile_mode in dummy_flag_mode_list: - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - # test exception flag - tile_flag = [0, 0, 0, 0] - with pytest.raises(ValueError, match=r".*Unknown tile mode.*"): - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=-1, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for nuclei instance segmentor.""" - root_save_dir = Path(tmp_path) - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * generate full output w/o parallel post-processing worker first - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - batch_size=8, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_a = joblib.load(f"{output[0][1]}.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - assert inst_segmentor.num_postproc_workers == 2 - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_b = joblib.load(f"{output[0][1]}.dat") - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - # ** - # To evaluate the precision of doing post-processing on tile - # then re-assemble without using full image prediction maps, - # we compare its output with the output when doing - # post-processing on the entire images. - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - _, inst_dict_b = semantic_segmentor.model.postproc(raw_maps) - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.9, "Heavy loss of precision!" - - -def test_cli_nucleus_instance_segment_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for nucleus segmentation with IOConfig.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") - - # resolution for travis testing, not the correct ones - config = { - "input_resolutions": [{"units": "mpp", "resolution": resolution}], - "output_resolutions": [ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - "margin": 128, - "tile_shape": [512, 512], - "patch_input_shape": [256, 256], - "patch_output_shape": [164, 164], - "stride_shape": [164, 164], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - - with Path.open(tmp_path / "config.yaml", "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_jpg), - "--pretrained-weights", - str(pretrained_weights), - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--mode", - "tile", - "--output-path", - str(output_path), - "--yaml-config-path", - str(tmp_path.joinpath("config.yaml")), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() - - -def test_cli_nucleus_instance_segment(remote_sample: Callable, tmp_path: Path) -> None: - """Test for nucleus segmentation.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--output-path", - str(output_path), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py deleted file mode 100644 index 0138bab48..000000000 --- a/tests/engines/_test_patch_predictor.py +++ /dev/null @@ -1,763 +0,0 @@ -"""Test for Patch Predictor.""" -from __future__ import annotations - -import copy -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor -from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.utils import download_data, imwrite -from tiatoolbox.utils import env_detection as toolbox_env - -ON_GPU = toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - -def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: - """Test for delegating args to io config.""" - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models - model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - kwargs = { - "patch_input_shape": [512, 512], - "resolution": 1.75, - "units": "mpp", - } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models - ioconfig = IOPatchPredictorConfig( - patch_input_shape=(512, 512), - stride_shape=(256, 256), - input_resolutions=[{"resolution": 1.35, "units": "mpp"}], - output_resolutions=[], - ) - predictor.predict( - [mini_wsi_svs], - ioconfig=ioconfig, - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], - patch_input_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.patch_input_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - stride_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.stride_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - resolution=1.99, - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - units="baseline", - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - -def test_patch_predictor_api( - sample_patch1: Path, - sample_patch2: Path, - tmp_path: Path, -) -> None: - """Helper function to get the model output using API 1.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent reader complaint - inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - # don't run test on GPU - output = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - - # test loading user weight - pretrained_weights_url = ( - "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" - ) - - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - pretrained_weights = ( - save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" - ) - - download_data(pretrained_weights_url, pretrained_weights) - - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - weights=pretrained_weights, - batch_size=1, - ) - - # --- test different using user model - model = CNNModel(backbone="resnet18", num_classes=9) - # test prediction - predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - -def test_wsi_predictor_api( - sample_wsi_dict: dict, - tmp_path: Path, - chdir: Callable, -) -> None: - """Test normal run of wsi predictor.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - save_dir = f"{save_dir_path}/model_wsi_output" - - # wrapper to make this more clean - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 1.0, - "units": "baseline", - "save_dir": save_dir, - } - # ! add this test back once the read at `baseline` is fixed - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - # coverage test - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True - # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 - - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) - - -def _test_predictor_output( - inputs: list, - pretrained_model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - pretrained_model=pretrained_model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.predict( - inputs, - return_probabilities=True, - return_labels=False, - on_gpu=on_gpu, - ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - - -def test_patch_predictor_kather100k_output( - sample_patch1: Path, - sample_patch2: Path, -) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" - inputs = [Path(sample_patch1), Path(sample_patch2)] - pretrained_info = { - "alexnet-kather100k": [1.0, 0.9999735355377197], - "resnet18-kather100k": [1.0, 0.9999911785125732], - "resnet34-kather100k": [1.0, 0.9979840517044067], - "resnet50-kather100k": [1.0, 0.9999986886978149], - "resnet101-kather100k": [1.0, 0.9999932050704956], - "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], - "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], - "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], - "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], - "densenet121-kather100k": [1.0, 1.0], - "densenet161-kather100k": [1.0, 0.9999959468841553], - "densenet169-kather100k": [1.0, 0.9999934434890747], - "densenet201-kather100k": [1.0, 0.9999983310699463], - "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], - "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], - "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], - "googlenet-kather100k": [1.0, 0.9999639987945557], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[6, 3], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI file not found error.""" - runner = CliRunner() - model_file_not_found_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs)[:-1], - "--file-types", - '"*.ndpi, *.svs"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert model_file_not_found_result.output == "" - assert model_file_not_found_result.exit_code == 1 - assert isinstance(model_file_not_found_result.exception, FileNotFoundError) - - -def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI mode not in wsi, tile.""" - runner = CliRunner() - mode_not_in_wsi_tile_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--file-types", - '"*.ndpi, *.svs"', - "--mode", - '"patch"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output - assert mode_not_in_wsi_tile_result.exit_code != 0 - assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) - - -def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI single file.""" - runner = CliRunner() - models_wsi_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--mode", - "wsi", - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - # Make multiple copies for test - dir_path = tmp_path.joinpath("new_copies") - dir_path.mkdir() - - dir_path_masks = tmp_path.joinpath("new_copies_masks") - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py deleted file mode 100644 index abcc99e8d..000000000 --- a/tests/engines/_test_semantic_segmentation.py +++ /dev/null @@ -1,853 +0,0 @@ -"""Test for Semantic Segmentor.""" -from __future__ import annotations - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.multiprocessing as torch_mp -import torch.nn.functional as F # noqa: N812 -import yaml -from click.testing import CliRunner -from torch import nn - -from tiatoolbox import cli -from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import WSIStreamDataset -from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imread, imwrite -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -class _CNNTo1(ModelABC): - """Contains a convolution. - - Simple model to test functionality, this contains a single - convolution layer which has weight=0 and bias=1. - - """ - - def __init__(self: _CNNTo1) -> None: - super().__init__() - self.conv = nn.Conv2d(3, 1, 3, padding=1) - self.conv.weight.data.fill_(0) - self.conv.bias.data.fill_(1) - - def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor: - """Define how to use layer.""" - return self.conv(img) - - @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o - aggregation for a single data batch. - - Args: - model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): A batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. - - """ - device = "cuda" if on_gpu else "cpu" - #### - model.eval() # infer mode - - #### - img_list = batch_data - - img_list = img_list.to(device).type(torch.float32) - img_list = img_list.permute(0, 3, 1, 2) # to NCHW - - hw = np.array(img_list.shape[2:]) - with torch.inference_mode(): # do not compute gradient - logit_list = model(img_list) - logit_list = centre_crop(logit_list, hw // 2) - logit_list = logit_list.permute(0, 2, 3, 1) # to NHWC - prob_list = F.relu(logit_list) - - prob_list = prob_list.cpu().numpy() - return [prob_list] - - -# ------------------------------------------------------------------------------------- -# IOConfig -# ------------------------------------------------------------------------------------- - - -def test_segmentor_ioconfig() -> None: - """Test for IOConfig.""" - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - assert ioconfig.highest_input_resolution == {"units": "mpp", "resolution": 0.25} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 1.0 - assert ioconfig.input_resolutions[1]["resolution"] == 0.5 - assert ioconfig.input_resolutions[2]["resolution"] == 1 / 3 - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - output_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - save_resolution={"units": "power", "resolution": 8.0}, - ) - assert ioconfig.highest_input_resolution == {"units": "power", "resolution": 40} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 0.5 - assert ioconfig.input_resolutions[1]["resolution"] == 1.0 - assert ioconfig.save_resolution["resolution"] == 8.0 / 40.0 - - resolutions = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ] - with pytest.raises(ValueError, match=r".*Unknown units.*"): - ioconfig.scale_to_highest(resolutions, "axx") - - -# ------------------------------------------------------------------------------------- -# Dataset -# ------------------------------------------------------------------------------------- - - -def test_functional_wsi_stream_dataset(remote_sample: Callable) -> None: - """Functional test for WSIStreamDataset.""" - gc.collect() # Force clean up everything on hold - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - mp_manager = torch_mp.Manager() - mp_shared_space = mp_manager.Namespace() - - sds = WSIStreamDataset(ioconfig, [mini_wsi_svs], mp_shared_space) - # test for collate - out = sds.collate_fn([None, 1, 2, 3]) - assert np.sum(out.numpy() != np.array([1, 2, 3])) == 0 - - # artificial data injection - mp_shared_space.wsi_idx = torch.tensor(0) # a scalar - mp_shared_space.patch_inputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - mp_shared_space.patch_outputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - # test read - for _, sample in enumerate(sds): - patch_data, _ = sample - (patch_resolution1, patch_resolution2, patch_resolution3) = patch_data - assert np.round(patch_resolution1.shape[0] / patch_resolution2.shape[0]) == 2 - assert np.round(patch_resolution1.shape[0] / patch_resolution3.shape[0]) == 3 - - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_crash_segmentor(remote_sample: Callable) -> None: - """Functional crash tests for segmentor.""" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # * test basic crash - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(TypeError, match=r".*`mask_reader`.*"): - semantic_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) - with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): - semantic_segmentor.filter_coordinates( - WSIReader.open(mini_wsi_msk), - np.array([1.0, 2.0]), - ) - semantic_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) - with pytest.raises(ValueError, match=r".*must be a valid file path.*"): - semantic_segmentor.get_reader( - mini_wsi_msk, - "not_exist", - "wsi", - auto_get_mask=True, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(ValueError, match=r".*provide.*"): - SemanticSegmentor() - with pytest.raises(ValueError, match=r".*valid mode.*"): - semantic_segmentor.predict([], mode="abc") - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict([], mode="tile", patch_input_shape=(2048, 2048)) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * Test crash propagation when parallelize post-processing - semantic_segmentor.num_postproc_workers = 2 - semantic_segmentor.model.forward = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - resolution=1.0, - units="baseline", - ) - shutil.rmtree("output", ignore_errors=True) - - with pytest.raises(ValueError, match=r"Invalid resolution.*"): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) - # test ignore crash - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=False, - resolution=1.0, - units="baseline", - ) - shutil.rmtree("output", ignore_errors=True) - - -def test_functional_segmentor_merging(tmp_path: Path) -> None: - """Functional test for assmebling output.""" - save_dir = Path(tmp_path) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - # predictions with HW - _output = np.array( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 2, 2], - [0, 0, 2, 2], - ], - ) - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # a second rerun to test overlapping count, - # should still maintain same result - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # else will leave hanging file pointer - # and hence cant remove its folder later - del canvas # skipcq - - # * predictions with HWC - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - _ = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - del _ # skipcq - - # * test crashing when switch to image having larger - # * shape but still provide old links - semantic_segmentor.merge_prediction( - [8, 8], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - with pytest.raises(ValueError, match=r".*`save_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.py", - ) - - with pytest.raises(ValueError, match=r".*`cache_count_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - # * test non HW predictions - with pytest.raises(ValueError, match=r".*Prediction is no HW or HWC.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2,), 1), np.full((2,), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - - # * with an out of bound location - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [ - np.full((2, 2), 1), - np.full((2, 2), 2), - np.full((2, 2), 3), - np.full((2, 2), 4), - ], - [[0, 0, 2, 2], [2, 2, 4, 4], [0, 4, 2, 6], [4, 0, 6, 2]], - save_path=None, - ) - assert np.sum(canvas - _output) < 1.0e-8 - del canvas # skipcq - - -def test_functional_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Functional test for segmentor.""" - save_dir = tmp_path / "dump" - # # convert to pathlib Path to prevent wsireader complaint - resolution = 2.0 - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - mini_wsi_msk = f"{tmp_path}/mini_mask.jpg" - imwrite(mini_wsi_msk, (thumb > 0).astype(np.uint8)) - - # preemptive clean up - shutil.rmtree("output", ignore_errors=True) # default output dir test - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # should still run because we skip exception - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=resolution, - units="mpp", - crash_on_exception=False, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * check exception bypass in the log - # there should be no exception, but how to check the log? - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - patch_output_shape=(512, 512), - stride_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=False, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test basic running and merging prediction - # * should dumping all 1 in the output - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "baseline", "resolution": 1.0}], - output_resolutions=[{"units": "baseline", "resolution": 1.0}], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - file_list = [ - mini_wsi_jpg, - mini_wsi_jpg, - ] - output_list = semantic_segmentor.predict( - file_list, - mode="tile", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - pred_2 = np.load(output_list[1][1] + ".raw.0.npy") - assert len(output_list) == 2 - assert np.sum(pred_1 - pred_2) == 0 - # due to overlapping merge and division, will not be - # exactly 1, but should be approximately so - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # * test running with mask and svs - # * also test merging prediction at designated resolution - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[{"units": "mpp", "resolution": resolution}], - save_resolution={"units": "mpp", "resolution": resolution}, - patch_input_shape=[512, 512], - patch_output_shape=[256, 256], - stride_shape=[512, 512], - ) - shutil.rmtree(save_dir, ignore_errors=True) - output_list = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - reader = WSIReader.open(mini_wsi_svs) - expected_shape = reader.slide_dimensions(**ioconfig.save_resolution) - expected_shape = np.array(expected_shape)[::-1] # to YX - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - saved_shape = np.array(pred_1.shape[:2]) - assert np.sum(expected_shape - saved_shape) == 0 - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # check normal run with auto get mask - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - model=model, - auto_generate_mask=True, - ) - _ = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -def test_subclass(remote_sample: Callable, tmp_path: Path) -> None: - """Create subclass and test parallel processing setup.""" - save_dir = Path(tmp_path) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - model = _CNNTo1() - - class XSegmentor(SemanticSegmentor): - """Dummy class to test subclassing.""" - - def __init__(self: XSegmentor) -> None: - super().__init__(model=model) - self.num_postproc_worker = 2 - - semantic_segmentor = XSegmentor() - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(1024, 1024), - patch_output_shape=(512, 512), - stride_shape=(256, 256), - resolution=1.0, - units="baseline", - crash_on_exception=False, - save_dir=save_dir / "raw", - ) - - -# specifically designed for travis -def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: - """Test for load up pretrained and over-writing tile mode ioconfig.""" - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=1.0, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - assert save_dir.joinpath("raw/0.raw.0.npy").exists() - assert save_dir.joinpath("raw/file_map.dat").exists() - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - wsi_with_artifacts = Path(remote_sample("wsi3_20k_20k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [wsi_with_artifacts], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("wsi3_20k_20k_pred"))) - _test_pred = np.load(str(save_dir / "raw" / "0.raw.0.npy")) - _test_pred = (_test_pred[..., 1] > 0.75) * 255 - # divide 255 to binarize - assert np.mean(_cache_pred[..., 0] == _test_pred) > 0.99 - - shutil.rmtree(save_dir, ignore_errors=True) - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=True, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - - wsi_breast = Path(remote_sample("wsi4_4k_4k_svs")) - semantic_segmentor = SemanticSegmentor( - num_loader_workers=4, - batch_size=BATCH_SIZE, - pretrained_model="fcn_resnet50_unet-bcss", - ) - semantic_segmentor.predict( - [wsi_breast], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = np.load(Path(remote_sample("wsi4_4k_4k_pred"))) - _test_pred = np.load(f"{save_dir}/raw/0.raw.0.npy") - _test_pred = np.argmax(_test_pred, axis=-1) - assert np.mean(np.abs(_cache_pred - _test_pred)) < 1.0e-2 - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_cli_semantic_segment_out_exists_error( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation if output path exists.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path, - ], - ) - - assert semantic_segment_result.output == "" - assert semantic_segment_result.exit_code == 1 - assert isinstance(semantic_segment_result.exception, FileExistsError) - - -def test_cli_semantic_segmentation_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation single file custom ioconfig.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") - - config = { - "input_resolutions": [{"units": "mpp", "resolution": 2.0}], - "output_resolutions": [{"units": "mpp", "resolution": 2.0}], - "patch_input_shape": [1024, 1024], - "patch_output_shape": [512, 512], - "stride_shape": [256, 256], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--pretrained-weights", - str(pretrained_weights), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path.joinpath("output"), - "--yaml-config-path", - tmp_path.joinpath("config.yaml"), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("output/0.raw.0.npy").exists() - assert tmp_path.joinpath("output/file_map.dat").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_semantic_segmentation_multi_file( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path / "small_svs_tissue_mask.jpg" - - # Make multiple copies for test - dir_path = tmp_path / "new_copies" - dir_path.mkdir() - - dir_path_masks = tmp_path / "new_copies_masks" - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - dir_path_masks.joinpath("2_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - except OSError: - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("1_" + sample_wsi_msk.name)) - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("2_" + sample_wsi_msk.name)) - - tmp_path = tmp_path / "output" - - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("0.raw.0.npy").exists() - assert tmp_path.joinpath("1.raw.0.npy").exists() - assert tmp_path.joinpath("file_map.dat").exists() - assert tmp_path.joinpath("results.json").exists() - - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("small_svs_tissue_mask"))) - _test_pred = np.load(str(tmp_path.joinpath("0.raw.0.npy"))) - _test_pred = (_test_pred[..., 1] > 0.50) * 255 - - assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3 diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index a2b1ac5c9..788144034 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -9,6 +9,7 @@ ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator +device = "cuda" if ON_GPU else "cpu" def test_functional() -> None: @@ -43,7 +44,7 @@ def test_functional() -> None: try: for backbone in backbones: model = CNNModel(backbone, num_classes=1) - model_ = model_to(on_gpu=ON_GPU, model=model) + model_ = model_to(device=device, model=model) model.infer_batch(model_, samples, on_gpu=ON_GPU) except ValueError as exc: msg = f"Model {backbone} failed." diff --git a/tests/models/test_models_abc.py b/tests/models/test_models_abc.py index 3537735ce..635b13be1 100644 --- a/tests/models/test_models_abc.py +++ b/tests/models/test_models_abc.py @@ -124,9 +124,9 @@ def test_model_to() -> None: if not utils.env_detection.has_gpu(): model = torch_models.resnet18() with pytest.raises((AssertionError, RuntimeError)): - _ = tiatoolbox.models.models_abc.model_to(on_gpu=True, model=model) + _ = tiatoolbox.models.models_abc.model_to(device="cuda", model=model) # Test on CPU model = torch_models.resnet18() - model = tiatoolbox.models.models_abc.model_to(on_gpu=False, model=model) + model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) assert isinstance(model, nn.Module) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 41ff46377..9b149bf8e 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -23,6 +23,7 @@ from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.wsicore.wsireader import WSIReader from .io_config import ModelIOConfigABC @@ -69,7 +70,7 @@ def prepare_engines_save_dir( if len_images > 1: logger.info( - "When providing multiple whole-slide images / tiles, " + "When providing multiple whole slide images, " "the outputs will be saved and the locations of outputs " "will be returned to the calling function.", ) @@ -110,17 +111,17 @@ class EngineABC(ABC): Please note that they will also perform preprocessing. default = 0 num_post_proc_workers (int): Number of workers to postprocess the results of the model. default = 0 - on_gpu (bool): - + device (str): + Select the device to run the model. Default is "cpu". verbose (bool): Whether to output logging information. Attributes: images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A NHWC image or a path to WSI. - mode (str): - Type of input to process. Choose from either `patch`, `tile` - or `wsi`. + patch_mode (str): + Whether to treat input image as a patch or WSI. + default = True. model (str | nn.Module): Defined PyTorch model. Name of the existing models support by tiatoolbox for @@ -152,18 +153,16 @@ class EngineABC(ABC): requested read resolution, not with respect to level 0, and must be positive. stride_shape (tuple): - Stride using during tile and WSI processing. Stride is + Stride used during WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - on_gpu (bool): - Whether to run model on the GPU. Default is False. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + Select the device to run the model. Default is "cpu". num_loader_workers (int): Number of workers used in torch.utils.data.DataLoader. verbose (bool): @@ -181,10 +180,10 @@ class EngineABC(ABC): >>> engine = EngineABC(pretrained_model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=False) - >>> # list of 2 image tile files as input - >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> # list of 2 image files as input + >>> image = ['path/image1.png', 'path/image2.png'] >>> engine = EngineABC(pretraind_model="resnet18-kather100k") - >>> output = engine.run(tile_file, patch_mode=False) + >>> output = engine.run(image, patch_mode=False) >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] @@ -201,7 +200,7 @@ def __init__( num_post_proc_workers: int = 0, weights: str | Path | None = None, *, - on_gpu: bool = False, + device: str = "str", verbose: bool = False, ) -> None: """Initialize Engine.""" @@ -210,14 +209,14 @@ def __init__( self.masks = None self.images = None self.patch_mode = None - self.on_gpu = on_gpu + self.device = device # Initialize model with specified weights and ioconfig. self.model, self.ioconfig = self._initialize_model_ioconfig( model=model, weights=weights, ) - self.model = model_to(model=self.model, on_gpu=self.on_gpu) + self.model = model_to(model=self.model, device=self.device) self._ioconfig = self.ioconfig # runtime ioconfig self.batch_size = batch_size @@ -343,7 +342,7 @@ def infer_patches( batch_output_predictions = self.model.infer_batch( self.model, batch_data["image"], - on_gpu=self.on_gpu, + on_gpu=self.device, ) raw_predictions["predictions"].extend(batch_output_predictions.tolist()) @@ -470,7 +469,7 @@ def _validate_input_numbers( def run( self: EngineABC, - images: list[os | Path] | np.ndarray, + images: list[os | Path | WSIReader] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, @@ -487,29 +486,25 @@ def run( images (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. + file paths or a numpy array of an image list. masks (list | None): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. labels (list | None): - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. + List of labels. Only a single label per image is supported. patch_mode (bool): Whether to treat input image as a patch or WSI. default = True. ioconfig (IOPatchPredictorConfig): IO configuration. save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. overwrite (bool): Whether to overwrite the results. Default = False. output_type (str): @@ -522,7 +517,7 @@ def run( Returns: (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple - image tiles or whole-slide images are provided as input, + whole slide images are provided as input, or save_output is True, then results are saved to `save_dir` and a dictionary indicating save location for each input is returned. @@ -561,7 +556,7 @@ def run( # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) - self.model = model_to(model=self.model, on_gpu=self.on_gpu) + self.model = model_to(model=self.model, device=self.device) save_dir = prepare_engines_save_dir( save_dir, diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index fb24322eb..7b03fd975 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -363,7 +363,7 @@ def _predict_engine( return_probabilities=False, return_labels=False, return_coordinates=False, - on_gpu=True, + device="cpu", ): """Make a prediction on a dataset. The dataset may be mutated. @@ -377,8 +377,8 @@ def _predict_engine( Whether to return labels. return_coordinates (bool): Whether to return patch coordinates. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". Returns: :class:`numpy.ndarray`: @@ -406,7 +406,7 @@ def _predict_engine( ) # use external for testing - model = tiatoolbox.models.models_abc.model_to(model=self.model, on_gpu=on_gpu) + model = tiatoolbox.models.models_abc.model_to(model=self.model, device=device) cum_output = { "probabilities": [], @@ -418,7 +418,7 @@ def _predict_engine( batch_output_probabilities = self.model.infer_batch( model, batch_data["image"], - on_gpu=on_gpu, + device=device, ) # We get the index of the class with the maximum probability batch_output_predictions = self.model.postproc_func( diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e1341c640..8ea14e447 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -928,7 +928,7 @@ def predict( # noqa: PLR0913 units=None, save_dir=None, *, - on_gpu=True, + device="cpu", crash_on_exception=False, ): """Make a prediction for a list of input data. @@ -966,8 +966,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1049,10 +1049,10 @@ def predict( # noqa: PLR0913 ) # use external for testing - self._on_gpu = on_gpu + self._device = device self._model = tiatoolbox.models.models_abc.model_to( model=self.model, - on_gpu=on_gpu, + device=device, ) # workers should be > 0 else Value Error will be thrown @@ -1253,7 +1253,7 @@ def predict( # noqa: PLR0913 units="baseline", save_dir=None, *, - on_gpu=True, + device=True, crash_on_exception=False, ): """Make a prediction for a list of input data. @@ -1291,8 +1291,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1348,7 +1348,7 @@ def predict( # noqa: PLR0913 imgs=imgs, masks=masks, mode=mode, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, patch_input_shape=patch_input_shape, patch_output_shape=patch_output_shape, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a93fccae3..93d338f93 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -34,23 +34,24 @@ def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: return model -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: +def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. + device (str): Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: The model after being moved to cpu/gpu. """ - if on_gpu: # DataParallel work only for cuda + if device != "cpu": + # DataParallel work only for cuda model = torch.nn.DataParallel(model) - return model.to("cuda") - return model.to("cpu") + device = torch.device(device) + return model.to(device) class ModelABC(ABC, nn.Module): From bd892fda67a4e4f3d1cd4edca19dc7dd6f211c70 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:40:00 +0100 Subject: [PATCH 078/112] :bug: Fix breaking changes due to use of device instead of `on_gpu`. --- tests/engines/test_engine_abc.py | 2 +- tests/models/test_arch_mapde.py | 4 ++-- tests/models/test_arch_micronet.py | 2 +- tiatoolbox/models/architecture/hovernet.py | 8 +++----- tiatoolbox/models/architecture/hovernetplus.py | 8 +++----- tiatoolbox/models/architecture/mapde.py | 8 +++----- tiatoolbox/models/architecture/micronet.py | 8 +++----- tiatoolbox/models/architecture/nuclick.py | 11 +++++------ tiatoolbox/models/architecture/sccnn.py | 8 +++----- tiatoolbox/models/architecture/unet.py | 8 +++----- tiatoolbox/models/architecture/vanilla.py | 17 ++++++++--------- tiatoolbox/models/engine/engine_abc.py | 13 +++++++++---- tiatoolbox/models/models_abc.py | 12 +++++++----- tiatoolbox/utils/misc.py | 3 ++- 14 files changed, 53 insertions(+), 59 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 1c2ab2370..8c7b60a6c 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -146,7 +146,7 @@ def test_prepare_engines_save_dir( assert out_dir == tmp_path / "wsi_multiple_output" assert out_dir.exists() - assert r"When providing multiple whole-slide images / tiles" in caplog.text + assert r"When providing multiple whole slide images" in caplog.text def test_engine_initalization() -> NoReturn: diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index df60d3b47..f0142406d 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -44,7 +44,7 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to(select_device(on_gpu=ON_GPU)) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + model = model.to() + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index cd4bd0833..e7aa23d5b 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -39,7 +39,7 @@ def test_functionality( model = model.to(map_location) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=map_location) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index cad29fe83..216e06ee5 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -19,7 +19,6 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc from tiatoolbox.utils.misc import get_bounding_box @@ -781,7 +780,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: return pred_inst, nuc_inst_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -793,8 +792,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: tuple: @@ -806,7 +805,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 59135a350..ddcce67ea 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -12,7 +12,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x -from tiatoolbox.utils import misc class HoVerNetPlus(HoVerNet): @@ -320,7 +319,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -332,13 +331,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index 21c588c29..863ce985d 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -13,7 +13,6 @@ from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet -from tiatoolbox.utils.misc import select_device class MapDe(MicroNet): @@ -258,7 +257,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -271,8 +270,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -281,7 +280,6 @@ def infer_batch( """ patch_imgs = batch_data - device = select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 69daa120f..c18e51e6b 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -18,7 +18,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc def group1_forward_branch( @@ -628,7 +627,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -641,8 +640,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -651,7 +650,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py index 85a759bb6..cb5f52509 100644 --- a/tiatoolbox/models/architecture/nuclick.py +++ b/tiatoolbox/models/architecture/nuclick.py @@ -21,7 +21,6 @@ from tiatoolbox import logger from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import IntPair @@ -646,7 +645,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> np.ndarray: """Run inference on an input batch. @@ -655,16 +654,16 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): a batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + batch_data (torch.Tensor): + A batch of data generated by torch.utils.data.DataLoader. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() - device = misc.select_device(on_gpu=on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index bbeb58094..bdb8926e3 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -16,7 +16,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class SCCNN(ModelABC): @@ -354,7 +353,7 @@ def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -367,8 +366,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list of :class:`numpy.ndarray`: @@ -377,7 +376,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 8f628fb52..7e2e35c02 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -11,7 +11,6 @@ from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class ResNetEncoder(ResNet): @@ -415,7 +414,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list: """Run inference on an input batch. @@ -428,8 +427,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list: @@ -438,7 +437,6 @@ def infer_batch( """ model.eval() - device = misc.select_device(on_gpu=on_gpu) #### imgs = batch_data diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 5855971d5..2ecbd5b86 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -9,7 +9,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import select_device if TYPE_CHECKING: # pragma: no cover from torchvision.models import WeightsEnum @@ -142,7 +141,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str = "cpu", ) -> np.ndarray: """Run inference on an input batch. @@ -154,11 +153,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() @@ -239,7 +238,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray, ...]: """Run inference on an input batch. @@ -251,11 +250,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 9b149bf8e..b58dce0ee 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -200,7 +200,7 @@ def __init__( num_post_proc_workers: int = 0, weights: str | Path | None = None, *, - device: str = "str", + device: str = "cpu", verbose: bool = False, ) -> None: """Initialize Engine.""" @@ -342,7 +342,7 @@ def infer_patches( batch_output_predictions = self.model.infer_batch( self.model, batch_data["image"], - on_gpu=self.device, + device=self.device, ) raw_predictions["predictions"].extend(batch_output_predictions.tolist()) @@ -509,8 +509,13 @@ def run( Whether to overwrite the results. Default = False. output_type (str): The format of the output type. "output_type" can be - "dict", "array", "AnnotationStore", "DataFrame" or "json". - Default is "AnnotationStore". + "zarr", "AnnotationStore". Default is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. **kwargs (dict): Keyword Args to update :class:`EngineABC` attributes. diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 93d338f93..1da0342a0 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -38,8 +38,10 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. Args: - model (torch.nn.Module): PyTorch defined model. - device (str): Transfers model to the specified device. Default is "cpu". + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: torch.nn.Module: @@ -71,7 +73,7 @@ def forward(self, *args, **kwargs): @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str): """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -82,8 +84,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ ... # pragma: no cover diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index fe6c30f19..6c2aa2e74 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -860,7 +860,8 @@ def select_device(*, on_gpu: bool) -> str: """Selects the appropriate device as requested. Args: - on_gpu (bool): Selects gpu if True. + on_gpu (bool): + Selects gpu if True. Returns: str: From 4ca18be8ede2c75bd4549061a82d10118c508d94 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:16:07 +0100 Subject: [PATCH 079/112] :memo: Fix docstrings. --- tiatoolbox/models/engine/engine_abc.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index b58dce0ee..95a14aa54 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -85,20 +85,16 @@ class EngineABC(ABC): """Abstract base class for engines used in tiatoolbox. Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs + model (str | nn.Module): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox using + the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case-insensitive. + of weights. weights (str or Path): - Path to the weight of the corresponding `pretrained_model`. + Path to the weight of the corresponding `model`. >>> engine = EngineABC( ... pretrained_model="pretrained-model-name", @@ -124,13 +120,13 @@ class EngineABC(ABC): default = True. model (str | nn.Module): Defined PyTorch model. - Name of the existing models support by tiatoolbox for + Name of the existing models supported by the TIAToolbox for processing the data. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument + of weights via the `weights` argument. Argument is case-insensitive. ioconfig (ModelIOConfigABC): Input IO configuration to run the Engine. From e7a6d1e0f061f92f3da907b1d455a763dbf55b8c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:17:48 +0100 Subject: [PATCH 080/112] :bug: Fix test_arch_nuclick.py --- tests/models/test_arch_nuclick.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index fda0c01a6..b84516125 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -10,6 +10,7 @@ from tiatoolbox.models import NuClick from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device ON_GPU = False @@ -53,7 +54,7 @@ def test_functional_nuclick( model = NuClick(num_input_channels=5, num_output_channels=1) pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) postproc_masks = model.postproc( output, do_reconstruction=True, From 40846ffb933a62610cc8eff60c10a9d5701af184 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:21:42 +0100 Subject: [PATCH 081/112] :rewind: Revert back to pretrained_model --- tiatoolbox/models/architecture/__init__.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index b7ddeb62d..7776cdb60 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -58,7 +58,7 @@ def fetch_pretrained_weights( def get_pretrained_model( - requested_model: str | None = None, + pretrained_model: str | None = None, pretrained_weights: str | Path | None = None, *, overwrite: bool = False, @@ -66,7 +66,7 @@ def get_pretrained_model( """Load a predefined PyTorch model with the appropriate pretrained weights. Args: - requested_model (str): + pretrained_model (str): Name of the existing models support by tiatoolbox for processing the data. The models currently supported: @@ -106,25 +106,25 @@ def get_pretrained_model( Examples: >>> # get mobilenet pretrained on Kather100K dataset by the TIA team - >>> model = get_pretrained_model(requested_model='mobilenet_v2-kather100k') + >>> model = get_pretrained_model(pretrained_model='mobilenet_v2-kather100k') >>> # get mobilenet defined by TIA team, but loaded with user defined weights >>> model = get_pretrained_model( - ... requested_model='mobilenet_v2-kather100k', + ... pretrained_model='mobilenet_v2-kather100k', ... pretrained_weights='/A/B/C/my_weights.tar', ... ) >>> # get resnet34 pretrained on PCam dataset by TIA team - >>> model = get_pretrained_model(requested_model='resnet34-pcam') + >>> model = get_pretrained_model(pretrained_model='resnet34-pcam') """ - if not isinstance(requested_model, str): + if not isinstance(pretrained_model, str): msg = "pretrained_model must be a string." raise TypeError(msg) - if requested_model not in PRETRAINED_INFO: - msg = f"Pretrained model `{requested_model}` does not exist." + if pretrained_model not in PRETRAINED_INFO: + msg = f"Pretrained model `{pretrained_model}` does not exist." raise ValueError(msg) - info = PRETRAINED_INFO[requested_model] + info = PRETRAINED_INFO[pretrained_model] arch_info = info["architecture"] creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}") @@ -138,7 +138,7 @@ def get_pretrained_model( if pretrained_weights is None: pretrained_weights = fetch_pretrained_weights( - requested_model, + pretrained_model, overwrite=overwrite, ) From c513ddc82c4714289154f60f54f3c1de66c9f320 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:43:25 +0100 Subject: [PATCH 082/112] :bug: Fix issues with on_gpu --- tests/models/test_arch_sccnn.py | 11 ++++++++--- tests/models/test_arch_unet.py | 5 +++-- tests/models/test_arch_vanilla.py | 3 ++- tests/models/test_hovernet.py | 9 +++++---- tests/models/test_hovernetplus.py | 3 ++- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index bdec99e0b..7a809432c 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -4,9 +4,10 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -14,7 +15,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) - map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu()) + map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) @@ -44,6 +45,10 @@ def test_functionality(remote_sample: Callable) -> None: assert np.all(output == [[8, 7]]) model = _load_sccnn(name="sccnn-conic") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index f15a5dc71..69496c7aa 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -8,6 +8,7 @@ from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.unet import UNetModel +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = False @@ -47,7 +48,7 @@ def test_functional_unet(remote_sample: Callable) -> None: model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) _ = output[0] # run untrained network to test for architecture @@ -59,4 +60,4 @@ def test_functional_unet(remote_sample: Callable) -> None: encoder_levels=[32, 64], skip_type="concat", ) - _ = model.infer_batch(model, batch, on_gpu=ON_GPU) + _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 788144034..cfae665b2 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -6,6 +6,7 @@ from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator @@ -45,7 +46,7 @@ def test_functional() -> None: for backbone in backbones: model = CNNModel(backbone, num_classes=1) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index bf77b46ba..dcf2251ac 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -14,6 +14,7 @@ ResidualBlock, TFSamepaddingLayer, ) +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-monusac") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-consep") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-kumar") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index 96d0f9d23..1377fdd82 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -7,6 +7,7 @@ from tiatoolbox.models import HoVerNetPlus from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.transforms import imresize @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernetplus-oed") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) From e829967ba5ba68cfc0e3d320c9097cf26be70af4 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:56:42 +0100 Subject: [PATCH 083/112] :bug: Fix issues with on_gpu --- tests/models/test_arch_sccnn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index 7a809432c..53e05545a 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -40,7 +40,11 @@ def test_functionality(remote_sample: Callable) -> None: ) batch = torch.from_numpy(patch)[None] model = _load_sccnn(name="sccnn-crchisto") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) From 52c441516737bbd394a85c0f8079299a0adc7b88 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:18:26 +0100 Subject: [PATCH 084/112] :bug: Fix infer_batch() takes 2 positional arguments but 3 were given Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_arch_sccnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index 53e05545a..58d3f67d0 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -43,7 +43,7 @@ def test_functionality(remote_sample: Callable) -> None: output = model.infer_batch( model, batch, - select_device(on_gpu=env_detection.has_gpu()), + device=select_device(on_gpu=env_detection.has_gpu()), ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) From 263d869e3a65a273a2a7574d6645a27d290e604c Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:42:38 +0100 Subject: [PATCH 085/112] :bug: Fix flake8-annotations Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 478cde2be..f4f81ac92 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1629,7 +1629,7 @@ def test_imwrite(tmp_path: Path) -> NoReturn: ) -def test_patch_pred_store(): +def test_patch_pred_store() -> None: """Test patch_pred_store.""" # Define a mock patch_output patch_output = { @@ -1654,7 +1654,7 @@ def test_patch_pred_store(): misc.patch_pred_store(patch_output, (1.0, 1.0)) -def test_patch_pred_store_cdict(): +def test_patch_pred_store_cdict() -> None: """Test patch_pred_store with a class dict.""" # Define a mock patch_output patch_output = { @@ -1667,7 +1667,7 @@ def test_patch_pred_store_cdict(): class_dict = {0: "class0", 1: "class1"} store = misc.patch_pred_store(patch_output, (1.0, 1.0), class_dict=class_dict) - # Check that its an SQLiteStore containing the expected annotations + # Check that it's an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): @@ -1677,7 +1677,7 @@ def test_patch_pred_store_cdict(): assert "other" not in annotation.properties -def test_patch_pred_store_sf(): +def test_patch_pred_store_sf() -> None: """Test patch_pred_store with scale factor.""" # Define a mock patch_output patch_output = { @@ -1688,7 +1688,7 @@ def test_patch_pred_store_sf(): } store = misc.patch_pred_store(patch_output, (2.0, 2.0)) - # Check that its an SQLiteStore containing the expected annotations + # Check that it's an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): From 0358b99093f15d180a614c49db5cf03fd119ca6e Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:57:39 +0100 Subject: [PATCH 086/112] :bug: Fix flake8-annotations for multi_task_segmentor.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../models/engine/multi_task_segmentor.py | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7fbae30bb..287fc456b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -39,32 +39,34 @@ if TYPE_CHECKING: # pragma: no cover import torch - from .io_config import IOInstanceSegmentorConfig + from tiatoolbox.typing import IntBounds + + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. # May need 3rd party libraries to use method/static method otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, - model_name, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, + model_name: str, +) -> tuple: """Process Tile Predictions. Function to merge new tile prediction with existing prediction, using the output from each task. Args: - ioconfig (:class:`IOInstanceSegmentorConfig`): Object defines information + ioconfig (:class:`IOSegmentorConfig`): Object defines information about input and output placement of patches. tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as (top_left_x, top_left_y, bottom_x, bottom_y). @@ -239,7 +241,7 @@ class MultiTaskSegmentor(NucleusInstanceSegmentor): """ def __init__( # noqa: PLR0913 - self, + self: MultiTaskSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -286,12 +288,12 @@ def __init__( # noqa: PLR0913 ) def _predict_one_wsi( - self, + self: MultiTaskSegmentor, wsi_idx: int, ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -393,13 +395,13 @@ def _predict_one_wsi( # may need to chain it with parents def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: MultiTaskSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -418,10 +420,15 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: MultiTaskSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): + def callback( + new_inst_dicts: dict, + remove_uuid_lists: list, + tiles: dict, + bounds: IntBounds, + ) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store @@ -444,7 +451,7 @@ def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): callback(*future) continue # some errors happen, log it and propagate exception - # ! this will lead to discard a bunch of + # ! this will lead to discard a whole bunch of # ! inferred tiles within this current WSI if future.exception() is not None: raise future.exception() # noqa: RSE102 From 040cb7c3e13d72f064e5f43d2ba16a885769ef4f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:04:51 +0100 Subject: [PATCH 087/112] :bug: Fix flake8-annotations for nucleus_instance_segmentor.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../engine/nucleus_instance_segmentor.py | 85 ++++++++++--------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 4156e2c2a..d59e209da 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -18,18 +18,18 @@ from tiatoolbox.tools.patchextraction import PatchExtractor if TYPE_CHECKING: # pragma: no cover - from .io_config import IOInstanceSegmentorConfig + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig def _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - ref_inst_dict, -): + inst_dict: dict, + ioconfig: IOSegmentorConfig, + tile_shape: list, + tile_flag: list, + tile_mode: int, + tile_tl: tuple, + ref_inst_dict: dict, +) -> list | tuple: """Function to merge new tile prediction with existing prediction. Args: @@ -50,12 +50,12 @@ def _process_instance_predictions( an overlapping tile from tile generation. The predicted instances are immediately added to accumulated output. - 1: Vertical tile strip that stands between two normal tiles - (flag 0). It has the the same height as normal tile but + (flag 0). It has the same height as normal tile but less width (hence vertical strip). - 2: Horizontal tile strip that stands between two normal tiles - (flag 0). It has the the same width as normal tile but + (flag 0). It has the same width as normal tile but less height (hence horizontal strip). - - 3: tile strip stands at the cross section of four normal tiles + - 3: tile strip stands at the cross-section of four normal tiles (flag 0). tile_tl (tuple): Top left coordinates of the current tile. ref_inst_dict (dict): Dictionary contains accumulated output. The @@ -144,7 +144,7 @@ def _process_instance_predictions( msg = f"Unknown tile mode {tile_mode}." raise ValueError(msg) - def retrieve_sel_uids(sel_indices, inst_dict): + def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: """Helper to retrieved selected instance uids.""" if len(sel_indices) > 0: # not sure how costly this is in large dict @@ -153,7 +153,7 @@ def retrieve_sel_uids(sel_indices, inst_dict): remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) - # external removal only for tile at cross sections + # external removal only for tile at cross-sections # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 @@ -186,17 +186,17 @@ def retrieve_sel_uids(sel_indices, inst_dict): # caller. May need 3rd party libraries to use method/static method # otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, +) -> tuple[dict, list]: """Function to merge new tile prediction with existing prediction. Args: @@ -368,7 +368,7 @@ class NucleusInstanceSegmentor(SemanticSegmentor): """ def __init__( - self, + self: NucleusInstanceSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -406,7 +406,7 @@ def __init__( def _get_tile_info( image_shape: list[int] | np.ndarray, ioconfig: IOInstanceSegmentorConfig, - ): + ) -> list[list, ...]: """Generating tile information. To avoid out of memory problem when processing WSI-scale in @@ -467,7 +467,7 @@ def _get_tile_info( # * remove all sides for boxes # unset for those lie within the selection - def unset_removal_flag(boxes, removal_flag): + def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: """Unset removal flags for tiles intersecting image boundaries.""" sel_boxes = [ shapely_box(0, 0, w, 0), # top edge @@ -581,7 +581,12 @@ def unset_removal_flag(boxes, removal_flag): return info - def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): + def _to_shared_space( + self: NucleusInstanceSegmentor, + wsi_idx: int, + patch_inputs: list, + patch_outputs: list, + ) -> None: """Helper functions to transfer variable to shared space. We modify the shared space so that we can update worker info @@ -613,7 +618,7 @@ def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): self._mp_shared_space.patch_outputs = patch_outputs self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() - def _infer_once(self): + def _infer_once(self: NucleusInstanceSegmentor) -> list: """Running the inference only once for the currently active dataloader.""" num_steps = len(self._loader) @@ -658,12 +663,12 @@ def _infer_once(self): return cum_output def _predict_one_wsi( - self, + self: NucleusInstanceSegmentor, wsi_idx: int, - ioconfig: IOInstanceSegmentorConfig, + ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -751,13 +756,13 @@ def _predict_one_wsi( self._wsi_inst_info = None # clean up def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: NucleusInstanceSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -775,10 +780,10 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: NucleusInstanceSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dict, remove_uuid_list): + def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store From b99314692f93032cb7bb85212427aa50593372ca Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:20:32 +0100 Subject: [PATCH 088/112] :bug: Fix flake8-annotations for semantic_segmentor.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- .../models/engine/semantic_segmentor.py | 142 +++++++++--------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 8ea14e447..237d032f1 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -27,10 +27,13 @@ from .io_config import IOSegmentorConfig if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + from tiatoolbox.typing import IntPair, Resolution, Units -def _estimate_canvas_parameters(sample_prediction, canvas_shape): +def _estimate_canvas_parameters( + sample_prediction: np.ndarray, + canvas_shape: np.ndarray, +) -> tuple[tuple, tuple, bool]: """Estimates canvas parameters. Args: @@ -58,11 +61,11 @@ def _estimate_canvas_parameters(sample_prediction, canvas_shape): def _prepare_save_output( - save_path, - cache_count_path, - canvas_cum_shape_, - canvas_count_shape_, -): + save_path: str | Path, + cache_count_path: str | Path, + canvas_cum_shape_: tuple[int, ...], + canvas_count_shape_: tuple[int, ...], +) -> tuple: """Prepares for saving the cached output.""" if save_path is not None: save_path = Path(save_path) @@ -193,7 +196,7 @@ class SemanticSegmentor: """ def __init__( - self, + self: SemanticSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -251,7 +254,7 @@ def __init__( def get_coordinates( image_shape: list[int] | np.ndarray, ioconfig: IOSegmentorConfig, - ): + ) -> tuple[list, list]: """Calculate patch tiling coordinates. By default, internally, it will call the @@ -309,7 +312,7 @@ def filter_coordinates( bounds: np.ndarray, resolution: Resolution | None = None, units: Units | None = None, - ): + ) -> np.ndarray: """Indicates which coordinate is valid basing on the mask. To use your own approaches, either subclass to overwrite or @@ -369,7 +372,7 @@ def filter_coordinates( scale_factor = mask_real_shape / mask_resolution_shape scale_factor = scale_factor[0] # what if ratio x != y - def sel_func(coord: np.ndarray): + def sel_func(coord: np.ndarray) -> bool: """Accept coord as long as its box contains part of mask.""" coord_in_real_mask = np.ceil(scale_factor * coord).astype(np.int32) start_x, start_y, end_x, end_y = coord_in_real_mask @@ -386,7 +389,7 @@ def get_reader( mode: str, *, auto_get_mask: bool, - ): + ) -> tuple[WSIReader, WSIReader]: """Define how to get reader for mask and source image.""" img_path = Path(img_path) reader = WSIReader.open(img_path) @@ -411,12 +414,12 @@ def get_reader( return reader, mask_reader def _predict_one_wsi( - self, + self: SemanticSegmentor, wsi_idx: int, ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -527,13 +530,13 @@ def _predict_one_wsi( shutil.rmtree(cache_dir) def _process_predictions( - self, + self: SemanticSegmentor, cum_batch_predictions: list, wsi_reader: WSIReader, ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also saving afterwards. @@ -595,7 +598,7 @@ def merge_prediction( locations: list | np.ndarray, save_path: str | Path | None = None, cache_count_path: str | Path | None = None, - ): + ) -> np.ndarray: """Merge patch-level predictions to form a 2-dimensional prediction map. When accumulating the raw prediction onto a same canvas (via @@ -665,7 +668,7 @@ def merge_prediction( canvas_count_shape_, ) - def index(arr, tl, br): + def index(arr: np.ndarray, tl: np.ndarray, br: np.ndarray) -> np.ndarray: """Helper to shorten indexing.""" return arr[tl[0] : br[0], tl[1] : br[1]] @@ -726,7 +729,7 @@ def index(arr, tl, br): return cum_canvas @staticmethod - def _prepare_save_dir(save_dir): + def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]: """Prepare save directory and cache.""" if save_dir is None: logger.warning( @@ -749,14 +752,14 @@ def _prepare_save_dir(save_dir): @staticmethod def _update_ioconfig( - ioconfig, - mode, - patch_input_shape, - patch_output_shape, - stride_shape, - resolution, - units, - ): + ioconfig: IOSegmentorConfig, + mode: str, + patch_input_shape: IntPair, + patch_output_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOSegmentorConfig: """Update ioconfig according to input parameters. Args: @@ -815,7 +818,7 @@ def _update_ioconfig( return ioconfig - def _prepare_workers(self): + def _prepare_workers(self: SemanticSegmentor) -> None: """Prepare number of workers.""" self._postproc_workers = None if self.num_postproc_workers is not None: @@ -823,7 +826,7 @@ def _prepare_workers(self): max_workers=self.num_postproc_workers, ) - def _memory_cleanup(self): + def _memory_cleanup(self: SemanticSegmentor) -> None: """Memory clean up.""" self.imgs = None self.masks = None @@ -838,15 +841,16 @@ def _memory_cleanup(self): self._postproc_workers = None def _predict_wsi_handle_exception( - self, - imgs, - wsi_idx, - img_path, - mode, - ioconfig, - save_dir, - crash_on_exception, - ): + self: SemanticSegmentor, + imgs: list, + wsi_idx: int, + img_path: str | Path, + mode: str, + ioconfig: IOSegmentorConfig, + save_dir: str | Path, + *, + crash_on_exception: bool, + ) -> None: """Predict on multiple WSIs. Args: @@ -916,21 +920,21 @@ def _predict_wsi_handle_exception( logging.exception("Crashed on %s", wsi_save_path) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=None, - units=None, - save_dir=None, + self: SemanticSegmentor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig = None, + patch_input_shape: IntPair = None, + patch_output_shape: IntPair = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - device="cpu", - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the object instantiation time @@ -1170,7 +1174,7 @@ class DeepFeatureExtractor(SemanticSegmentor): """ def __init__( - self, + self: DeepFeatureExtractor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -1197,13 +1201,13 @@ def __init__( self.process_prediction_per_batch = False def _process_predictions( - self, + self: DeepFeatureExtractor, cum_batch_predictions: list, wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also @@ -1241,21 +1245,21 @@ def _process_predictions( np.save(f"{save_path}.features.{idx}.npy", prediction_list) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=1.0, - units="baseline", - save_dir=None, + self: DeepFeatureExtractor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig | None = None, + patch_input_shape: IntPair | None = None, + patch_output_shape: IntPair | None = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - device=True, - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the time of object From 9c68cc34eee5682ece10f3e80865909e6204226d Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:31:29 +0100 Subject: [PATCH 089/112] :bug: Fix flake8-annotations for patch_predictor.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/engine/engine_abc.py | 2 +- .../engine/nucleus_instance_segmentor.py | 2 +- tiatoolbox/models/engine/patch_predictor.py | 144 +++++++++++------- 3 files changed, 94 insertions(+), 54 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 95a14aa54..83932d74a 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -475,7 +475,7 @@ def run( overwrite: bool = False, output_type: str = "dict", **kwargs: dict, - ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: + ) -> AnnotationStore | str: """Run the engine on input images. Args: diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index d59e209da..9aac3b8f5 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -645,7 +645,7 @@ def _infer_once(self: NucleusInstanceSegmentor) -> list: sample_outputs = self.model.infer_batch( self._model, sample_datas, - on_gpu=self._on_gpu, + device=self._device, ) # repackage so that it's a N list, each contains # L x etc. output diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 7b03fd975..3092f827b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -17,7 +17,12 @@ from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + import os + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.typing import IntPair, Resolution, Units + + from .io_config import ModelIOConfigABC from .engine_abc import EngineABC from .io_config import IOPatchPredictorConfig @@ -201,7 +206,7 @@ class PatchPredictor(EngineABC): """ def __init__( - self, + self: PatchPredictor, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, @@ -209,7 +214,7 @@ def __init__( pretrained_model: str | None = None, weights: str | None = None, *, - verbose=True, + verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" super().__init__( @@ -230,7 +235,11 @@ def infer_wsi(self: PatchPredictor) -> NoReturn: """Model inference on a WSI.""" ... - def post_process_patches(self: PatchPredictor) -> NoReturn: + def post_process_patches( + self: PatchPredictor, + raw_predictions: dict, + output_type: str, + ) -> None: """Post-process an image patch.""" ... @@ -247,7 +256,7 @@ def merge_predictions( post_proc_func: Callable | None = None, *, return_raw: bool = False, - ): + ) -> np.ndarray: """Merge patch level predictions to form a 2-dimensional prediction map. #! Improve how the below reads. @@ -357,14 +366,14 @@ def merge_predictions( return output def _predict_engine( - self, - dataset, + self: PatchPredictor, + dataset: torch.utils.data.Dataset, *, - return_probabilities=False, - return_labels=False, - return_coordinates=False, - device="cpu", - ): + return_probabilities: bool = False, + return_labels: bool = False, + return_coordinates: bool = False, + device: str = "cpu", + ) -> np.ndarray: """Make a prediction on a dataset. The dataset may be mutated. Args: @@ -450,13 +459,13 @@ def _predict_engine( return cum_output def _update_ioconfig( - self, - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ): + self: PatchPredictor, + ioconfig: IOPatchPredictorConfig, + patch_input_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOPatchPredictorConfig: """Update the ioconfig. Args: @@ -522,7 +531,15 @@ def _update_ioconfig( output_resolutions=[], ) - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): + def _predict_patch( + self: PatchPredictor, + imgs: list | np.ndarray, + labels: list, + *, + return_probabilities: bool, + return_labels: bool, + device: str, + ) -> np.ndarray: """Process patch mode. Args: @@ -540,8 +557,8 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the engine. Returns: :class:`numpy.ndarray`: @@ -566,23 +583,24 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g return_probabilities=return_probabilities, return_labels=return_labels, return_coordinates=return_coordinates, - on_gpu=on_gpu, + device=device, ) def _predict_tile_wsi( # noqa: PLR0913 - self, - imgs, - masks, - labels, - mode, - return_probabilities, - on_gpu, - ioconfig, - merge_predictions, - save_dir, - save_output, - highest_input_resolution, - ): + self: PatchPredictor, + imgs: list, + masks: list | None, + labels: list, + mode: str, + ioconfig: IOPatchPredictorConfig, + save_dir: str | Path, + highest_input_resolution: list[dict], + *, + save_output: bool, + return_probabilities: bool, + merge_predictions: bool, + on_gpu: bool, + ) -> list | dict: """Predict on Tile and WSIs. Args: @@ -714,29 +732,51 @@ def _predict_tile_wsi( # noqa: PLR0913 return file_dict if save_output else outputs - def run(self): + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: dict, + ) -> AnnotationStore | str: """Run engine.""" - super().run() + super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - labels=None, - mode="patch", + self: PatchPredictor, + imgs: list, + masks: list | None = None, + labels: list | None = None, + mode: str = "patch", ioconfig: IOPatchPredictorConfig | None = None, patch_input_shape: tuple[int, int] | None = None, stride_shape: tuple[int, int] | None = None, - resolution=None, - units=None, + resolution: Resolution | None = None, + units: Units = None, *, - return_probabilities=False, - return_labels=False, - on_gpu=True, - merge_predictions=False, - save_dir=None, - save_output=False, - ): + return_probabilities: bool = False, + return_labels: bool = False, + on_gpu: bool = True, + merge_predictions: bool = False, + save_dir: str | Path | None = None, + save_output: bool = False, + ) -> np.ndarray | list | dict: """Make a prediction for a list of input data. Args: From b0d11fecfe838dbe594f2678c34b9cefca0463d5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:38:26 +0100 Subject: [PATCH 090/112] :bug: Fix flake8-annotations for models_abc.py Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/models_abc.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 1da0342a0..98ca29911 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable import torch from torch import nn @@ -59,7 +59,7 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" - def __init__(self) -> None: + def __init__(self: ModelABC) -> None: """Initialize Abstract class ModelABC.""" super().__init__() self._postproc = self.postproc @@ -67,13 +67,13 @@ def __init__(self) -> None: @abstractmethod # This is generic abc, else pylint will complain - def forward(self, *args, **kwargs): + def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str): + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -91,22 +91,22 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str): ... # pragma: no cover @staticmethod - def preproc(image): + def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" return image @staticmethod - def postproc(image): + def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model.""" return image @property - def preproc_func(self): + def preproc_func(self: ModelABC) -> Callable: """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter - def preproc_func(self, func): + def preproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. @@ -131,12 +131,12 @@ def preproc_func(self, func): self._preproc = func @property - def postproc_func(self): + def postproc_func(self: ModelABC) -> Callable: """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter - def postproc_func(self, func): + def postproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. From ae06e45f7a80404da7a9f688dd847b7aaa5ce255 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Tue, 10 Oct 2023 18:39:38 +0100 Subject: [PATCH 091/112] Added Zarr save to post_process_patches method in engine_abc --- tests/engines/test_engine_abc.py | 13 +++++----- tiatoolbox/models/engine/engine_abc.py | 34 +++++++++++++++++++++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 8c7b60a6c..1cb88a449 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -6,6 +6,8 @@ import numpy as np import pytest +import zarr +import os from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir @@ -216,8 +218,7 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) - assert "predictions" in out - assert "labels" not in out + assert os.path.exists(out), f"Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( @@ -225,14 +226,12 @@ def test_engine_run() -> NoReturn: on_gpu=False, verbose=False, ) - assert "predictions" in out - assert "labels" not in out - + assert os.path.exists(out), f"Zarr output file does not exist" + eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, ) - assert "predictions" in out - assert "labels" in out + assert os.path.exists(out), f"Zarr output file does not exist" diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 83932d74a..3cfa8c95a 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, NoReturn +import numcodecs +import zarr import numpy as np import pandas as pd @@ -359,13 +361,33 @@ def infer_patches( def post_process_patches( self: EngineABC, raw_predictions: dict, - output_type: str, - ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: + output_type: str = "zarr", + save_dir: Path | None = None, + **kwargs: dict + ) -> Path | AnnotationStore: + """Post-process an image patches.""" - return self._convert_output_to_requested_type( - output=raw_predictions, - output_type=output_type, - ) + # Create a Zarr and return the Path + + if not save_dir: + save_dir = Path.cwd() + + """ Compressor and Chunks defaults set if not received from kwargs """ + compressor = kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + chunks = kwargs["chunks"] if "chunks" in kwargs else 10000 + + path_to_output_file = save_dir / "output.zarr" + + # save to zarr + predictions_array = np.array(raw_predictions["predictions"]) + z = zarr.open(path_to_output_file, mode='w', shape=predictions_array.shape, chunks=chunks, compressor=compressor) + z[:] = predictions_array + + if output_type is "AnnotationStore": + pass + # create_AnnotationStore() + + return path_to_output_file @abstractmethod def pre_process_wsi(self: EngineABC) -> NoReturn: From 034720e50a8ebc0a438c8aaddba79fcaac62369b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:53:28 +0000 Subject: [PATCH 092/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 11 +++++------ tiatoolbox/models/engine/engine_abc.py | 24 +++++++++++++++--------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 1cb88a449..6775d991b 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -1,13 +1,12 @@ """Test tiatoolbox.models.engine.engine_abc.""" from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING, NoReturn import numpy as np import pytest -import zarr -import os from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir @@ -218,7 +217,7 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) - assert os.path.exists(out), f"Zarr output file does not exist" + assert os.path.exists(out), "Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( @@ -226,12 +225,12 @@ def test_engine_run() -> NoReturn: on_gpu=False, verbose=False, ) - assert os.path.exists(out), f"Zarr output file does not exist" - + assert os.path.exists(out), "Zarr output file does not exist" + eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, ) - assert os.path.exists(out), f"Zarr output file does not exist" + assert os.path.exists(out), "Zarr output file does not exist" diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 3cfa8c95a..13ff51eca 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -5,13 +5,13 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, NoReturn -import numcodecs -import zarr +import numcodecs import numpy as np import pandas as pd import torch import tqdm +import zarr from torch import nn from tiatoolbox import logger @@ -363,29 +363,35 @@ def post_process_patches( raw_predictions: dict, output_type: str = "zarr", save_dir: Path | None = None, - **kwargs: dict + **kwargs: dict, ) -> Path | AnnotationStore: - """Post-process an image patches.""" # Create a Zarr and return the Path if not save_dir: - save_dir = Path.cwd() + save_dir = Path.cwd() """ Compressor and Chunks defaults set if not received from kwargs """ - compressor = kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + compressor = ( + kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + ) chunks = kwargs["chunks"] if "chunks" in kwargs else 10000 path_to_output_file = save_dir / "output.zarr" # save to zarr predictions_array = np.array(raw_predictions["predictions"]) - z = zarr.open(path_to_output_file, mode='w', shape=predictions_array.shape, chunks=chunks, compressor=compressor) + z = zarr.open( + path_to_output_file, + mode="w", + shape=predictions_array.shape, + chunks=chunks, + compressor=compressor, + ) z[:] = predictions_array - if output_type is "AnnotationStore": + if output_type == "AnnotationStore": pass - # create_AnnotationStore() return path_to_output_file From d1ce67107ae445454289c02c5da723630797d756 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 11 Oct 2023 11:57:54 +0100 Subject: [PATCH 093/112] Fix linter errors changed os.path to pathlib.Path --- .gitignore | 3 +++ tests/engines/test_engine_abc.py | 7 +++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a192542d6..16ea54a83 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,6 @@ ENV/ # vim/vi generated *.swp + +# output zarr generated +*.zarr diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 6775d991b..e23ca4e4e 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -1,7 +1,6 @@ """Test tiatoolbox.models.engine.engine_abc.""" from __future__ import annotations -import os from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -217,7 +216,7 @@ def test_engine_run() -> NoReturn: eng = TestEngineABC(model="alexnet-kather100k") out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) - assert os.path.exists(out), "Zarr output file does not exist" + assert Path.exists(out), "Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( @@ -225,7 +224,7 @@ def test_engine_run() -> NoReturn: on_gpu=False, verbose=False, ) - assert os.path.exists(out), "Zarr output file does not exist" + assert Path.exists(out), "Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( @@ -233,4 +232,4 @@ def test_engine_run() -> NoReturn: labels=list(range(10)), on_gpu=False, ) - assert os.path.exists(out), "Zarr output file does not exist" + assert Path.exists(out), "Zarr output file does not exist" From c2419f1fcd50cac02ec20b86755c3e1a0915abdc Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 03:30:23 +0100 Subject: [PATCH 094/112] Resolving review comments --- tests/engines/test_engine_abc.py | 81 ++++++++++++++- tests/test_utils.py | 37 +++++++ tiatoolbox/models/engine/engine_abc.py | 136 ++++++++++++++----------- tiatoolbox/utils/misc.py | 44 +++++++- 4 files changed, 233 insertions(+), 65 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index e23ca4e4e..c9d3e2f21 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -17,9 +17,9 @@ class TestEngineABC(EngineABC): """Test EngineABC.""" - def __init__(self: TestEngineABC, model: str | torch.nn.Module) -> NoReturn: + def __init__(self: TestEngineABC, model: str | torch.nn.Module, verbose: bool | None = None) -> NoReturn: """Test EngineABC init.""" - super().__init__(model=model) + super().__init__(model=model, verbose=verbose) def infer_wsi(self: EngineABC) -> NoReturn: """Test infer_wsi.""" @@ -73,6 +73,19 @@ def test_incorrect_ioconfig() -> NoReturn: engine.run(images=[], masks=[], ioconfig=None) +def test_pretrained_ioconfig() -> NoReturn: + """ Test EngineABC initialization with ioconfig from the pretrained model in the toolbox """ + + #pre-trained model as a string + pretrained_model = "alexnet-kather100k" + + """Test engine run without ioconfig""" + eng = TestEngineABC(model=pretrained_model) + out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, patch_mode=True, ioconfig=None) + assert "predictions" in out + assert "labels" not in out + + def test_prepare_engines_save_dir( tmp_path: pytest.TempPathFactory, caplog: pytest.LogCaptureFixture, @@ -215,7 +228,53 @@ def test_engine_run() -> NoReturn: ) eng = TestEngineABC(model="alexnet-kather100k") - out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False) + out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, patch_mode=True) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + ) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + assert "predictions" in out + assert "labels" in out + + +def test_engine_run_with_verbose() -> NoReturn: + """Test engine run with verbose""" + """Run pytest with `-rP` option to view progress bar on the captured stderr call""" + + eng = TestEngineABC(model="alexnet-kather100k", verbose=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + + assert "predictions" in out + assert "labels" in out + + +def test_patch_pred_zarr_store( + tmp_path: pytest.TempPathFactory +) -> NoReturn: + """Test the engine run and patch pred store""" + + save_dir=tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, save_dir=save_dir, overwrite=True) assert Path.exists(out), "Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") @@ -223,13 +282,29 @@ def test_engine_run() -> NoReturn: images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, verbose=False, + save_dir=save_dir, + overwrite=True + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True ) assert Path.exists(out), "Zarr output file does not exist" + ''' test custom zarr output file name''' eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_file="patch_pred_output" ) assert Path.exists(out), "Zarr output file does not exist" diff --git a/tests/test_utils.py b/tests/test_utils.py index 588b6b08b..784f6df81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1693,3 +1693,40 @@ def test_patch_pred_store_sf() -> None: assert len(store) == 3 for annotation in store.values(): assert annotation.geometry.area == 4 + +def test_patch_pred_store_persist( + tmp_path: pytest.TempPathFactory +) -> None: + """Test patch_pred_store. and persists store output to a .db file""" + # Define a mock patch_output + patch_output = { + "predictions": [1, 0, 1], + "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], + "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], + "labels": [1, 0, 1], + } + save_dir = tmp_path / "patch_output" + + store_path = misc.patch_pred_store( + patch_output, + (1.0, 1.0), + save_dir=save_dir, + output_file="patch_pred_output") + + print("Annotation store path: ", store_path) + assert Path.exists(store_path), "Annotation Store output file does not exist" + + store = SQLiteStore(store_path) + + # Check that its an SQLiteStore containing the expected annotations + assert isinstance(store, SQLiteStore) + assert len(store) == 3 + for annotation in store.values(): + assert annotation.geometry.area == 1 + assert annotation.properties["type"] in [0, 1] + assert "other" not in annotation.properties + + patch_output.pop("coordinates") + # check correct error is raised if coordinates are missing + with pytest.raises(ValueError, match="coordinates"): + misc.patch_pred_store(patch_output, (1.0, 1.0)) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 13ff51eca..426c8277c 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -6,18 +6,17 @@ from pathlib import Path from typing import TYPE_CHECKING, NoReturn -import numcodecs import numpy as np import pandas as pd import torch import tqdm -import zarr from torch import nn from tiatoolbox import logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import PatchDataset from tiatoolbox.models.models_abc import load_torch_model, model_to +from tiatoolbox.utils.misc import patch_pred_store, patch_pred_store_zarr if TYPE_CHECKING: # pragma: no cover import os @@ -47,7 +46,7 @@ def prepare_engines_save_dir( patch_mode(bool): Whether to treat input image as a patch or WSI. overwrite (bool): - Whether to overwrite the results. Default = False. + Whether to overwrite the results. Default = False. Returns: :class:`Path`: @@ -99,7 +98,7 @@ class EngineABC(ABC): Path to the weight of the corresponding `model`. >>> engine = EngineABC( - ... pretrained_model="pretrained-model-name", + ... model="pretrained-model-name", ... weights="pretrained-local-weights.pth") batch_size (int): @@ -122,7 +121,7 @@ class EngineABC(ABC): default = True. model (str | nn.Module): Defined PyTorch model. - Name of the existing models supported by the TIAToolbox for + Name of an existing model supported by the TIAToolbox for processing the data. For a full list of pretrained models, refer to the `docs `_ @@ -138,7 +137,7 @@ class EngineABC(ABC): Whether to return the labels with the predictions. merge_predictions (bool): Whether to merge the predictions to form a 2-dimensional - map. This is only applicable `patch_mode` is False in inference. + map. This is only applicable if `patch_mode` is False in inference. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -170,22 +169,28 @@ class EngineABC(ABC): >>> # array of list of 2 image patches as input >>> import numpy as np >>> data = np.array([np.ndarray, np.ndarray]) - >>> engine = EngineABC(pretrained_model="resnet18-kather100k") + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # array of list of 2 image patches as input + >>> import numpy as np + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = EngineABC(model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=True) >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] - >>> engine = EngineABC(pretrained_model="resnet18-kather100k") + >>> engine = EngineABC(model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=False) >>> # list of 2 image files as input >>> image = ['path/image1.png', 'path/image2.png'] - >>> engine = EngineABC(pretraind_model="resnet18-kather100k") + >>> engine = EngineABC(model="resnet18-kather100k") >>> output = engine.run(image, patch_mode=False) >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] - >>> engine = EngineABC(pretraind_model="resnet18-kather100k") + >>> engine = EngineABC(model="resnet18-kather100k") >>> output = engine.run(wsi_file, patch_mode=True) """ @@ -296,24 +301,6 @@ def pre_process_patches( shuffle=False, ) - @staticmethod - def _convert_output_to_requested_type( - output: dict, - output_type: str, - ) -> AnnotationStore | np.ndarray | pd.DataFrame | dict | str: - """Converts inference output to requested type.""" - # function convert output to output_type - if output_type.lower() == "array": - return np.array(output["predictions"]) - - if output_type.lower() == "json": - return json.dumps(output, indent=4) - - if output_type.lower() == "dataframe": - return pd.DataFrame.from_dict(data=output) - - return output - def infer_patches( self: EngineABC, data_loader: DataLoader, @@ -361,39 +348,43 @@ def infer_patches( def post_process_patches( self: EngineABC, raw_predictions: dict, - output_type: str = "zarr", + output_type: str, save_dir: Path | None = None, **kwargs: dict, ) -> Path | AnnotationStore: - """Post-process an image patches.""" - # Create a Zarr and return the Path - + + """Post-process image patches.""" + + """Stores as an Annotation Store or Zarr (default) and returns the Path""" + + if not save_dir and self.patch_mode and output_type != "AnnotationStore": + return raw_predictions + if not save_dir: - save_dir = Path.cwd() - - """ Compressor and Chunks defaults set if not received from kwargs """ - compressor = ( - kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) - ) - chunks = kwargs["chunks"] if "chunks" in kwargs else 10000 - - path_to_output_file = save_dir / "output.zarr" - - # save to zarr - predictions_array = np.array(raw_predictions["predictions"]) - z = zarr.open( - path_to_output_file, - mode="w", - shape=predictions_array.shape, - chunks=chunks, - compressor=compressor, - ) - z[:] = predictions_array - + raise OSError("`save_dir` not specified.") + + output_file=kwargs["output_file"] and kwargs.pop("output_file") if "output_file" in kwargs else "output" + if output_type == "AnnotationStore": - pass - - return path_to_output_file + #scale_factor set from kwargs + scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None + #class_dict set from kwargs + class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None + + return patch_pred_store( + raw_predictions, + scale_factor, + class_dict, + save_dir, + output_file + ) + + return patch_pred_store_zarr( + raw_predictions, + save_dir, + output_file, + **kwargs, + ) @abstractmethod def pre_process_wsi(self: EngineABC) -> NoReturn: @@ -503,7 +494,7 @@ def run( overwrite: bool = False, output_type: str = "dict", **kwargs: dict, - ) -> AnnotationStore | str: + ) -> AnnotationStore | Path | str: """Run the engine on input images. Args: @@ -561,8 +552,7 @@ def run( Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] - >>> predictor = EngineABC( - ... pretrained_model="resnet18-kather100k") + >>> predictor = EngineABC(model="resnet18-kather100k") >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] @@ -570,10 +560,34 @@ def run( ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} >>> output['wsi2.svs'] ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} - + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> output = predictor.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> labels=list(range(10)), + >>> on_gpu=False, + >>> ) + >>> output + ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., 0.034451354295015335, 0.004817609209567308]], + ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> save_dir = Path("/tmp/patch_output/") + >>> output = eng.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> on_gpu=False, + >>> verbose=False, + >>> save_dir=save_dir, + >>> overwrite=True + >>> ) + >>> output + ... /tmp/patch_output/output.zarr """ + for key in kwargs: setattr(self, key, kwargs[key]) + + self.patch_mode = patch_mode self._validate_input_numbers(images=images, masks=masks, labels=labels) self.images = self._validate_images_masks(images=images) @@ -605,6 +619,8 @@ def run( return self.post_process_patches( raw_predictions=raw_predictions, output_type=output_type, + save_dir=save_dir, + **kwargs ) return {"save_dir": save_dir} diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 21a8d18cd..4fb1dfc4a 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -15,6 +15,8 @@ import pandas as pd import requests import yaml +import zarr +import numcodecs from filelock import FileLock from shapely.affinity import translate from shapely.geometry import Polygon @@ -1182,7 +1184,9 @@ def patch_pred_store( patch_output: dict, scale_factor: tuple[int, int], class_dict: dict | None = None, -) -> AnnotationStore: + save_dir: Path | None = None, + output_file: str | None = None, +) -> AnnotationStore | Path: """Create an SQLiteStore containing Annotations for each patch. Args: @@ -1237,5 +1241,41 @@ def patch_pred_store( annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) store = SQLiteStore() keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) + + if not save_dir: + return store + + else: + output_file += ".db" + path_to_output_file = save_dir / output_file + save_dir.mkdir(parents=True, exist_ok=True) + store.dump(path_to_output_file) + return path_to_output_file + +def patch_pred_store_zarr( + raw_predictions: dict, + save_dir: Path, + output_file: str, + **kwargs:dict, +) -> Path: + + """ Default values for Compressor and Chunks set if not received from kwargs """ + compressor = ( + kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) + ) + chunks = kwargs["chunks"] if "chunks" in kwargs else 10000 + + # save to zarr + output_file += ".zarr" + path_to_output_file = save_dir / output_file + predictions_array = np.array(raw_predictions["predictions"]) + z = zarr.open( + path_to_output_file, + mode="w", + shape=predictions_array.shape, + chunks=chunks, + compressor=compressor, + ) + z[:] = predictions_array - return store + return path_to_output_file From 9b805f809e7a05efc88f211d7ab8d5f29ab65536 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 02:31:08 +0000 Subject: [PATCH 095/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 48 ++++++++++++++++---------- tests/test_utils.py | 15 ++++---- tiatoolbox/models/engine/engine_abc.py | 44 +++++++++++------------ tiatoolbox/utils/misc.py | 12 +++---- 4 files changed, 62 insertions(+), 57 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index c9d3e2f21..6f911b06f 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -17,7 +17,9 @@ class TestEngineABC(EngineABC): """Test EngineABC.""" - def __init__(self: TestEngineABC, model: str | torch.nn.Module, verbose: bool | None = None) -> NoReturn: + def __init__( + self: TestEngineABC, model: str | torch.nn.Module, verbose: bool | None = None, + ) -> NoReturn: """Test EngineABC init.""" super().__init__(model=model, verbose=verbose) @@ -74,14 +76,18 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: - """ Test EngineABC initialization with ioconfig from the pretrained model in the toolbox """ - - #pre-trained model as a string + """Test EngineABC initialization with ioconfig from the pretrained model in the toolbox.""" + # pre-trained model as a string pretrained_model = "alexnet-kather100k" """Test engine run without ioconfig""" eng = TestEngineABC(model=pretrained_model) - out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, patch_mode=True, ioconfig=None) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + ) assert "predictions" in out assert "labels" not in out @@ -228,7 +234,11 @@ def test_engine_run() -> NoReturn: ) eng = TestEngineABC(model="alexnet-kather100k") - out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, patch_mode=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ) assert "predictions" in out assert "labels" not in out @@ -252,7 +262,7 @@ def test_engine_run() -> NoReturn: def test_engine_run_with_verbose() -> NoReturn: - """Test engine run with verbose""" + """Test engine run with verbose.""" """Run pytest with `-rP` option to view progress bar on the captured stderr call""" eng = TestEngineABC(model="alexnet-kather100k", verbose=True) @@ -266,15 +276,17 @@ def test_engine_run_with_verbose() -> NoReturn: assert "labels" in out -def test_patch_pred_zarr_store( - tmp_path: pytest.TempPathFactory -) -> NoReturn: - """Test the engine run and patch pred store""" - - save_dir=tmp_path / "patch_output" +def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the engine run and patch pred store.""" + save_dir = tmp_path / "patch_output" eng = TestEngineABC(model="alexnet-kather100k") - out = eng.run(images=np.zeros((10, 224, 224, 3), dtype=np.uint8), on_gpu=False, save_dir=save_dir, overwrite=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) assert Path.exists(out), "Zarr output file does not exist" eng = TestEngineABC(model="alexnet-kather100k") @@ -283,7 +295,7 @@ def test_patch_pred_zarr_store( on_gpu=False, verbose=False, save_dir=save_dir, - overwrite=True + overwrite=True, ) assert Path.exists(out), "Zarr output file does not exist" @@ -293,11 +305,11 @@ def test_patch_pred_zarr_store( labels=list(range(10)), on_gpu=False, save_dir=save_dir, - overwrite=True + overwrite=True, ) assert Path.exists(out), "Zarr output file does not exist" - ''' test custom zarr output file name''' + """ test custom zarr output file name""" eng = TestEngineABC(model="alexnet-kather100k") out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), @@ -305,6 +317,6 @@ def test_patch_pred_zarr_store( on_gpu=False, save_dir=save_dir, overwrite=True, - output_file="patch_pred_output" + output_file="patch_pred_output", ) assert Path.exists(out), "Zarr output file does not exist" diff --git a/tests/test_utils.py b/tests/test_utils.py index 784f6df81..f6082ce3c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1694,10 +1694,9 @@ def test_patch_pred_store_sf() -> None: for annotation in store.values(): assert annotation.geometry.area == 4 -def test_patch_pred_store_persist( - tmp_path: pytest.TempPathFactory -) -> None: - """Test patch_pred_store. and persists store output to a .db file""" + +def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: + """Test patch_pred_store. and persists store output to a .db file.""" # Define a mock patch_output patch_output = { "predictions": [1, 0, 1], @@ -1708,11 +1707,9 @@ def test_patch_pred_store_persist( save_dir = tmp_path / "patch_output" store_path = misc.patch_pred_store( - patch_output, - (1.0, 1.0), - save_dir=save_dir, - output_file="patch_pred_output") - + patch_output, (1.0, 1.0), save_dir=save_dir, output_file="patch_pred_output", + ) + print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 426c8277c..90143854a 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,13 +1,11 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" from __future__ import annotations -import json from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, NoReturn import numpy as np -import pandas as pd import torch import tqdm from torch import nn @@ -352,33 +350,32 @@ def post_process_patches( save_dir: Path | None = None, **kwargs: dict, ) -> Path | AnnotationStore: - """Post-process image patches.""" - """Stores as an Annotation Store or Zarr (default) and returns the Path""" - + if not save_dir and self.patch_mode and output_type != "AnnotationStore": return raw_predictions - + if not save_dir: - raise OSError("`save_dir` not specified.") - - output_file=kwargs["output_file"] and kwargs.pop("output_file") if "output_file" in kwargs else "output" - + msg = "`save_dir` not specified." + raise OSError(msg) + + output_file = ( + kwargs["output_file"] and kwargs.pop("output_file") + if "output_file" in kwargs + else "output" + ) + if output_type == "AnnotationStore": - #scale_factor set from kwargs + # scale_factor set from kwargs scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None - #class_dict set from kwargs + # class_dict set from kwargs class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None return patch_pred_store( - raw_predictions, - scale_factor, - class_dict, - save_dir, - output_file + raw_predictions, scale_factor, class_dict, save_dir, output_file, ) - + return patch_pred_store_zarr( raw_predictions, save_dir, @@ -560,7 +557,7 @@ def run( ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} >>> output['wsi2.svs'] ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} - + >>> predictor = EngineABC(model="alexnet-kather100k") >>> output = predictor.run( >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), @@ -568,9 +565,9 @@ def run( >>> on_gpu=False, >>> ) >>> output - ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., 0.034451354295015335, 0.004817609209567308]], + ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., 0.034451354295015335, 0.004817609209567308]], ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} - + >>> predictor = EngineABC(model="alexnet-kather100k") >>> save_dir = Path("/tmp/patch_output/") >>> output = eng.run( @@ -583,10 +580,9 @@ def run( >>> output ... /tmp/patch_output/output.zarr """ - for key in kwargs: setattr(self, key, kwargs[key]) - + self.patch_mode = patch_mode self._validate_input_numbers(images=images, masks=masks, labels=labels) @@ -620,7 +616,7 @@ def run( raw_predictions=raw_predictions, output_type=output_type, save_dir=save_dir, - **kwargs + **kwargs, ) return {"save_dir": save_dir} diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 4fb1dfc4a..b062cb950 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -11,12 +11,12 @@ import cv2 import joblib +import numcodecs import numpy as np import pandas as pd import requests import yaml import zarr -import numcodecs from filelock import FileLock from shapely.affinity import translate from shapely.geometry import Polygon @@ -1241,10 +1241,10 @@ def patch_pred_store( annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) store = SQLiteStore() keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) - + if not save_dir: return store - + else: output_file += ".db" path_to_output_file = save_dir / output_file @@ -1252,14 +1252,14 @@ def patch_pred_store( store.dump(path_to_output_file) return path_to_output_file + def patch_pred_store_zarr( raw_predictions: dict, save_dir: Path, output_file: str, - **kwargs:dict, + **kwargs: dict, ) -> Path: - - """ Default values for Compressor and Chunks set if not received from kwargs """ + """Default values for Compressor and Chunks set if not received from kwargs.""" compressor = ( kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) ) From bdcee000dacc8b6590021e511d227baa8f64ced8 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 10:02:49 +0100 Subject: [PATCH 096/112] Fix linter errors resolved --- tiatoolbox/models/engine/engine_abc.py | 6 ++++-- tiatoolbox/utils/misc.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 426c8277c..e94b817d1 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -568,8 +568,10 @@ def run( >>> on_gpu=False, >>> ) >>> output - ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., 0.034451354295015335, 0.004817609209567308]], - ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} + ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., + ... 0.034451354295015335, 0.004817609209567308]], + ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), + ... tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} >>> predictor = EngineABC(model="alexnet-kather100k") >>> save_dir = Path("/tmp/patch_output/") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 4fb1dfc4a..f22879002 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1197,9 +1197,17 @@ def patch_pred_store( conversion of annotations saved at non-baseline resolution to baseline. Should be model_mpp/slide_mpp. class_dict (dict): Optional dictionary mapping class indices to class names. + save_dir (str or pathlib.Path): Optional Output directory to save the Annotation + Store results. if the save_dir is not provided, then an SQLiteStore object + containing Annotations for each patch is returned. + output_file (str): Optional file name to save the Annotation Store results. + if the output_file is not provided, then an SQLiteStore object + containing Annotations for each patch is returned. + Returns: - SQLiteStore: An SQLiteStore containing Annotations for each patch. + SQLiteStore: An SQLiteStore containing Annotations for each patch + or Path to file storing SQLiteStore containing Annotations for each patch """ if "coordinates" not in patch_output: @@ -1242,15 +1250,15 @@ def patch_pred_store( store = SQLiteStore() keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) - if not save_dir: - return store - - else: + #if a save director is provided, then dump store into a file + if save_dir and output_file: output_file += ".db" path_to_output_file = save_dir / output_file save_dir.mkdir(parents=True, exist_ok=True) store.dump(path_to_output_file) return path_to_output_file + + return store def patch_pred_store_zarr( raw_predictions: dict, From 4467e78b2eb475dd5becc1be36fe37e1a65e1ce9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:07:05 +0000 Subject: [PATCH 097/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 4 +++- tests/test_utils.py | 5 ++++- tiatoolbox/models/engine/engine_abc.py | 12 ++++++++---- tiatoolbox/utils/misc.py | 10 +++++----- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 6f911b06f..f3230b05d 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -18,7 +18,9 @@ class TestEngineABC(EngineABC): """Test EngineABC.""" def __init__( - self: TestEngineABC, model: str | torch.nn.Module, verbose: bool | None = None, + self: TestEngineABC, + model: str | torch.nn.Module, + verbose: bool | None = None, ) -> NoReturn: """Test EngineABC init.""" super().__init__(model=model, verbose=verbose) diff --git a/tests/test_utils.py b/tests/test_utils.py index f6082ce3c..847aae07a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1707,7 +1707,10 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: save_dir = tmp_path / "patch_output" store_path = misc.patch_pred_store( - patch_output, (1.0, 1.0), save_dir=save_dir, output_file="patch_pred_output", + patch_output, + (1.0, 1.0), + save_dir=save_dir, + output_file="patch_pred_output", ) print("Annotation store path: ", store_path) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 22935551b..b2d9fef33 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -373,7 +373,11 @@ def post_process_patches( class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None return patch_pred_store( - raw_predictions, scale_factor, class_dict, save_dir, output_file, + raw_predictions, + scale_factor, + class_dict, + save_dir, + output_file, ) return patch_pred_store_zarr( @@ -566,10 +570,10 @@ def run( >>> ) >>> output ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., - ... 0.034451354295015335, 0.004817609209567308]], - ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), + ... 0.034451354295015335, 0.004817609209567308]], + ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), ... tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} - + >>> predictor = EngineABC(model="alexnet-kather100k") >>> save_dir = Path("/tmp/patch_output/") >>> output = eng.run( diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index feada46d0..73a2bb3cc 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1200,14 +1200,14 @@ def patch_pred_store( save_dir (str or pathlib.Path): Optional Output directory to save the Annotation Store results. if the save_dir is not provided, then an SQLiteStore object containing Annotations for each patch is returned. - output_file (str): Optional file name to save the Annotation Store results. + output_file (str): Optional file name to save the Annotation Store results. if the output_file is not provided, then an SQLiteStore object containing Annotations for each patch is returned. Returns: - SQLiteStore: An SQLiteStore containing Annotations for each patch - or Path to file storing SQLiteStore containing Annotations for each patch + SQLiteStore: An SQLiteStore containing Annotations for each patch + or Path to file storing SQLiteStore containing Annotations for each patch """ if "coordinates" not in patch_output: @@ -1250,14 +1250,14 @@ def patch_pred_store( store = SQLiteStore() keys = store.append_many(annotations, [str(i) for i in range(len(annotations))]) - #if a save director is provided, then dump store into a file + # if a save director is provided, then dump store into a file if save_dir and output_file: output_file += ".db" path_to_output_file = save_dir / output_file save_dir.mkdir(parents=True, exist_ok=True) store.dump(path_to_output_file) return path_to_output_file - + return store From f8d82928537a7c10c48bab264220b524da5a1805 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 10:23:03 +0100 Subject: [PATCH 098/112] fix linter error resolved --- tests/engines/test_engine_abc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 6f911b06f..0da87f735 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -76,7 +76,8 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: - """Test EngineABC initialization with ioconfig from the pretrained model in the toolbox.""" + """Test EngineABC initialization with ioconfig from + the pretrained model in the toolbox.""" # pre-trained model as a string pretrained_model = "alexnet-kather100k" From 7f6941a0c563f2b2cdf358fab6918e56c096abc8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:26:21 +0000 Subject: [PATCH 099/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 8c599d7a6..0a6b738ae 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -78,8 +78,9 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: - """Test EngineABC initialization with ioconfig from - the pretrained model in the toolbox.""" + """Test EngineABC initialization with ioconfig from + the pretrained model in the toolbox. + """ # pre-trained model as a string pretrained_model = "alexnet-kather100k" From 40117ff7f0930143f381c33217c49ecd689d033e Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 10:32:17 +0100 Subject: [PATCH 100/112] fix linter error resolved --- tests/engines/test_engine_abc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 8c599d7a6..d0752759b 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -80,6 +80,7 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with ioconfig from the pretrained model in the toolbox.""" + # pre-trained model as a string pretrained_model = "alexnet-kather100k" From ec8fe97415e27964d77c7a3588bc1f4ff7139cf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:34:38 +0000 Subject: [PATCH 101/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index bc51fff7f..0a6b738ae 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -78,11 +78,9 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: - """Test EngineABC initialization with ioconfig from the pretrained model in the toolbox. """ - # pre-trained model as a string pretrained_model = "alexnet-kather100k" From cad3a4fe029a6882f6364139e7fb9c0c364a2323 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 10:50:37 +0100 Subject: [PATCH 102/112] fix D205 1 blank line required between summary line and description --- tests/engines/test_engine_abc.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index bc51fff7f..c8b5e8658 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -78,15 +78,10 @@ def test_incorrect_ioconfig() -> NoReturn: def test_pretrained_ioconfig() -> NoReturn: - - """Test EngineABC initialization with ioconfig from - the pretrained model in the toolbox. - """ - - # pre-trained model as a string + """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" - """Test engine run without ioconfig""" + #Test engine run without ioconfig eng = TestEngineABC(model=pretrained_model) out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), From 5882be6f27a316a2f0765a9215caded35b7e70fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:52:09 +0000 Subject: [PATCH 103/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/engines/test_engine_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index c8b5e8658..cfa1ff7bb 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -81,7 +81,7 @@ def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" - #Test engine run without ioconfig + # Test engine run without ioconfig eng = TestEngineABC(model=pretrained_model) out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), From e7df499309816706eec95aba0f0dd9d620803abc Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Wed, 18 Oct 2023 11:40:02 +0100 Subject: [PATCH 104/112] docstring sytax error fixed --- tiatoolbox/models/engine/engine_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index b2d9fef33..2d6d72bcc 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -584,7 +584,7 @@ def run( >>> overwrite=True >>> ) >>> output - ... /tmp/patch_output/output.zarr + ... '/tmp/patch_output/output.zarr' """ for key in kwargs: setattr(self, key, kwargs[key]) From e86f440f6811253e08986a44125600974b056dc9 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Fri, 20 Oct 2023 13:37:34 +0100 Subject: [PATCH 105/112] Improving engine_abc based on review comments --- tests/engines/test_engine_abc.py | 1 - tiatoolbox/models/engine/engine_abc.py | 45 +++++++++++++++++--------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index cfa1ff7bb..860472197 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -60,7 +60,6 @@ def test_engine_abc_incorrect_model_type() -> NoReturn: TypeError, match="Input model must be a string or 'torch.nn.Module'.", ): - # Can't instantiate abstract class with abstract methods TestEngineABC(model=1) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 2d6d72bcc..4354be339 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -112,7 +112,7 @@ class EngineABC(ABC): Whether to output logging information. Attributes: - images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): A NHWC image or a path to WSI. patch_mode (str): Whether to treat input image as a patch or WSI. @@ -144,7 +144,7 @@ class EngineABC(ABC): from either `level`, `power` or `mpp`. Please see :obj:`WSIReader` for details. patch_input_shape (tuple): - Size of patches input to the model. Patches are at + Shape of patches input to the model as tupled of HW. Patches are at requested read resolution, not with respect to level 0, and must be positive. stride_shape (tuple): @@ -176,11 +176,6 @@ class EngineABC(ABC): >>> engine = EngineABC(model="resnet18-kather100k") >>> output = engine.run(data, patch_mode=True) - >>> # list of 2 image patch files as input - >>> data = ['path/img.png', 'path/img.png'] - >>> engine = EngineABC(model="resnet18-kather100k") - >>> output = engine.run(data, patch_mode=False) - >>> # list of 2 image files as input >>> image = ['path/image1.png', 'path/image2.png'] >>> engine = EngineABC(model="resnet18-kather100k") @@ -343,23 +338,43 @@ def infer_patches( return raw_predictions - def post_process_patches( + def setup_patch_dataset( self: EngineABC, raw_predictions: dict, output_type: str, save_dir: Path | None = None, **kwargs: dict, ) -> Path | AnnotationStore: - """Post-process image patches.""" - """Stores as an Annotation Store or Zarr (default) and returns the Path""" + """Post-process image patches. - if not save_dir and self.patch_mode and output_type != "AnnotationStore": - return raw_predictions + Args: + raw_predictions (dict): + A dictionary of patch prediction information. + save_dir (Path): + Optional Output Path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided patch_mode is True. if the patch_mode is + False then save_dir is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (dict): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (dict, Path, :class:`SQLiteStore`): + if the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations for each or the + Path to a `.db` file depending on whether a save_dir Path is provided. + Otherwise, the function defaults to returning patch predictor output, either + as a dict or the Path to a `.zarr` file depending on whether a save_dir Path + is provided. - if not save_dir: - msg = "`save_dir` not specified." + """ + if not save_dir and not self.patch_mode: + msg = "`save_dir` must be specified when patch_mode is False." raise OSError(msg) + if not save_dir and output_type != "AnnotationStore": + return raw_predictions + output_file = ( kwargs["output_file"] and kwargs.pop("output_file") if "output_file" in kwargs @@ -618,7 +633,7 @@ def run( raw_predictions = self.infer_patches( data_loader=data_loader, ) - return self.post_process_patches( + return self.setup_patch_dataset( raw_predictions=raw_predictions, output_type=output_type, save_dir=save_dir, From 122e114bb73f5cada69ee481fd020c484a3528ee Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Mon, 23 Oct 2023 11:33:17 +0100 Subject: [PATCH 106/112] added link to pretrained models --- tiatoolbox/models/engine/engine_abc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 4354be339..6324c78ea 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -264,6 +264,8 @@ def _initialize_model_ioconfig( if isinstance(model, str): # ioconfig is retrieved from the pretrained model in the toolbox. + # list of pretrained models in the TIA Toolbox is available here: + # https://tia-toolbox.readthedocs.io/en/add-bokeh-app/pretrained.html # no need to provide ioconfig in EngineABC.run() this case. return get_pretrained_model(model, weights) From b47cd35ffbb4d0ed9443a4a941a61195d06dfd10 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 26 Oct 2023 10:41:51 +0100 Subject: [PATCH 107/112] Delete ' --- ' | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 ' diff --git a/' b/' deleted file mode 100644 index fa2d74ea0..000000000 --- a/' +++ /dev/null @@ -1,6 +0,0 @@ -Merge branch 'dev-redefine-patchpredictor' of https://github.com/TissueImageAnalytics/tiatoolbox into dev-redefine-patchpredictor -# Please enter a commit message to explain why this merge is necessary, -# especially if it merges an updated upstream into a topic branch. -# -# Lines starting with '#' will be ignored, and an empty message aborts -# the commit. From 92bde5b585fca1c4cfa4292f982a88c45eed730c Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Thu, 26 Oct 2023 13:16:49 +0100 Subject: [PATCH 108/112] update method patch_pred method names in EngineABC --- tiatoolbox/models/engine/engine_abc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 6324c78ea..164e02915 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -14,7 +14,7 @@ from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import PatchDataset from tiatoolbox.models.models_abc import load_torch_model, model_to -from tiatoolbox.utils.misc import patch_pred_store, patch_pred_store_zarr +from tiatoolbox.utils.misc import dict_to_store, dict_to_zarr if TYPE_CHECKING: # pragma: no cover import os @@ -389,7 +389,7 @@ def setup_patch_dataset( # class_dict set from kwargs class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None - return patch_pred_store( + return dict_to_store( raw_predictions, scale_factor, class_dict, @@ -397,7 +397,7 @@ def setup_patch_dataset( output_file, ) - return patch_pred_store_zarr( + return dict_to_zarr( raw_predictions, save_dir, output_file, From ac4fc45e74c93d5d5b7e22126af163719a10f5f2 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Thu, 26 Oct 2023 13:43:54 +0100 Subject: [PATCH 109/112] fix pytest errors --- tiatoolbox/models/engine/engine_abc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 164e02915..67555eb23 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -383,6 +383,8 @@ def setup_patch_dataset( else "output" ) + save_path = save_dir / output_file + if output_type == "AnnotationStore": # scale_factor set from kwargs scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None @@ -393,14 +395,12 @@ def setup_patch_dataset( raw_predictions, scale_factor, class_dict, - save_dir, - output_file, + save_path ) return dict_to_zarr( raw_predictions, - save_dir, - output_file, + save_path, **kwargs, ) From 53594fe3809aa2ff340307b5408d1a30edb5ac8d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Oct 2023 12:45:16 +0000 Subject: [PATCH 110/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/engine/engine_abc.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 67555eb23..08871c54b 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -383,7 +383,7 @@ def setup_patch_dataset( else "output" ) - save_path = save_dir / output_file + save_path = save_dir / output_file if output_type == "AnnotationStore": # scale_factor set from kwargs @@ -391,12 +391,7 @@ def setup_patch_dataset( # class_dict set from kwargs class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None - return dict_to_store( - raw_predictions, - scale_factor, - class_dict, - save_path - ) + return dict_to_store(raw_predictions, scale_factor, class_dict, save_path) return dict_to_zarr( raw_predictions, From 3c1445f6735e7a2add5a7ca898fd0925fcc07507 Mon Sep 17 00:00:00 2001 From: abishekrajvg Date: Thu, 26 Oct 2023 15:57:43 +0100 Subject: [PATCH 111/112] updates to model_to() and load_torch_model() methods in ModelABC --- tiatoolbox/models/models_abc.py | 54 ++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 98ca29911..6904b09ad 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -13,6 +13,7 @@ import numpy as np +#Draft - will be moved into ModelABC as a class method def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: """Helper function to load a torch model. @@ -34,6 +35,7 @@ def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: return model +#Draft - will be moved into ModelABC as a class method def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. @@ -46,14 +48,15 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: Returns: torch.nn.Module: The model after being moved to cpu/gpu. - """ - if device != "cpu": - # DataParallel work only for cuda + device = torch.device(device) + model = model.to(device) + + # If target device is CUDA and more than one GPU is available, use DataParallel + if device.type == "cuda" and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) - device = torch.device(device) - return model.to(device) + return model class ModelABC(ABC, nn.Module): @@ -71,6 +74,47 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover + def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + """ + device = torch.device(device) + model = super().to(device) + + # If target device is CUDA and more than one GPU is available, use DataParallel + if device.type == "cuda" and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + return model + + def load_weights_from_path(self: ModelABC, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. + + Args: + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + + return self.load_state_dict(saved_state_dict, strict=True) + + @staticmethod @abstractmethod def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: From f53a75bf1fd386d5b577106f610d99d4ed6e9f73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Oct 2023 15:14:37 +0000 Subject: [PATCH 112/112] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tiatoolbox/models/models_abc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 6904b09ad..09aef76dc 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -13,7 +13,7 @@ import numpy as np -#Draft - will be moved into ModelABC as a class method +# Draft - will be moved into ModelABC as a class method def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: """Helper function to load a torch model. @@ -35,7 +35,7 @@ def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: return model -#Draft - will be moved into ModelABC as a class method +# Draft - will be moved into ModelABC as a class method def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: """Transfers model to cpu/gpu. @@ -114,7 +114,6 @@ def load_weights_from_path(self: ModelABC, weights: str | Path) -> nn.Module: return self.load_state_dict(saved_state_dict, strict=True) - @staticmethod @abstractmethod def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: