diff --git a/.gitignore b/.gitignore index a192542d6..16ea54a83 100644 --- a/.gitignore +++ b/.gitignore @@ -115,3 +115,6 @@ ENV/ # vim/vi generated *.swp + +# output zarr generated +*.zarr diff --git a/tests/engines/_test_feature_extractor.py b/tests/engines/_test_feature_extractor.py deleted file mode 100644 index 3315cf0c3..000000000 --- a/tests/engines/_test_feature_extractor.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Test for feature extractor.""" - -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import torch - -from tiatoolbox.models import IOSegmentorConfig -from tiatoolbox.models.architecture.vanilla import CNNBackbone -from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_functional(remote_sample: Callable, tmp_path: Path) -> None: - """Test for feature extraction.""" - save_dir = tmp_path / "output" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * test providing pretrained from torch vs pretrained_model.yaml - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - extractor = DeepFeatureExtractor(batch_size=1, pretrained_model="fcn-tissue_mask") - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - assert len(features.shape) == 4 - - # * test same output between full infer and engine - # pre-emptive clean up - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - ], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[256, 256], - save_resolution={"units": "mpp", "resolution": 8.0}, - ) - - model = CNNBackbone("resnet50") - extractor = DeepFeatureExtractor(batch_size=4, model=model) - # should still run because we skip exception - output_list = extractor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - wsi_0_root_path = output_list[0][1] - positions = np.load(f"{wsi_0_root_path}.position.npy") - features = np.load(f"{wsi_0_root_path}.features.0.npy") - - reader = WSIReader.open(mini_wsi_svs) - patches = [ - reader.read_bounds( - positions[patch_idx], - resolution=0.25, - units="mpp", - pad_constant_values=0, - coord_space="resolution", - ) - for patch_idx in range(4) - ] - patches = np.array(patches) - patches = torch.from_numpy(patches) # NHWC - patches = patches.permute(0, 3, 1, 2) # NCHW - patches = patches.type(torch.float32) - model = model.to("cpu") - # Inference mode - model.eval() - with torch.inference_mode(): - _features = model(patches).numpy() - # ! must maintain same batch size and likely same ordering - # ! else the output values will not exactly be the same (still < 1.0e-4 - # ! of epsilon though) - assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 diff --git a/tests/engines/_test_multi_task_segmentor.py b/tests/engines/_test_multi_task_segmentor.py deleted file mode 100644 index c3cc85cea..000000000 --- a/tests/engines/_test_multi_task_segmentor.py +++ /dev/null @@ -1,423 +0,0 @@ -"""Unit test package for HoVerNet+.""" - -import copy - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest - -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - MultiTaskSegmentor, - SemanticSegmentor, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection - -ON_GPU = toolbox_env.has_gpu() -BATCH_SIZE = 1 if not ON_GPU else 8 # 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def semantic_postproc_func(raw_output: np.ndarray) -> np.ndarray: - """Function to post process semantic segmentations. - - Post processes semantic segmentation to form one map output. - - """ - return np.argmax(raw_output, axis=-1) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for multi task segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("svs-1-small")) - save_dir = root_save_dir / "multitask" - shutil.rmtree(save_dir, ignore_errors=True) - - # * generate full output w/o parallel post-processing worker first - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_a = joblib.load(f"{output[0][1]}.0.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - assert multi_segmentor.num_postproc_workers == NUM_POSTPROC_WORKERS - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict_b = joblib.load(f"{output[0][1]}.0.dat") - layer_map_b = np.load(f"{output[0][1]}.1.npy") - assert len(inst_dict_b) > 0, "Must have some nuclei" - assert layer_map_b is not None, "Must have some layers." - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - -def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - required_dims = (258, 258) - # above image is 512 x 512 at 0.252 mpp resolution. This is 258 x 258 at 0.500 mpp. - - save_dir = f"{root_save_dir}/multi/" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - layer_map = np.load(f"{output[0][1]}.1.npy") - - assert len(inst_dict) > 0, "Must have some nuclei." - assert layer_map is not None, "Must have some layers." - assert ( - layer_map.shape == required_dims - ), "Output layer map dimensions must be same as the expected output shape" - - -def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test segmentor when image is masked.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = root_save_dir / "instance" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - output = multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - inst_dict = joblib.load(f"{output[0][1]}.0.dat") - - assert len(inst_dict) > 0, "Must have some nuclei." - - -def test_functionality_process_instance_predictions( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test the functionality of instance predictions processing.""" - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(4)] - - dummy_reference = [{i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)}] - - dummy_tiles = [np.zeros((512, 512))] - dummy_bounds = np.array([0, 0, 512, 512]) - - multi_segmentor.wsi_layers = [np.zeros_like(raw_maps[0][..., 0])] - multi_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - multi_segmentor._futures = [ - [dummy_reference, [dummy_reference[0].keys()], dummy_tiles, dummy_bounds], - ] - multi_segmentor._merge_post_process_results() - assert len(multi_segmentor._wsi_inst_info[0]) == 0 - - -def test_empty_image(tmp_path: Path) -> None: - """Test MultiTaskSegmentor for an empty image.""" - root_save_dir = Path(tmp_path) - sample_patch = np.ones((256, 256, 3), dtype="uint8") * 255 - sample_patch_path = root_save_dir / "sample_tile.png" - imwrite(sample_patch_path, sample_patch) - - save_dir = root_save_dir / "hovernetplus" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernetplus-oed", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "hovernet" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - save_dir = root_save_dir / "semantic" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=(2048, 2048), - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - _ = multi_segmentor.predict( - [sample_patch_path], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - -def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for multitask segmentor.""" - root_save_dir = Path(tmp_path) - - save_dir = root_save_dir / "multi" - shutil.rmtree(save_dir, ignore_errors=True) - with pytest.raises( - ValueError, - match=r"Output type must be specified for instance or semantic segmentation.", - ): - MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - ) - - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - save_dir = f"{root_save_dir}/multi/" - - multi_segmentor = MultiTaskSegmentor( - pretrained_model="fcn_resnet50_unet-bcss", - batch_size=BATCH_SIZE, - num_postproc_workers=NUM_POSTPROC_WORKERS, - output_types=["semantic"], - ) - - bcc_wsi_ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[{"units": "mpp", "resolution": 0.25}], - tile_shape=2048, - patch_input_shape=[1024, 1024], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - margin=128, - save_resolution={"units": "mpp", "resolution": 2}, - ) - - multi_segmentor.model.postproc_func = semantic_postproc_func - - output = multi_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ioconfig=bcc_wsi_ioconfig, - ) - - layer_map = np.load(f"{output[0][1]}.0.npy") - - assert layer_map is not None, "Must have some segmentations." - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/multi/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=[512, 512], - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - multi_segmentor = MultiTaskSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernetplus-oed", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree(save_dir, ignore_errors=True) - multi_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Crash."): - multi_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) diff --git a/tests/engines/_test_nucleus_instance_segmentor.py b/tests/engines/_test_nucleus_instance_segmentor.py deleted file mode 100644 index 6d3ea2f67..000000000 --- a/tests/engines/_test_nucleus_instance_segmentor.py +++ /dev/null @@ -1,596 +0,0 @@ -"""Test for Nucleus Instance Segmentor.""" - -import copy - -# ! The garbage collector -import gc -import shutil -from pathlib import Path -from typing import Callable - -import joblib -import numpy as np -import pytest -import yaml -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import ( - IOInstanceSegmentorConfig, - NucleusInstanceSegmentor, - SemanticSegmentor, -) -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - _process_tile_predictions, -) -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imwrite -from tiatoolbox.utils.metrics import f1_detection -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -def helper_tile_info() -> list: - """Helper function for tile information.""" - predictor = NucleusInstanceSegmentor(model="A") - # ! assuming the tiles organized as follows (coming out from - # ! PatchExtractor). If this is broken, need to check back - # ! PatchExtractor output ordering first - # left to right, top to bottom - # --------------------- - # | 0 | 1 | 2 | 3 | - # --------------------- - # | 4 | 5 | 6 | 7 | - # --------------------- - # | 8 | 9 | 10 | 11 | - # --------------------- - # | 12 | 13 | 14 | 15 | - # --------------------- - # ! assume flag index ordering: left right top bottom - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": 0.25}], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.25}, - ], - margin=1, - tile_shape=(4, 4), - stride_shape=[4, 4], - patch_input_shape=[4, 4], - patch_output_shape=[4, 4], - ) - - return predictor._get_tile_info([16, 16], ioconfig) - - -# ---------------------------------------------------- - - -def test_get_tile_info() -> None: - """Test for getting tile info.""" - info = helper_tile_info() - _, flag = info[0] # index 0 should be full grid, removal - # removal flag at top edges - assert ( - np.sum( - np.nonzero(flag[:, 0]) - != np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), - ) - == 0 - ), "Fail Top" - # removal flag at bottom edges - assert ( - np.sum( - np.nonzero(flag[:, 1]) != np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), - ) - == 0 - ), "Fail Bottom" - # removal flag at left edges - assert ( - np.sum( - np.nonzero(flag[:, 2]) - != np.array([1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]), - ) - == 0 - ), "Fail Left" - # removal flag at right edges - assert ( - np.sum( - np.nonzero(flag[:, 3]) - != np.array([0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14]), - ) - == 0 - ), "Fail Right" - - -def test_vertical_boundary_boxes() -> None: - """Test for vertical boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [3, 0, 5, 4], - [7, 0, 9, 4], - [11, 0, 13, 4], - [3, 4, 5, 8], - [7, 4, 9, 8], - [11, 4, 13, 8], - [3, 8, 5, 12], - [7, 8, 9, 12], - [11, 8, 13, 12], - [3, 12, 5, 16], - [7, 12, 9, 16], - [11, 12, 13, 16], - ], - ) - _flag = np.array( - [ - [0, 1, 0, 0], - [0, 1, 0, 0], - [0, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - ], - ) - boxes, flag = info[1] - assert np.sum(_boxes - boxes) == 0, "Wrong Vertical Bounds" - assert np.sum(flag - _flag) == 0, "Fail Vertical Flag" - - -def test_horizontal_boundary_boxes() -> None: - """Test for horizontal boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [0, 3, 4, 5], - [4, 3, 8, 5], - [8, 3, 12, 5], - [12, 3, 16, 5], - [0, 7, 4, 9], - [4, 7, 8, 9], - [8, 7, 12, 9], - [12, 7, 16, 9], - [0, 11, 4, 13], - [4, 11, 8, 13], - [8, 11, 12, 13], - [12, 11, 16, 13], - ], - ) - _flag = np.array( - [ - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 1, 1], - [0, 0, 1, 1], - [0, 0, 1, 0], - ], - ) - boxes, flag = info[2] - assert np.sum(_boxes - boxes) == 0, "Wrong Horizontal Bounds" - assert np.sum(flag - _flag) == 0, "Fail Horizontal Flag" - - -def test_cross_section_boundary_boxes() -> None: - """Test for cross-section boundary boxes.""" - info = helper_tile_info() - _boxes = np.array( - [ - [2, 2, 6, 6], - [6, 2, 10, 6], - [10, 2, 14, 6], - [2, 6, 6, 10], - [6, 6, 10, 10], - [10, 6, 14, 10], - [2, 10, 6, 14], - [6, 10, 10, 14], - [10, 10, 14, 14], - ], - ) - _flag = np.array( - [ - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - ], - ) - boxes, flag = info[3] - assert np.sum(boxes - _boxes) == 0, "Wrong Cross Section Bounds" - assert np.sum(flag - _flag) == 0, "Fail Cross Section Flag" - - -def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Test engine crash when given malformed input.""" - root_save_dir = Path(tmp_path) - sample_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - save_dir = f"{root_save_dir}/instance/" - - # resolution for travis testing, not the correct ones - resolution = 4.0 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - instance_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=2, - pretrained_model="hovernet_fast-pannuke", - ) - - # * Test crash propagation when parallelize post-processing - shutil.rmtree("output", ignore_errors=True) - shutil.rmtree(save_dir, ignore_errors=True) - instance_segmentor.model.postproc_func = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - instance_segmentor.predict( - [sample_wsi_svs], - masks=[sample_wsi_msk], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None: - """Functionality test for nuclei instance segmentor.""" - gc.collect() - root_save_dir = Path(tmp_path) - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - save_dir = f"{root_save_dir}/instance/" - - # * test run on wsi, test run with worker - # resolution for travis testing, not the correct ones - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(1024, 1024), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - inst_segmentor = NucleusInstanceSegmentor( - batch_size=1, - num_loader_workers=0, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - ioconfig=ioconfig, - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=save_dir, - ) - - -def test_functionality_merge_tile_predictions_ci( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Functional tests for merging tile predictions.""" - gc.collect() # Force clean up everything on hold - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - resolution = 0.5 - ioconfig = IOInstanceSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - margin=128, - tile_shape=(512, 512), - patch_input_shape=[256, 256], - patch_output_shape=[164, 164], - stride_shape=[164, 164], - ) - - # mainly to hook the merge prediction function - inst_segmentor = NucleusInstanceSegmentor( - batch_size=BATCH_SIZE, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=0, - ) - - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - raw_maps = [[v] for v in raw_maps] # mask it as patch output - - dummy_reference = {i: {"box": np.array([0, 0, 32, 32])} for i in range(1000)} - dummy_flag_mode_list = [ - [[1, 1, 0, 0], 1], - [[0, 0, 1, 1], 2], - [[1, 1, 1, 1], 3], - [[0, 0, 0, 0], 0], - ] - - inst_segmentor._wsi_inst_info = copy.deepcopy(dummy_reference) - inst_segmentor._futures = [[dummy_reference, dummy_reference.keys()]] - inst_segmentor._merge_post_process_results() - assert len(inst_segmentor._wsi_inst_info) == 0 - - blank_raw_maps = [np.zeros_like(v) for v in raw_maps] - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=dummy_flag_mode_list[0][0], - tile_mode=dummy_flag_mode_list[0][1], - tile_output=[[np.array([0, 0, 512, 512]), blank_raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - for tile_flag, tile_mode in dummy_flag_mode_list: - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=tile_mode, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - # test exception flag - tile_flag = [0, 0, 0, 0] - with pytest.raises(ValueError, match=r".*Unknown tile mode.*"): - _process_tile_predictions( - ioconfig=ioconfig, - tile_bounds=np.array([0, 0, 512, 512]), - tile_flag=tile_flag, - tile_mode=-1, - tile_output=[[np.array([0, 0, 512, 512]), raw_maps]], - ref_inst_dict=dummy_reference, - postproc=semantic_segmentor.model.postproc_func, - merge_predictions=semantic_segmentor.merge_prediction, - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: - """Local functionality test for nuclei instance segmentor.""" - root_save_dir = Path(tmp_path) - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - - # * generate full output w/o parallel post-processing worker first - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - batch_size=8, - num_postproc_workers=0, - pretrained_model="hovernet_fast-pannuke", - ) - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_a = joblib.load(f"{output[0][1]}.dat") - - # * then test run when using workers, will then compare results - # * to ensure the predictions are the same - shutil.rmtree(save_dir, ignore_errors=True) - inst_segmentor = NucleusInstanceSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - assert inst_segmentor.num_postproc_workers == 2 - output = inst_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - inst_dict_b = joblib.load(f"{output[0][1]}.dat") - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.95, "Heavy loss of precision!" - - # ** - # To evaluate the precision of doing post-processing on tile - # then re-assemble without using full image prediction maps, - # we compare its output with the output when doing - # post-processing on the entire images. - save_dir = root_save_dir / "semantic" - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor = SemanticSegmentor( - pretrained_model="hovernet_fast-pannuke", - batch_size=BATCH_SIZE, - num_postproc_workers=2, - ) - output = semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir, - ) - raw_maps = [np.load(f"{output[0][1]}.raw.{head_idx}.npy") for head_idx in range(3)] - _, inst_dict_b = semantic_segmentor.model.postproc(raw_maps) - - inst_coords_a = np.array([v["centroid"] for v in inst_dict_a.values()]) - inst_coords_b = np.array([v["centroid"] for v in inst_dict_b.values()]) - score = f1_detection(inst_coords_b, inst_coords_a, radius=1.0) - assert score > 0.9, "Heavy loss of precision!" - - -def test_cli_nucleus_instance_segment_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for nucleus segmentation with IOConfig.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - resolution = 2.0 - - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") - - # resolution for travis testing, not the correct ones - config = { - "input_resolutions": [{"units": "mpp", "resolution": resolution}], - "output_resolutions": [ - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - {"units": "mpp", "resolution": resolution}, - ], - "margin": 128, - "tile_shape": [512, 512], - "patch_input_shape": [256, 256], - "patch_output_shape": [164, 164], - "stride_shape": [164, 164], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - - with Path.open(tmp_path / "config.yaml", "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_jpg), - "--pretrained-weights", - str(pretrained_weights), - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--mode", - "tile", - "--output-path", - str(output_path), - "--yaml-config-path", - str(tmp_path.joinpath("config.yaml")), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() - - -def test_cli_nucleus_instance_segment(remote_sample: Callable, tmp_path: Path) -> None: - """Test for nucleus segmentation.""" - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - output_path = tmp_path / "output" - - runner = CliRunner() - nucleus_instance_segment_result = runner.invoke( - cli.main, - [ - "nucleus-instance-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--num-loader-workers", - str(0), - "--num-postproc-workers", - str(0), - "--output-path", - str(output_path), - ], - ) - - assert nucleus_instance_segment_result.exit_code == 0 - assert output_path.joinpath("0.dat").exists() - assert output_path.joinpath("file_map.dat").exists() - assert output_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_patch_predictor.py b/tests/engines/_test_patch_predictor.py deleted file mode 100644 index b3322635d..000000000 --- a/tests/engines/_test_patch_predictor.py +++ /dev/null @@ -1,763 +0,0 @@ -"""Test for Patch Predictor.""" -from __future__ import annotations - -import copy -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -from click.testing import CliRunner - -from tiatoolbox import cli -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor -from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.utils import download_data, imwrite -from tiatoolbox.utils import env_detection as toolbox_env - -ON_GPU = toolbox_env.has_gpu() - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - -def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: - """Test for delegating args to io config.""" - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models - model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - kwargs = { - "patch_input_shape": [512, 512], - "resolution": 1.75, - "units": "mpp", - } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models - ioconfig = IOPatchPredictorConfig( - patch_input_shape=(512, 512), - stride_shape=(256, 256), - input_resolutions=[{"resolution": 1.35, "units": "mpp"}], - output_resolutions=[], - ) - predictor.predict( - [mini_wsi_svs], - ioconfig=ioconfig, - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], - patch_input_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.patch_input_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - stride_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.stride_shape == (300, 300) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - resolution=1.99, - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor.predict( - [mini_wsi_svs], - units="baseline", - mode="wsi", - on_gpu=ON_GPU, - save_dir=f"{tmp_path}/dump", - ) - assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - -def test_patch_predictor_api( - sample_patch1: Path, - sample_patch2: Path, - tmp_path: Path, -) -> None: - """Helper function to get the model output using API 1.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent reader complaint - inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - # don't run test on GPU - output = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - - # test loading user weight - pretrained_weights_url = ( - "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" - ) - - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - pretrained_weights = ( - save_dir_path / "tmp_pretrained_weigths" / "resnet18-kather100k.pth" - ) - - download_data(pretrained_weights_url, pretrained_weights) - - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, - batch_size=1, - ) - - # --- test different using user model - model = CNNModel(backbone="resnet18", num_classes=9) - # test prediction - predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - -def test_wsi_predictor_api( - sample_wsi_dict: dict, - tmp_path: Path, - chdir: Callable, -) -> None: - """Test normal run of wsi predictor.""" - save_dir_path = tmp_path - - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - save_dir = f"{save_dir_path}/model_wsi_output" - - # wrapper to make this more clean - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 1.0, - "units": "baseline", - "save_dir": save_dir, - } - # ! add this test back once the read at `baseline` is fixed - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - # coverage test - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True - # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 - - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) - - -def _test_predictor_output( - inputs: list, - pretrained_model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - pretrained_model=pretrained_model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.predict( - inputs, - return_probabilities=True, - return_labels=False, - on_gpu=on_gpu, - ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, - probabilities_max, - probabilities_check[idx], - predictions[idx], - predictions_check[idx], - ) - - -def test_patch_predictor_kather100k_output( - sample_patch1: Path, - sample_patch2: Path, -) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" - inputs = [Path(sample_patch1), Path(sample_patch2)] - pretrained_info = { - "alexnet-kather100k": [1.0, 0.9999735355377197], - "resnet18-kather100k": [1.0, 0.9999911785125732], - "resnet34-kather100k": [1.0, 0.9979840517044067], - "resnet50-kather100k": [1.0, 0.9999986886978149], - "resnet101-kather100k": [1.0, 0.9999932050704956], - "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], - "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], - "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], - "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], - "densenet121-kather100k": [1.0, 1.0], - "densenet161-kather100k": [1.0, 0.9999959468841553], - "densenet169-kather100k": [1.0, 0.9999934434890747], - "densenet201-kather100k": [1.0, 0.9999983310699463], - "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], - "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], - "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], - "googlenet-kather100k": [1.0, 0.9999639987945557], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[6, 3], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], - } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_command_line_models_file_not_found(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI file not found error.""" - runner = CliRunner() - model_file_not_found_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs)[:-1], - "--file-types", - '"*.ndpi, *.svs"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert model_file_not_found_result.output == "" - assert model_file_not_found_result.exit_code == 1 - assert isinstance(model_file_not_found_result.exception, FileNotFoundError) - - -def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI mode not in wsi, tile.""" - runner = CliRunner() - mode_not_in_wsi_tile_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--file-types", - '"*.ndpi, *.svs"', - "--mode", - '"patch"', - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output - assert mode_not_in_wsi_tile_result.exit_code != 0 - assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) - - -def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: - """Test for models CLI single file.""" - runner = CliRunner() - models_wsi_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(sample_svs), - "--mode", - "wsi", - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - mini_wsi_msk = tmp_path.joinpath("small_svs_tissue_mask.jpg") - - # Make multiple copies for test - dir_path = tmp_path.joinpath("new_copies") - dir_path.mkdir() - - dir_path_masks = tmp_path.joinpath("new_copies_masks") - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) - except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() diff --git a/tests/engines/_test_semantic_segmentation.py b/tests/engines/_test_semantic_segmentation.py deleted file mode 100644 index 0bc2babde..000000000 --- a/tests/engines/_test_semantic_segmentation.py +++ /dev/null @@ -1,854 +0,0 @@ -"""Test for Semantic Segmentor.""" -from __future__ import annotations - -# ! The garbage collector -import gc -import multiprocessing -import shutil -from pathlib import Path -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.multiprocessing as torch_mp -import torch.nn.functional as F # noqa: N812 -import yaml -from click.testing import CliRunner -from torch import nn - -from tiatoolbox import cli -from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor -from tiatoolbox.models.architecture import fetch_pretrained_weights -from tiatoolbox.models.architecture.utils import centre_crop -from tiatoolbox.models.engine.semantic_segmentor import WSIStreamDataset -from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.utils import imread, imwrite -from tiatoolbox.wsicore.wsireader import WSIReader - -ON_GPU = toolbox_env.has_gpu() -# The value is based on 2 TitanXP each with 12GB -BATCH_SIZE = 1 if not ON_GPU else 16 -try: - NUM_POSTPROC_WORKERS = multiprocessing.cpu_count() -except NotImplementedError: - NUM_POSTPROC_WORKERS = 2 - -# ---------------------------------------------------- - - -def _crash_func(_x: object) -> None: - """Helper to induce crash.""" - msg = "Propagation Crash." - raise ValueError(msg) - - -class _CNNTo1(ModelABC): - """Contains a convolution. - - Simple model to test functionality, this contains a single - convolution layer which has weight=0 and bias=1. - - """ - - def __init__(self: _CNNTo1) -> None: - super().__init__() - self.conv = nn.Conv2d(3, 1, 3, padding=1) - self.conv.weight.data.fill_(0) - self.conv.bias.data.fill_(1) - - def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor: - """Define how to use layer.""" - return self.conv(img) - - @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o - aggregation for a single data batch. - - Args: - model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): A batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. - - """ - device = "cuda" if on_gpu else "cpu" - #### - model.eval() # infer mode - - #### - img_list = batch_data - - img_list = img_list.to(device).type(torch.float32) - img_list = img_list.permute(0, 3, 1, 2) # to NCHW - - hw = np.array(img_list.shape[2:]) - with torch.inference_mode(): # do not compute gradient - logit_list = model(img_list) - logit_list = centre_crop(logit_list, hw // 2) - logit_list = logit_list.permute(0, 2, 3, 1) # to NHWC - prob_list = F.relu(logit_list) - - prob_list = prob_list.cpu().numpy() - return [prob_list] - - -# ------------------------------------------------------------------------------------- -# IOConfig -# ------------------------------------------------------------------------------------- - - -def test_segmentor_ioconfig() -> None: - """Test for IOConfig.""" - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - assert ioconfig.highest_input_resolution == {"units": "mpp", "resolution": 0.25} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 1.0 - assert ioconfig.input_resolutions[1]["resolution"] == 0.5 - assert ioconfig.input_resolutions[2]["resolution"] == 1 / 3 - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - output_resolutions=[ - {"units": "power", "resolution": 20}, - {"units": "power", "resolution": 40}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - save_resolution={"units": "power", "resolution": 8.0}, - ) - assert ioconfig.highest_input_resolution == {"units": "power", "resolution": 40} - ioconfig = ioconfig.to_baseline() - assert ioconfig.input_resolutions[0]["resolution"] == 0.5 - assert ioconfig.input_resolutions[1]["resolution"] == 1.0 - assert ioconfig.save_resolution["resolution"] == 8.0 / 40.0 - - resolutions = [ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ] - with pytest.raises(ValueError, match=r".*Unknown units.*"): - ioconfig.scale_to_highest(resolutions, "axx") - - -# ------------------------------------------------------------------------------------- -# Dataset -# ------------------------------------------------------------------------------------- - - -def test_functional_wsi_stream_dataset(remote_sample: Callable) -> None: - """Functional test for WSIStreamDataset.""" - gc.collect() # Force clean up everything on hold - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - - ioconfig = IOSegmentorConfig( - input_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - {"units": "mpp", "resolution": 0.75}, - ], - output_resolutions=[ - {"units": "mpp", "resolution": 0.25}, - {"units": "mpp", "resolution": 0.50}, - ], - patch_input_shape=[2048, 2048], - patch_output_shape=[1024, 1024], - stride_shape=[512, 512], - ) - mp_manager = torch_mp.Manager() - mp_shared_space = mp_manager.Namespace() - - sds = WSIStreamDataset(ioconfig, [mini_wsi_svs], mp_shared_space) - # test for collate - out = sds.collate_fn([None, 1, 2, 3]) - assert np.sum(out.numpy() != np.array([1, 2, 3])) == 0 - - # artificial data injection - mp_shared_space.wsi_idx = torch.tensor(0) # a scalar - mp_shared_space.patch_inputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - mp_shared_space.patch_outputs = torch.from_numpy( - np.array( - [ - [0, 0, 256, 256], - [256, 256, 512, 512], - ], - ), - ) - # test read - for _, sample in enumerate(sds): - patch_data, _ = sample - (patch_resolution1, patch_resolution2, patch_resolution3) = patch_data - assert np.round(patch_resolution1.shape[0] / patch_resolution2.shape[0]) == 2 - assert np.round(patch_resolution1.shape[0] / patch_resolution3.shape[0]) == 3 - - -# ------------------------------------------------------------------------------------- -# Engine -# ------------------------------------------------------------------------------------- - - -def test_crash_segmentor(remote_sample: Callable) -> None: - """Functional crash tests for segmentor.""" - # # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # * test basic crash - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(TypeError, match=r".*`mask_reader`.*"): - semantic_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) - with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): - semantic_segmentor.filter_coordinates( - WSIReader.open(mini_wsi_msk), - np.array([1.0, 2.0]), - ) - semantic_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) - with pytest.raises(ValueError, match=r".*must be a valid file path.*"): - semantic_segmentor.get_reader( - mini_wsi_msk, - "not_exist", - "wsi", - auto_get_mask=True, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - with pytest.raises(ValueError, match=r".*provide.*"): - SemanticSegmentor() - with pytest.raises(ValueError, match=r".*valid mode.*"): - semantic_segmentor.predict([], mode="abc") - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - with pytest.raises(ValueError, match=r".*already exists.*"): - semantic_segmentor.predict([], mode="tile", patch_input_shape=(2048, 2048)) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test not providing any io_config info when not using pretrained model - with pytest.raises(ValueError, match=r".*provide either `ioconfig`.*"): - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * Test crash propagation when parallelize post-processing - semantic_segmentor.num_postproc_workers = 2 - semantic_segmentor.model.forward = _crash_func - with pytest.raises(ValueError, match=r"Propagation Crash."): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - resolution=1.0, - units="baseline", - ) - - shutil.rmtree("output", ignore_errors=True) - - with pytest.raises(ValueError, match=r"Invalid resolution.*"): - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) - # test ignore crash - semantic_segmentor.predict( - [mini_wsi_svs], - patch_input_shape=(2048, 2048), - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=False, - resolution=1.0, - units="baseline", - ) - shutil.rmtree("output", ignore_errors=True) - - -def test_functional_segmentor_merging(tmp_path: Path) -> None: - """Functional test for assmebling output.""" - save_dir = Path(tmp_path) - - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - # predictions with HW - _output = np.array( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - [0, 0, 2, 2], - [0, 0, 2, 2], - ], - ) - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # a second rerun to test overlapping count, - # should still maintain same result - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2), 1), np.full((2, 2), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - assert np.sum(canvas - _output) < 1.0e-8 - # else will leave hanging file pointer - # and hence cant remove its folder later - del canvas # skipcq - - # * predictions with HWC - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - _ = semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.py", - ) - del _ # skipcq - - # * test crashing when switch to image having larger - # * shape but still provide old links - semantic_segmentor.merge_prediction( - [8, 8], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - with pytest.raises(ValueError, match=r".*`save_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.1.py", - cache_count_path=f"{save_dir}/count.py", - ) - - with pytest.raises(ValueError, match=r".*`cache_count_path` does not match.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2, 2, 1), 1), np.full((2, 2, 1), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - # * test non HW predictions - with pytest.raises(ValueError, match=r".*Prediction is no HW or HWC.*"): - semantic_segmentor.merge_prediction( - [4, 4], - [np.full((2,), 1), np.full((2,), 2)], - [[0, 0, 2, 2], [2, 2, 4, 4]], - save_path=f"{save_dir}/raw.py", - cache_count_path=f"{save_dir}/count.1.py", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir() - - # * with an out of bound location - canvas = semantic_segmentor.merge_prediction( - [4, 4], - [ - np.full((2, 2), 1), - np.full((2, 2), 2), - np.full((2, 2), 3), - np.full((2, 2), 4), - ], - [[0, 0, 2, 2], [2, 2, 4, 4], [0, 4, 2, 6], [4, 0, 6, 2]], - save_path=None, - ) - assert np.sum(canvas - _output) < 1.0e-8 - del canvas # skipcq - - -def test_functional_segmentor(remote_sample: Callable, tmp_path: Path) -> None: - """Functional test for segmentor.""" - save_dir = tmp_path / "dump" - # # convert to pathlib Path to prevent wsireader complaint - resolution = 2.0 - mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=resolution, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - mini_wsi_msk = f"{tmp_path}/mini_mask.jpg" - imwrite(mini_wsi_msk, (thumb > 0).astype(np.uint8)) - - # preemptive clean up - shutil.rmtree("output", ignore_errors=True) # default output dir test - model = _CNNTo1() - semantic_segmentor = SemanticSegmentor(batch_size=BATCH_SIZE, model=model) - # fake injection to trigger Segmentor to create parallel - # post-processing workers because baseline Semantic Segmentor does not support - # post-processing out of the box. It only contains condition to create it - # for any subclass - semantic_segmentor.num_postproc_workers = 1 - - # should still run because we skip exception - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=resolution, - units="mpp", - crash_on_exception=False, - ) - - shutil.rmtree("output", ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=True, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * check exception bypass in the log - # there should be no exception, but how to check the log? - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(512, 512), - patch_output_shape=(512, 512), - stride_shape=(512, 512), - resolution=1 / resolution, - units="baseline", - crash_on_exception=False, - ) - shutil.rmtree("output", ignore_errors=True) # default output dir test - - # * test basic running and merging prediction - # * should dumping all 1 in the output - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "baseline", "resolution": 1.0}], - output_resolutions=[{"units": "baseline", "resolution": 1.0}], - patch_input_shape=[512, 512], - patch_output_shape=[512, 512], - stride_shape=[512, 512], - ) - - shutil.rmtree(save_dir, ignore_errors=True) - file_list = [ - mini_wsi_jpg, - mini_wsi_jpg, - ] - output_list = semantic_segmentor.predict( - file_list, - mode="tile", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - pred_2 = np.load(output_list[1][1] + ".raw.0.npy") - assert len(output_list) == 2 - assert np.sum(pred_1 - pred_2) == 0 - # due to overlapping merge and division, will not be - # exactly 1, but should be approximately so - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # * test running with mask and svs - # * also test merging prediction at designated resolution - ioconfig = IOSegmentorConfig( - input_resolutions=[{"units": "mpp", "resolution": resolution}], - output_resolutions=[{"units": "mpp", "resolution": resolution}], - save_resolution={"units": "mpp", "resolution": resolution}, - patch_input_shape=[512, 512], - patch_output_shape=[256, 256], - stride_shape=[512, 512], - ) - shutil.rmtree(save_dir, ignore_errors=True) - output_list = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - reader = WSIReader.open(mini_wsi_svs) - expected_shape = reader.slide_dimensions(**ioconfig.save_resolution) - expected_shape = np.array(expected_shape)[::-1] # to YX - pred_1 = np.load(output_list[0][1] + ".raw.0.npy") - saved_shape = np.array(pred_1.shape[:2]) - assert np.sum(expected_shape - saved_shape) == 0 - assert np.sum((pred_1 - 1) > 1.0e-6) == 0 - shutil.rmtree(save_dir, ignore_errors=True) - - # check normal run with auto get mask - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - model=model, - auto_generate_mask=True, - ) - _ = semantic_segmentor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - on_gpu=ON_GPU, - ioconfig=ioconfig, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -def test_subclass(remote_sample: Callable, tmp_path: Path) -> None: - """Create subclass and test parallel processing setup.""" - save_dir = Path(tmp_path) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - model = _CNNTo1() - - class XSegmentor(SemanticSegmentor): - """Dummy class to test subclassing.""" - - def __init__(self: XSegmentor) -> None: - super().__init__(model=model) - self.num_postproc_worker = 2 - - semantic_segmentor = XSegmentor() - shutil.rmtree(save_dir, ignore_errors=True) # default output dir test - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - patch_input_shape=(1024, 1024), - patch_output_shape=(512, 512), - stride_shape=(256, 256), - resolution=1.0, - units="baseline", - crash_on_exception=False, - save_dir=save_dir / "raw", - ) - - -# specifically designed for travis -def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: - """Test for load up pretrained and over-writing tile mode ioconfig.""" - save_dir = Path(f"{tmp_path}/output") - mini_wsi_svs = Path(remote_sample("wsi4_512_512_svs")) - reader = WSIReader.open(mini_wsi_svs) - thumb = reader.slide_thumbnail(resolution=1.0, units="baseline") - mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" - imwrite(mini_wsi_jpg, thumb) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [mini_wsi_svs], - mode="wsi", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=ON_GPU, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - assert save_dir.joinpath("raw/0.raw.0.npy").exists() - assert save_dir.joinpath("raw/file_map.dat").exists() - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - wsi_with_artifacts = Path(remote_sample("wsi3_20k_20k_svs")) - mini_wsi_jpg = Path(remote_sample("wsi2_4k_4k_jpg")) - - semantic_segmentor = SemanticSegmentor( - batch_size=BATCH_SIZE, - pretrained_model="fcn-tissue_mask", - ) - shutil.rmtree(save_dir, ignore_errors=True) - semantic_segmentor.predict( - [wsi_with_artifacts], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("wsi3_20k_20k_pred"))) - _test_pred = np.load(str(save_dir / "raw" / "0.raw.0.npy")) - _test_pred = (_test_pred[..., 1] > 0.75) * 255 - # divide 255 to binarize - assert np.mean(_cache_pred[..., 0] == _test_pred) > 0.99 - - shutil.rmtree(save_dir, ignore_errors=True) - # mainly to test prediction on tile - semantic_segmentor.predict( - [mini_wsi_jpg], - mode="tile", - on_gpu=True, - crash_on_exception=True, - save_dir=f"{save_dir}/raw/", - ) - - -@pytest.mark.skipif( - toolbox_env.running_on_ci() or not ON_GPU, - reason="Local test on machine with GPU.", -) -def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None: - """Contain test for behavior of the segmentor and pretrained models.""" - save_dir = tmp_path - - wsi_breast = Path(remote_sample("wsi4_4k_4k_svs")) - semantic_segmentor = SemanticSegmentor( - num_loader_workers=4, - batch_size=BATCH_SIZE, - pretrained_model="fcn_resnet50_unet-bcss", - ) - semantic_segmentor.predict( - [wsi_breast], - mode="wsi", - on_gpu=True, - crash_on_exception=True, - save_dir=save_dir / "raw", - ) - # load up the raw prediction and perform precision check - _cache_pred = np.load(Path(remote_sample("wsi4_4k_4k_pred"))) - _test_pred = np.load(f"{save_dir}/raw/0.raw.0.npy") - _test_pred = np.argmax(_test_pred, axis=-1) - assert np.mean(np.abs(_cache_pred - _test_pred)) < 1.0e-2 - - -# ------------------------------------------------------------------------------------- -# Command Line Interface -# ------------------------------------------------------------------------------------- - - -def test_cli_semantic_segment_out_exists_error( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation if output path exists.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path, - ], - ) - - assert semantic_segment_result.output == "" - assert semantic_segment_result.exit_code == 1 - assert isinstance(semantic_segment_result.exception, FileExistsError) - - -def test_cli_semantic_segmentation_ioconfig( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for semantic segmentation single file custom ioconfig.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") - - config = { - "input_resolutions": [{"units": "mpp", "resolution": 2.0}], - "output_resolutions": [{"units": "mpp", "resolution": 2.0}], - "patch_input_shape": [1024, 1024], - "patch_output_shape": [512, 512], - "stride_shape": [256, 256], - "save_resolution": {"units": "mpp", "resolution": 8.0}, - } - with Path.open(tmp_path.joinpath("config.yaml"), "w") as fptr: - yaml.dump(config, fptr) - - runner = CliRunner() - - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(mini_wsi_svs), - "--pretrained-weights", - str(pretrained_weights), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - tmp_path.joinpath("output"), - "--yaml-config-path", - tmp_path.joinpath("config.yaml"), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("output/0.raw.0.npy").exists() - assert tmp_path.joinpath("output/file_map.dat").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_semantic_segmentation_multi_file( - remote_sample: Callable, - tmp_path: Path, -) -> None: - """Test for models CLI multiple file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = tmp_path / "small_svs_tissue_mask.jpg" - - # Make multiple copies for test - dir_path = tmp_path / "new_copies" - dir_path.mkdir() - - dir_path_masks = tmp_path / "new_copies_masks" - dir_path_masks.mkdir() - - try: - dir_path.joinpath("1_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) - except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - - try: - dir_path_masks.joinpath("1_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - dir_path_masks.joinpath("2_" + sample_wsi_msk.name).symlink_to(sample_wsi_msk) - except OSError: - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("1_" + sample_wsi_msk.name)) - shutil.copy(sample_wsi_msk, dir_path_masks.joinpath("2_" + sample_wsi_msk.name)) - - tmp_path = tmp_path / "output" - - runner = CliRunner() - semantic_segment_result = runner.invoke( - cli.main, - [ - "semantic-segment", - "--img-input", - str(dir_path), - "--mode", - "wsi", - "--masks", - str(dir_path_masks), - "--output-path", - str(tmp_path), - ], - ) - - assert semantic_segment_result.exit_code == 0 - assert tmp_path.joinpath("0.raw.0.npy").exists() - assert tmp_path.joinpath("1.raw.0.npy").exists() - assert tmp_path.joinpath("file_map.dat").exists() - assert tmp_path.joinpath("results.json").exists() - - # load up the raw prediction and perform precision check - _cache_pred = imread(Path(remote_sample("small_svs_tissue_mask"))) - _test_pred = np.load(str(tmp_path.joinpath("0.raw.0.npy"))) - _test_pred = (_test_pred[..., 1] > 0.50) * 255 - - assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3 diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py new file mode 100644 index 000000000..860472197 --- /dev/null +++ b/tests/engines/test_engine_abc.py @@ -0,0 +1,322 @@ +"""Test tiatoolbox.models.engine.engine_abc.""" +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, NoReturn + +import numpy as np +import pytest + +from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir + +if TYPE_CHECKING: + import torch.nn + + +class TestEngineABC(EngineABC): + """Test EngineABC.""" + + def __init__( + self: TestEngineABC, + model: str | torch.nn.Module, + verbose: bool | None = None, + ) -> NoReturn: + """Test EngineABC init.""" + super().__init__(model=model, verbose=verbose) + + def infer_wsi(self: EngineABC) -> NoReturn: + """Test infer_wsi.""" + ... # dummy function for tests. + + def post_process_wsi(self: EngineABC) -> NoReturn: + """Test post_process_wsi.""" + ... # dummy function for tests. + + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Test pre_process_wsi.""" + ... # dummy function for tests. + + +def test_engine_abc() -> NoReturn: + """Test EngineABC initialization.""" + with pytest.raises( + TypeError, + match=r".*Can't instantiate abstract class EngineABC with abstract methods*", + ): + # Can't instantiate abstract class with abstract methods + EngineABC() # skipcq + + +def test_engine_abc_incorrect_model_type() -> NoReturn: + """Test EngineABC initialization with incorrect model type.""" + with pytest.raises( + TypeError, + match=r".*missing 1 required positional argument: 'model'", + ): + TestEngineABC() # skipcq + + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + TestEngineABC(model=1) + + +def test_incorrect_ioconfig() -> NoReturn: + """Test EngineABC initialization with incorrect ioconfig.""" + import torchvision.models as torch_models + + model = torch_models.resnet18() + engine = TestEngineABC(model=model) + with pytest.raises( + ValueError, + match=r".*provide a valid ModelIOConfigABC.*", + ): + engine.run(images=[], masks=[], ioconfig=None) + + +def test_pretrained_ioconfig() -> NoReturn: + """Test EngineABC initialization with pretrained model name in the toolbox.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + ) + assert "predictions" in out + assert "labels" not in out + + +def test_prepare_engines_save_dir( + tmp_path: pytest.TempPathFactory, + caplog: pytest.LogCaptureFixture, +) -> NoReturn: + """Test prepare save directory for engines.""" + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + overwrite=False, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "patch_output", + patch_mode=True, + len_images=1, + overwrite=True, + ) + + assert out_dir == tmp_path / "patch_output" + assert out_dir.exists() + + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=True, + len_images=1, + overwrite=False, + ) + assert out_dir is None + + with pytest.raises( + OSError, + match=r".*More than 1 WSIs detected but there is no save directory provided.*", + ): + _ = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=2, + overwrite=False, + ) + + out_dir = prepare_engines_save_dir( + save_dir=None, + patch_mode=False, + len_images=1, + overwrite=False, + ) + + assert out_dir == Path.cwd() + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_single_output", + patch_mode=False, + len_images=1, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_single_output" + assert out_dir.exists() + assert r"When providing multiple whole-slide images / tiles" not in caplog.text + + out_dir = prepare_engines_save_dir( + save_dir=tmp_path / "wsi_multiple_output", + patch_mode=False, + len_images=2, + overwrite=False, + ) + + assert out_dir == tmp_path / "wsi_multiple_output" + assert out_dir.exists() + assert r"When providing multiple whole slide images" in caplog.text + + +def test_engine_initalization() -> NoReturn: + """Test engine initialization.""" + with pytest.raises( + TypeError, + match="Input model must be a string or 'torch.nn.Module'.", + ): + _ = TestEngineABC(model=0) + + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + model = CNNModel("alexnet", num_classes=1) + eng = TestEngineABC(model=model) + assert isinstance(eng, EngineABC) + + +def test_engine_run() -> NoReturn: + """Test engine run.""" + eng = TestEngineABC(model="alexnet-kather100k") + assert isinstance(eng, EngineABC) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*The input numpy array should be four dimensional.*", + ): + eng.run(images=np.zeros((10, 10))) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + TypeError, + match=r"Input must be a list of file paths or a numpy array.", + ): + eng.run(images=1) + + eng = TestEngineABC(model="alexnet-kather100k") + with pytest.raises( + ValueError, + match=r".*len\(labels\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(1)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*len\(masks\) is not equal to len(images)*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((1, 224, 224, 3)), + on_gpu=False, + ) + + with pytest.raises( + ValueError, + match=r".*The shape of the numpy array should be NHWC*", + ): + eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + masks=np.zeros((10, 3)), + on_gpu=False, + ) + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + ) + assert "predictions" in out + assert "labels" not in out + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + assert "predictions" in out + assert "labels" in out + + +def test_engine_run_with_verbose() -> NoReturn: + """Test engine run with verbose.""" + """Run pytest with `-rP` option to view progress bar on the captured stderr call""" + + eng = TestEngineABC(model="alexnet-kather100k", verbose=True) + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + ) + + assert "predictions" in out + assert "labels" in out + + +def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: + """Test the engine run and patch pred store.""" + save_dir = tmp_path / "patch_output" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + verbose=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + ) + assert Path.exists(out), "Zarr output file does not exist" + + """ test custom zarr output file name""" + eng = TestEngineABC(model="alexnet-kather100k") + out = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + labels=list(range(10)), + on_gpu=False, + save_dir=save_dir, + overwrite=True, + output_file="patch_pred_output", + ) + assert Path.exists(out), "Zarr output file does not exist" diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index df60d3b47..f0142406d 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -44,7 +44,7 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to(select_device(on_gpu=ON_GPU)) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + model = model.to() + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index cd4bd0833..e7aa23d5b 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -39,7 +39,7 @@ def test_functionality( model = model.to(map_location) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=map_location) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index fda0c01a6..b84516125 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -10,6 +10,7 @@ from tiatoolbox.models import NuClick from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device ON_GPU = False @@ -53,7 +54,7 @@ def test_functional_nuclick( model = NuClick(num_input_channels=5, num_output_channels=1) pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) postproc_masks = model.postproc( output, do_reconstruction=True, diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index bdec99e0b..58d3f67d0 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -4,9 +4,10 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -14,7 +15,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) - map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu()) + map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) @@ -39,11 +40,19 @@ def test_functionality(remote_sample: Callable) -> None: ) batch = torch.from_numpy(patch)[None] model = _load_sccnn(name="sccnn-crchisto") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) model = _load_sccnn(name="sccnn-conic") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index f15a5dc71..69496c7aa 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -8,6 +8,7 @@ from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.unet import UNetModel +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = False @@ -47,7 +48,7 @@ def test_functional_unet(remote_sample: Callable) -> None: model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) _ = output[0] # run untrained network to test for architecture @@ -59,4 +60,4 @@ def test_functional_unet(remote_sample: Callable) -> None: encoder_levels=[32, 64], skip_type="concat", ) - _ = model.infer_batch(model, batch, on_gpu=ON_GPU) + _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index a2b1ac5c9..cfae665b2 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -6,9 +6,11 @@ from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator +device = "cuda" if ON_GPU else "cpu" def test_functional() -> None: @@ -43,8 +45,8 @@ def test_functional() -> None: try: for backbone in backbones: model = CNNModel(backbone, num_classes=1) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index bf77b46ba..dcf2251ac 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -14,6 +14,7 @@ ResidualBlock, TFSamepaddingLayer, ) +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-monusac") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-consep") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-kumar") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index 96d0f9d23..1377fdd82 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -7,6 +7,7 @@ from tiatoolbox.models import HoVerNetPlus from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.transforms import imresize @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernetplus-oed") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) diff --git a/tests/models/test_abc.py b/tests/models/test_models_abc.py similarity index 96% rename from tests/models/test_abc.py rename to tests/models/test_models_abc.py index 3537735ce..635b13be1 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_models_abc.py @@ -124,9 +124,9 @@ def test_model_to() -> None: if not utils.env_detection.has_gpu(): model = torch_models.resnet18() with pytest.raises((AssertionError, RuntimeError)): - _ = tiatoolbox.models.models_abc.model_to(on_gpu=True, model=model) + _ = tiatoolbox.models.models_abc.model_to(device="cuda", model=model) # Test on CPU model = torch_models.resnet18() - model = tiatoolbox.models.models_abc.model_to(on_gpu=False, model=model) + model = tiatoolbox.models.models_abc.model_to(device="cpu", model=model) assert isinstance(model, nn.Module) diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py index 4d4651f06..1b2dc8826 100644 --- a/tests/test_annotation_tilerendering.py +++ b/tests/test_annotation_tilerendering.py @@ -460,6 +460,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None: _, store = fill_store(SQLiteStore, tmp_path / "test.db") def color_fn(props: dict[str, str]) -> tuple[int, int, int]: + """Tests Red for cells, otherwise green.""" # simple test function that returns red for cells, otherwise green. if props["type"] == "cell": return 1, 0, 0 diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index 3440b5aec..4137b4f67 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2316,6 +2316,21 @@ def _unpack_wkb( cx: float, cy: float, ) -> bytes: + """Return the geometry as bytes using WKB. + + Args: + data (bytes or str): + The WKB/WKT data to be unpacked. + cx (int): + The X coordinate of the centroid/representative point. + cy (float): + The Y coordinate of the centroid/representative point. + + Returns: + bytes: + The geometry as bytes. + + """ return ( self._decompress_data(data) if data diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index 2c754d1c6..8c6128e8c 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -83,7 +83,7 @@ def patch_predictor( predictor = PatchPredictor( pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, + weights=pretrained_weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index d37ad5c80..7776cdb60 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -1,21 +1,20 @@ """Define a set of models to be used within tiatoolbox.""" from __future__ import annotations -import os from pydoc import locate -from typing import TYPE_CHECKING, Optional, Union - -import torch +from typing import TYPE_CHECKING from tiatoolbox import rcParam from tiatoolbox.models.dataset.classification import predefined_preproc_func +from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils import download_data if TYPE_CHECKING: # pragma: no cover from pathlib import Path - from tiatoolbox.models.models_abc import IOConfigABC + import torch + from tiatoolbox.models.engine.io_config import ModelIOConfigABC __all__ = ["get_pretrained_model", "fetch_pretrained_weights"] PRETRAINED_INFO = rcParam["pretrained_model_info"] @@ -63,7 +62,7 @@ def get_pretrained_model( pretrained_weights: str | Path | None = None, *, overwrite: bool = False, -) -> tuple[torch.nn.Module, IOConfigABC]: +) -> tuple[torch.nn.Module, ModelIOConfigABC]: """Load a predefined PyTorch model with the appropriate pretrained weights. Args: @@ -143,14 +142,11 @@ def get_pretrained_model( overwrite=overwrite, ) - # ! assume to be saved in single GPU mode - # always load on to the CPU - saved_state_dict = torch.load(pretrained_weights, map_location="cpu") - model.load_state_dict(saved_state_dict, strict=True) + model = load_torch_model(model=model, weights=pretrained_weights) # ! io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") - iostate = creator(**io_info["kwargs"]) - return model, iostate + ioconfig = creator(**io_info["kwargs"]) + return model, ioconfig diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index cad29fe83..216e06ee5 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -19,7 +19,6 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc from tiatoolbox.utils.misc import get_bounding_box @@ -781,7 +780,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: return pred_inst, nuc_inst_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -793,8 +792,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: tuple: @@ -806,7 +805,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 59135a350..ddcce67ea 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -12,7 +12,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x -from tiatoolbox.utils import misc class HoVerNetPlus(HoVerNet): @@ -320,7 +319,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -332,13 +331,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index 21c588c29..863ce985d 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -13,7 +13,6 @@ from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet -from tiatoolbox.utils.misc import select_device class MapDe(MicroNet): @@ -258,7 +257,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -271,8 +270,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -281,7 +280,6 @@ def infer_batch( """ patch_imgs = batch_data - device = select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index 69daa120f..c18e51e6b 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -18,7 +18,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc def group1_forward_branch( @@ -628,7 +627,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -641,8 +640,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -651,7 +650,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py index 85a759bb6..cb5f52509 100644 --- a/tiatoolbox/models/architecture/nuclick.py +++ b/tiatoolbox/models/architecture/nuclick.py @@ -21,7 +21,6 @@ from tiatoolbox import logger from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import IntPair @@ -646,7 +645,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> np.ndarray: """Run inference on an input batch. @@ -655,16 +654,16 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): a batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + batch_data (torch.Tensor): + A batch of data generated by torch.utils.data.DataLoader. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() - device = misc.select_device(on_gpu=on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index bbeb58094..bdb8926e3 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -16,7 +16,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class SCCNN(ModelABC): @@ -354,7 +353,7 @@ def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -367,8 +366,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list of :class:`numpy.ndarray`: @@ -377,7 +376,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 8f628fb52..7e2e35c02 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -11,7 +11,6 @@ from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class ResNetEncoder(ResNet): @@ -415,7 +414,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list: """Run inference on an input batch. @@ -428,8 +427,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list: @@ -438,7 +437,6 @@ def infer_batch( """ model.eval() - device = misc.select_device(on_gpu=on_gpu) #### imgs = batch_data diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 5855971d5..2ecbd5b86 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -9,7 +9,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import select_device if TYPE_CHECKING: # pragma: no cover from torchvision.models import WeightsEnum @@ -142,7 +141,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str = "cpu", ) -> np.ndarray: """Run inference on an input batch. @@ -154,11 +153,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() @@ -239,7 +238,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray, ...]: """Run inference on an input batch. @@ -251,11 +250,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/engine/__init__.py b/tiatoolbox/models/engine/__init__.py index 0a5968b44..7d0dfe0e1 100644 --- a/tiatoolbox/models/engine/__init__.py +++ b/tiatoolbox/models/engine/__init__.py @@ -1,11 +1,13 @@ """Engines to run models implemented in tiatoolbox.""" -from tiatoolbox.models.engine import ( +from . import ( + engine_abc, nucleus_instance_segmentor, patch_predictor, semantic_segmentor, ) __all__ = [ + "engine_abc", "nucleus_instance_segmentor", "patch_predictor", "semantic_segmentor", diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 69d66af73..08871c54b 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,21 +1,640 @@ """Defines Abstract Base Class for TIAToolbox Model Engines.""" +from __future__ import annotations + from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, NoReturn + +import numpy as np +import torch +import tqdm +from torch import nn + +from tiatoolbox import logger +from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.dataset.dataset_abc import PatchDataset +from tiatoolbox.models.models_abc import load_torch_model, model_to +from tiatoolbox.utils.misc import dict_to_store, dict_to_zarr + +if TYPE_CHECKING: # pragma: no cover + import os + + from torch.utils.data import DataLoader + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.wsicore.wsireader import WSIReader + + from .io_config import ModelIOConfigABC + + +def prepare_engines_save_dir( + save_dir: os | Path | None, + len_images: int, + *, + patch_mode: bool, + overwrite: bool, +) -> Path | None: + """Create directory if not defined and number of images is more than 1. + + Args: + save_dir (str or Path): + Path to output directory. + len_images (int): + List of inputs to process. + patch_mode(bool): + Whether to treat input image as a patch or WSI. + overwrite (bool): + Whether to overwrite the results. Default = False. + + Returns: + :class:`Path`: + Path to output directory. + + """ + if patch_mode is True: + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=overwrite) + return save_dir + + if save_dir is None: + if len_images > 1: + msg = ( + "More than 1 WSIs detected but there is no save directory provided." + "Please provide a 'save_dir'." + ) + raise OSError(msg) + return ( + Path.cwd() + ) # save the output to current working directory and return save_dir + + if len_images > 1: + logger.info( + "When providing multiple whole slide images, " + "the outputs will be saved and the locations of outputs " + "will be returned to the calling function.", + ) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + + return save_dir class EngineABC(ABC): - """Abstract base class for engines used in tiatoolbox.""" + """Abstract base class for engines used in tiatoolbox. + + Args: + model (str | nn.Module): + A PyTorch model. Default is `None`. + The user can request pretrained models from the toolbox using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = EngineABC( + ... model="pretrained-model-name", + ... weights="pretrained-local-weights.pth") + + batch_size (int): + Number of images fed into the model each time. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. default = 0 + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. default = 0 + device (str): + Select the device to run the model. Default is "cpu". + verbose (bool): + Whether to output logging information. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A NHWC image or a path to WSI. + patch_mode (str): + Whether to treat input image as a patch or WSI. + default = True. + model (str | nn.Module): + Defined PyTorch model. + Name of an existing model supported by the TIAToolbox for + processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration to run the Engine. + _ioconfig (): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable if `patch_mode` is False in inference. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + Select the device to run the model. Default is "cpu". + num_loader_workers (int): + Number of workers used in torch.utils.data.DataLoader. + verbose (bool): + Whether to output logging information. + + Examples: + >>> # array of list of 2 image patches as input + >>> import numpy as np + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # array of list of 2 image patches as input + >>> import numpy as np + >>> data = np.array([np.ndarray, np.ndarray]) + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(data, patch_mode=True) + + >>> # list of 2 image files as input + >>> image = ['path/image1.png', 'path/image2.png'] + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(image, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> engine = EngineABC(model="resnet18-kather100k") + >>> output = engine.run(wsi_file, patch_mode=True) - def __init__(self) -> None: + """ + + def __init__( + self: EngineABC, + model: str | nn.Module, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = False, + ) -> None: """Initialize Engine.""" super().__init__() + self.masks = None + self.images = None + self.patch_mode = None + self.device = device + + # Initialize model with specified weights and ioconfig. + self.model, self.ioconfig = self._initialize_model_ioconfig( + model=model, + weights=weights, + ) + self.model = model_to(model=self.model, device=self.device) + self._ioconfig = self.ioconfig # runtime ioconfig + + self.batch_size = batch_size + self.num_loader_workers = num_loader_workers + self.num_post_proc_workers = num_post_proc_workers + self.verbose = verbose + self.return_labels = False + self.merge_predictions = False + self.units = "baseline" + self.resolution = 1.0 + self.patch_input_shape = None + self.stride_shape = None + self.labels = None + + @staticmethod + def _initialize_model_ioconfig( + model: str | nn.Module, + weights: str | Path | None, + ) -> tuple[nn.Module, ModelIOConfigABC | None]: + """Helper function to initialize model and ioconfig attributes. + + If a pretrained model provided by the TIAToolbox is requested. The model + can be specified as a string otherwise torch.nn.Module is required. + This function also loads the :class:`ModelIOConfigABC` using the information + from the pretrained models in TIAToolbox. If ioconfig is not available then it + should be provided in the :func:`run` function. + + Args: + model (str | nn.Module): + A torch model which should be run by the engine. + + weights (str | Path | None): + Path to pretrained weights. If no pretrained weights are provided + and the model is provided by TIAToolbox, then pretrained weights will + be automatically loaded from the TIA servers. + + Returns: + nn.Module: + The requested PyTorch model. + + ModelIOConfigABC | None: + The model io configuration for TIAToolbox pretrained models. + Otherwise, None. + + """ + if not isinstance(model, (str, nn.Module)): + msg = "Input model must be a string or 'torch.nn.Module'." + raise TypeError(msg) + + if isinstance(model, str): + # ioconfig is retrieved from the pretrained model in the toolbox. + # list of pretrained models in the TIA Toolbox is available here: + # https://tia-toolbox.readthedocs.io/en/add-bokeh-app/pretrained.html + # no need to provide ioconfig in EngineABC.run() this case. + return get_pretrained_model(model, weights) + + if weights is not None: + model = load_torch_model(model=model, weights=weights) + + return model, None + + def pre_process_patches( + self: EngineABC, + images: np.ndarray | list, + labels: list, + ) -> torch.utils.data.DataLoader: + """Pre-process an image patch.""" + if labels: + # if a labels is provided, then return with the prediction + self.return_labels = bool(labels) + + dataset = PatchDataset(inputs=images, labels=labels) + dataset.preproc_func = self.model.preproc_func + + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_loader_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) + + def infer_patches( + self: EngineABC, + data_loader: DataLoader, + ) -> dict: + """Model inference on an image patch.""" + progress_bar = None + + if self.verbose: + progress_bar = tqdm.tqdm( + total=int(len(data_loader)), + leave=True, + ncols=80, + ascii=True, + position=0, + ) + raw_predictions = { + "predictions": [], + } + + if self.return_labels: + raw_predictions["labels"] = [] + + for _, batch_data in enumerate(data_loader): + batch_output_predictions = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + raw_predictions["predictions"].extend(batch_output_predictions.tolist()) + + if self.return_labels: # be careful of `s` + # We do not use tolist here because label may be of mixed types + # and hence collated as list by torch + raw_predictions["labels"].extend(list(batch_data["label"])) + + if progress_bar: + progress_bar.update() + + if progress_bar: + progress_bar.close() + + return raw_predictions + + def setup_patch_dataset( + self: EngineABC, + raw_predictions: dict, + output_type: str, + save_dir: Path | None = None, + **kwargs: dict, + ) -> Path | AnnotationStore: + """Post-process image patches. + + Args: + raw_predictions (dict): + A dictionary of patch prediction information. + save_dir (Path): + Optional Output Path to directory to save the patch dataset output to a + `.zarr` or `.db` file, provided patch_mode is True. if the patch_mode is + False then save_dir is required. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (dict): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (dict, Path, :class:`SQLiteStore`): + if the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations for each or the + Path to a `.db` file depending on whether a save_dir Path is provided. + Otherwise, the function defaults to returning patch predictor output, either + as a dict or the Path to a `.zarr` file depending on whether a save_dir Path + is provided. + + """ + if not save_dir and not self.patch_mode: + msg = "`save_dir` must be specified when patch_mode is False." + raise OSError(msg) + + if not save_dir and output_type != "AnnotationStore": + return raw_predictions + + output_file = ( + kwargs["output_file"] and kwargs.pop("output_file") + if "output_file" in kwargs + else "output" + ) + + save_path = save_dir / output_file + + if output_type == "AnnotationStore": + # scale_factor set from kwargs + scale_factor = kwargs["scale_factor"] if "scale_factor" in kwargs else None + # class_dict set from kwargs + class_dict = kwargs["class_dict"] if "class_dict" in kwargs else None + + return dict_to_store(raw_predictions, scale_factor, class_dict, save_path) + + return dict_to_zarr( + raw_predictions, + save_path, + **kwargs, + ) + @abstractmethod - def process_patch(self): - """Process an image patch.""" + def pre_process_wsi(self: EngineABC) -> NoReturn: + """Pre-process a WSI.""" raise NotImplementedError - # how to deal with patches, list of patches/numpy arrays, WSIs - # how to communicate with sub-processes. - # define how to deal with patches as numpy/zarr arrays. - # convert list of patches/numpy arrays to zarr and then pass to each sub-processes. - # define how to read WSIs, read the image and convert to zarr array. + @abstractmethod + def infer_wsi(self: EngineABC) -> NoReturn: + """Model inference on a WSI.""" + raise NotImplementedError + + @abstractmethod + def post_process_wsi(self: EngineABC) -> NoReturn: + """Post-process a WSI.""" + raise NotImplementedError + + def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: + """Helper function to load ioconfig. + + If the model is provided by TIAToolbox it will load the default ioconfig. + Otherwise, ioconfig must be specified. + + Args: + ioconfig (ModelIOConfigABC): + IO configuration to run the engines. + + Raises: + ValueError: + If no io configuration is provided or found in the pretrained TIAToolbox + models. + + Returns: + ModelIOConfigABC: + The ioconfig used for the run. + + """ + if self.ioconfig is None and ioconfig is None: + msg = ( + "Please provide a valid ModelIOConfigABC. " + "No default ModelIOConfigABC found." + ) + raise ValueError(msg) + + if ioconfig is not None: + self.ioconfig = ioconfig + + return self.ioconfig + + @staticmethod + def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: + """Validate input images for a run.""" + if not isinstance(images, (list, np.ndarray)): + msg = "Input must be a list of file paths or a numpy array." + raise TypeError( + msg, + ) + + if isinstance(images, np.ndarray) and images.ndim != 4: # noqa: PLR2004 + msg = ( + "The input numpy array should be four dimensional." + "The shape of the numpy array should be NHWC." + ) + raise ValueError(msg) + + return images + + @staticmethod + def _validate_input_numbers( + images: list | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ) -> None: + """Validates number of input images, masks and labels.""" + if masks is None and labels is None: + return + + len_images = len(images) + + if masks is not None and len_images != len(masks): + msg = ( + f"len(masks) is not equal to len(images) " + f": {len(masks)} != {len(images)}" + ) + raise ValueError( + msg, + ) + + if labels is not None and len_images != len(labels): + msg = ( + f"len(labels) is not equal to len(images) " + f": {len(labels)} != {len(images)}" + ) + raise ValueError( + msg, + ) + return + + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: dict, + ) -> AnnotationStore | Path | str: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. when using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOPatchPredictorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr", "AnnotationStore". Default is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (dict): + Keyword Args to update :class:`EngineABC` attributes. + + Returns: + (:class:`numpy.ndarray`, dict): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + - merged: path to .npy contain merged + predictions if `merge_predictions` is `True`. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> predictor = EngineABC(model="resnet18-kather100k") + >>> output = predictor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} + >>> output['wsi2.svs'] + ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> output = predictor.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> labels=list(range(10)), + >>> on_gpu=False, + >>> ) + >>> output + ... {'predictions': [[0.7716791033744812, 0.0111849969252944, ..., + ... 0.034451354295015335, 0.004817609209567308]], + ... 'labels': [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), + ... tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)]} + + >>> predictor = EngineABC(model="alexnet-kather100k") + >>> save_dir = Path("/tmp/patch_output/") + >>> output = eng.run( + >>> images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + >>> on_gpu=False, + >>> verbose=False, + >>> save_dir=save_dir, + >>> overwrite=True + >>> ) + >>> output + ... '/tmp/patch_output/output.zarr' + """ + for key in kwargs: + setattr(self, key, kwargs[key]) + + self.patch_mode = patch_mode + + self._validate_input_numbers(images=images, masks=masks, labels=labels) + self.images = self._validate_images_masks(images=images) + + if masks is not None: + self.masks = self._validate_images_masks(images=masks) + + self.labels = labels + + # if necessary Move model parameters to "cpu" or "gpu" and update ioconfig + self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) + self.model = model_to(model=self.model, device=self.device) + + save_dir = prepare_engines_save_dir( + save_dir, + len(self.images), + patch_mode=patch_mode, + overwrite=overwrite, + ) + + if patch_mode: + data_loader = self.pre_process_patches( + self.images, + self.labels, + ) + raw_predictions = self.infer_patches( + data_loader=data_loader, + ) + return self.setup_patch_dataset( + raw_predictions=raw_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + + return {"save_dir": save_dir} diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index 7fbae30bb..287fc456b 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -39,32 +39,34 @@ if TYPE_CHECKING: # pragma: no cover import torch - from .io_config import IOInstanceSegmentorConfig + from tiatoolbox.typing import IntBounds + + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. # May need 3rd party libraries to use method/static method otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, - model_name, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, + model_name: str, +) -> tuple: """Process Tile Predictions. Function to merge new tile prediction with existing prediction, using the output from each task. Args: - ioconfig (:class:`IOInstanceSegmentorConfig`): Object defines information + ioconfig (:class:`IOSegmentorConfig`): Object defines information about input and output placement of patches. tile_bounds (:class:`numpy.array`): Boundary of the current tile, defined as (top_left_x, top_left_y, bottom_x, bottom_y). @@ -239,7 +241,7 @@ class MultiTaskSegmentor(NucleusInstanceSegmentor): """ def __init__( # noqa: PLR0913 - self, + self: MultiTaskSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -286,12 +288,12 @@ def __init__( # noqa: PLR0913 ) def _predict_one_wsi( - self, + self: MultiTaskSegmentor, wsi_idx: int, ioconfig: IOInstanceSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -393,13 +395,13 @@ def _predict_one_wsi( # may need to chain it with parents def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: MultiTaskSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: IntBounds, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -418,10 +420,15 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: MultiTaskSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): + def callback( + new_inst_dicts: dict, + remove_uuid_lists: list, + tiles: dict, + bounds: IntBounds, + ) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store @@ -444,7 +451,7 @@ def callback(new_inst_dicts, remove_uuid_lists, tiles, bounds): callback(*future) continue # some errors happen, log it and propagate exception - # ! this will lead to discard a bunch of + # ! this will lead to discard a whole bunch of # ! inferred tiles within this current WSI if future.exception() is not None: raise future.exception() # noqa: RSE102 diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index 4156e2c2a..9aac3b8f5 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -18,18 +18,18 @@ from tiatoolbox.tools.patchextraction import PatchExtractor if TYPE_CHECKING: # pragma: no cover - from .io_config import IOInstanceSegmentorConfig + from .io_config import IOInstanceSegmentorConfig, IOSegmentorConfig def _process_instance_predictions( - inst_dict, - ioconfig, - tile_shape, - tile_flag, - tile_mode, - tile_tl, - ref_inst_dict, -): + inst_dict: dict, + ioconfig: IOSegmentorConfig, + tile_shape: list, + tile_flag: list, + tile_mode: int, + tile_tl: tuple, + ref_inst_dict: dict, +) -> list | tuple: """Function to merge new tile prediction with existing prediction. Args: @@ -50,12 +50,12 @@ def _process_instance_predictions( an overlapping tile from tile generation. The predicted instances are immediately added to accumulated output. - 1: Vertical tile strip that stands between two normal tiles - (flag 0). It has the the same height as normal tile but + (flag 0). It has the same height as normal tile but less width (hence vertical strip). - 2: Horizontal tile strip that stands between two normal tiles - (flag 0). It has the the same width as normal tile but + (flag 0). It has the same width as normal tile but less height (hence horizontal strip). - - 3: tile strip stands at the cross section of four normal tiles + - 3: tile strip stands at the cross-section of four normal tiles (flag 0). tile_tl (tuple): Top left coordinates of the current tile. ref_inst_dict (dict): Dictionary contains accumulated output. The @@ -144,7 +144,7 @@ def _process_instance_predictions( msg = f"Unknown tile mode {tile_mode}." raise ValueError(msg) - def retrieve_sel_uids(sel_indices, inst_dict): + def retrieve_sel_uids(sel_indices: list, inst_dict: dict) -> list: """Helper to retrieved selected instance uids.""" if len(sel_indices) > 0: # not sure how costly this is in large dict @@ -153,7 +153,7 @@ def retrieve_sel_uids(sel_indices, inst_dict): remove_insts_in_tile = retrieve_sel_uids(sel_indices, inst_dict) - # external removal only for tile at cross sections + # external removal only for tile at cross-sections # this one should contain UUID with the reference database remove_insts_in_orig = [] if tile_mode == 3: # noqa: PLR2004 @@ -186,17 +186,17 @@ def retrieve_sel_uids(sel_indices, inst_dict): # caller. May need 3rd party libraries to use method/static method # otherwise. def _process_tile_predictions( - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, # this would be replaced by annotation store # in the future - ref_inst_dict, - postproc, - merge_predictions, -): + ref_inst_dict: dict, + postproc: Callable, + merge_predictions: Callable, +) -> tuple[dict, list]: """Function to merge new tile prediction with existing prediction. Args: @@ -368,7 +368,7 @@ class NucleusInstanceSegmentor(SemanticSegmentor): """ def __init__( - self, + self: NucleusInstanceSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -406,7 +406,7 @@ def __init__( def _get_tile_info( image_shape: list[int] | np.ndarray, ioconfig: IOInstanceSegmentorConfig, - ): + ) -> list[list, ...]: """Generating tile information. To avoid out of memory problem when processing WSI-scale in @@ -467,7 +467,7 @@ def _get_tile_info( # * remove all sides for boxes # unset for those lie within the selection - def unset_removal_flag(boxes, removal_flag): + def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: """Unset removal flags for tiles intersecting image boundaries.""" sel_boxes = [ shapely_box(0, 0, w, 0), # top edge @@ -581,7 +581,12 @@ def unset_removal_flag(boxes, removal_flag): return info - def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): + def _to_shared_space( + self: NucleusInstanceSegmentor, + wsi_idx: int, + patch_inputs: list, + patch_outputs: list, + ) -> None: """Helper functions to transfer variable to shared space. We modify the shared space so that we can update worker info @@ -613,7 +618,7 @@ def _to_shared_space(self, wsi_idx, patch_inputs, patch_outputs): self._mp_shared_space.patch_outputs = patch_outputs self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() - def _infer_once(self): + def _infer_once(self: NucleusInstanceSegmentor) -> list: """Running the inference only once for the currently active dataloader.""" num_steps = len(self._loader) @@ -640,7 +645,7 @@ def _infer_once(self): sample_outputs = self.model.infer_batch( self._model, sample_datas, - on_gpu=self._on_gpu, + device=self._device, ) # repackage so that it's a N list, each contains # L x etc. output @@ -658,12 +663,12 @@ def _infer_once(self): return cum_output def _predict_one_wsi( - self, + self: NucleusInstanceSegmentor, wsi_idx: int, - ioconfig: IOInstanceSegmentorConfig, + ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -751,13 +756,13 @@ def _predict_one_wsi( self._wsi_inst_info = None # clean up def _process_tile_predictions( - self, - ioconfig, - tile_bounds, - tile_flag, - tile_mode, - tile_output, - ): + self: NucleusInstanceSegmentor, + ioconfig: IOSegmentorConfig, + tile_bounds: np.ndarray, + tile_flag: list, + tile_mode: int, + tile_output: list, + ) -> None: """Function to dispatch parallel post processing.""" args = [ ioconfig, @@ -775,10 +780,10 @@ def _process_tile_predictions( future = _process_tile_predictions(*args) self._futures.append(future) - def _merge_post_process_results(self): + def _merge_post_process_results(self: NucleusInstanceSegmentor) -> None: """Helper to aggregate results from parallel workers.""" - def callback(new_inst_dict, remove_uuid_list): + def callback(new_inst_dict: dict, remove_uuid_list: list) -> None: """Helper to aggregate worker's results.""" # ! DEPRECATION: # ! will be deprecated upon finalization of SQL annotation store diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 807cd9fad..3092f827b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -4,7 +4,7 @@ import copy from collections import OrderedDict from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, NoReturn import numpy as np import torch @@ -12,20 +12,24 @@ import tiatoolbox.models.models_abc from tiatoolbox import logger +from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + import os -from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.typing import IntPair, Resolution, Units + + from .io_config import ModelIOConfigABC +from .engine_abc import EngineABC from .io_config import IOPatchPredictorConfig -class PatchPredictor: - r"""Patch level predictor. +class PatchPredictor(EngineABC): + r"""Patch level predictor for digital histology images. The models provided by tiatoolbox should give the following results: @@ -125,12 +129,12 @@ class PatchPredictor: be downloaded. However, you can override with your own set of weights via the `pretrained_weights` argument. Argument is case-insensitive. - pretrained_weights (str): + weights (str): Path to the weight of the corresponding `pretrained_model`. >>> predictor = PatchPredictor( ... pretrained_model="resnet18-kather100k", - ... pretrained_weights="resnet18_local_weight") + ... weights="resnet18_local_weight") batch_size (int): Number of images fed into the model each time. @@ -141,14 +145,14 @@ class PatchPredictor: Whether to output logging information. Attributes: - img (:obj:`str` or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): + images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): A HWC image or a path to WSI. mode (str): Type of input to process. Choose from either `patch`, `tile` or `wsi`. model (nn.Module): Defined PyTorch model. - pretrained_model (str): + model (str): Name of the existing models support by tiatoolbox for processing the data. For a full list of pretrained models, refer to the `docs @@ -166,7 +170,7 @@ class PatchPredictor: Examples: >>> # list of 2 image patches as input - >>> data = [img1, img2] + >>> data = ['path/img.svs', 'path/img.svs'] >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') @@ -202,38 +206,46 @@ class PatchPredictor: """ def __init__( - self, - batch_size=8, - num_loader_workers=0, - model=None, - pretrained_model=None, - pretrained_weights=None, + self: PatchPredictor, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + model: torch.nn.Module = None, + pretrained_model: str | None = None, + weights: str | None = None, *, - verbose=True, + verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" - super().__init__() + super().__init__( + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + model=model, + pretrained_model=pretrained_model, + weights=weights, + verbose=verbose, + ) - self.imgs = None - self.mode = None + def pre_process_wsi(self: PatchPredictor) -> NoReturn: + """Pre-process a WSI.""" + ... - if model is None and pretrained_model is None: - msg = "Must provide either `model` or `pretrained_model`." - raise ValueError(msg) + def infer_wsi(self: PatchPredictor) -> NoReturn: + """Model inference on a WSI.""" + ... - if model is not None: - self.model = model - ioconfig = None # retrieve iostate from provided model ? - else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) + def post_process_patches( + self: PatchPredictor, + raw_predictions: dict, + output_type: str, + ) -> None: + """Post-process an image patch.""" + ... - self.ioconfig = ioconfig # for storing original - self._ioconfig = None # for storing runtime - self.model = model # for runtime, such as after wrapping with nn.DataParallel - self.pretrained_model = pretrained_model - self.batch_size = batch_size - self.num_loader_worker = num_loader_workers - self.verbose = verbose + def post_process_wsi(self: PatchPredictor) -> NoReturn: + """Post-process a WSI.""" + ... @staticmethod def merge_predictions( @@ -241,10 +253,10 @@ def merge_predictions( output: dict, resolution: Resolution | None = None, units: Units | None = None, - postproc_func: Callable | None = None, + post_proc_func: Callable | None = None, *, return_raw: bool = False, - ): + ) -> np.ndarray: """Merge patch level predictions to form a 2-dimensional prediction map. #! Improve how the below reads. @@ -263,7 +275,7 @@ def merge_predictions( units (Units): Units of resolution used when merging predictions. This must be the same `units` used when processing the data. - postproc_func (callable): + post_proc_func (callable): A function to post-process raw prediction from model. By default, internal code uses the `np.argmax` function. return_raw (bool): @@ -345,8 +357,8 @@ def merge_predictions( output = output / (np.expand_dims(denominator, -1) + 1.0e-8) if not return_raw: # convert raw probabilities to predictions - if postproc_func is not None: - output = postproc_func(output) + if post_proc_func is not None: + output = post_proc_func(output) else: output = np.argmax(output, axis=-1) # to make sure background is 0 while class will be 1...N @@ -354,14 +366,14 @@ def merge_predictions( return output def _predict_engine( - self, - dataset, + self: PatchPredictor, + dataset: torch.utils.data.Dataset, *, - return_probabilities=False, - return_labels=False, - return_coordinates=False, - on_gpu=True, - ): + return_probabilities: bool = False, + return_labels: bool = False, + return_coordinates: bool = False, + device: str = "cpu", + ) -> np.ndarray: """Make a prediction on a dataset. The dataset may be mutated. Args: @@ -374,8 +386,8 @@ def _predict_engine( Whether to return labels. return_coordinates (bool): Whether to return patch coordinates. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". Returns: :class:`numpy.ndarray`: @@ -387,7 +399,7 @@ def _predict_engine( # preprocessing must be defined with the dataset dataloader = torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_worker, + num_workers=self.num_loader_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, @@ -403,7 +415,7 @@ def _predict_engine( ) # use external for testing - model = tiatoolbox.models.models_abc.model_to(model=self.model, on_gpu=on_gpu) + model = tiatoolbox.models.models_abc.model_to(model=self.model, device=device) cum_output = { "probabilities": [], @@ -415,7 +427,7 @@ def _predict_engine( batch_output_probabilities = self.model.infer_batch( model, batch_data["image"], - on_gpu=on_gpu, + device=device, ) # We get the index of the class with the maximum probability batch_output_predictions = self.model.postproc_func( @@ -447,13 +459,13 @@ def _predict_engine( return cum_output def _update_ioconfig( - self, - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ): + self: PatchPredictor, + ioconfig: IOPatchPredictorConfig, + patch_input_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOPatchPredictorConfig: """Update the ioconfig. Args: @@ -519,44 +531,15 @@ def _update_ioconfig( output_resolutions=[], ) - @staticmethod - def _prepare_save_dir(save_dir, imgs): - """Create directory if not defined and number of images is more than 1. - - Args: - save_dir (str or pathlib.Path): - Path to output directory. - imgs (list, ndarray): - List of inputs to process. - - Returns: - :class:`pathlib.Path`: - Path to output directory. - - """ - if save_dir is None and len(imgs) > 1: - logger.warning( - "More than 1 WSIs detected but there is no save directory set." - "All subsequent output will be saved to current runtime" - "location under folder 'output'. Overwriting may happen!", - stacklevel=2, - ) - save_dir = Path.cwd() / "output" - elif save_dir is not None and len(imgs) > 1: - logger.warning( - "When providing multiple whole-slide images / tiles, " - "we save the outputs and return the locations " - "to the corresponding files.", - stacklevel=2, - ) - - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=False) - - return save_dir - - def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_gpu): + def _predict_patch( + self: PatchPredictor, + imgs: list | np.ndarray, + labels: list, + *, + return_probabilities: bool, + return_labels: bool, + device: str, + ) -> np.ndarray: """Process patch mode. Args: @@ -574,8 +557,8 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + Select the device to run the engine. Returns: :class:`numpy.ndarray`: @@ -600,23 +583,24 @@ def _predict_patch(self, imgs, labels, return_probabilities, return_labels, on_g return_probabilities=return_probabilities, return_labels=return_labels, return_coordinates=return_coordinates, - on_gpu=on_gpu, + device=device, ) def _predict_tile_wsi( # noqa: PLR0913 - self, - imgs, - masks, - labels, - mode, - return_probabilities, - on_gpu, - ioconfig, - merge_predictions, - save_dir, - save_output, - highest_input_resolution, - ): + self: PatchPredictor, + imgs: list, + masks: list | None, + labels: list, + mode: str, + ioconfig: IOPatchPredictorConfig, + save_dir: str | Path, + highest_input_resolution: list[dict], + *, + save_output: bool, + return_probabilities: bool, + merge_predictions: bool, + on_gpu: bool, + ) -> list | dict: """Predict on Tile and WSIs. Args: @@ -626,7 +610,7 @@ def _predict_tile_wsi( # noqa: PLR0913 file paths or a numpy array of an image list. When using `tile` or `wsi` mode, the input must be a list of file paths. - masks (list): + masks (list or None): List of masks. Only utilised when processing image tiles and whole-slide images. Patches are only processed if they are within a masked area. If not provided, then a @@ -715,7 +699,7 @@ def _predict_tile_wsi( # noqa: PLR0913 ) output_model["label"] = img_label # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.pretrained_model + output_model["pretrained_model"] = self.model output_model["resolution"] = highest_input_resolution["resolution"] output_model["units"] = highest_input_resolution["units"] @@ -727,7 +711,7 @@ def _predict_tile_wsi( # noqa: PLR0913 output_model, resolution=output_model["resolution"], units=output_model["units"], - postproc_func=self.model.postproc, + post_proc_func=self.model.postproc, ) outputs.append(merged_prediction) @@ -748,25 +732,51 @@ def _predict_tile_wsi( # noqa: PLR0913 return file_dict if save_output else outputs + def run( + self: EngineABC, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: dict, + ) -> AnnotationStore | str: + """Run engine.""" + super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) + def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - labels=None, - mode="patch", + self: PatchPredictor, + imgs: list, + masks: list | None = None, + labels: list | None = None, + mode: str = "patch", ioconfig: IOPatchPredictorConfig | None = None, patch_input_shape: tuple[int, int] | None = None, stride_shape: tuple[int, int] | None = None, - resolution=None, - units=None, + resolution: Resolution | None = None, + units: Units = None, *, - return_probabilities=False, - return_labels=False, - on_gpu=True, - merge_predictions=False, - save_dir=None, - save_output=False, - ): + return_probabilities: bool = False, + return_labels: bool = False, + on_gpu: bool = True, + merge_predictions: bool = False, + save_dir: str | Path | None = None, + save_output: bool = False, + ) -> np.ndarray | list | dict: """Make a prediction for a list of input data. Args: diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index e1341c640..237d032f1 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -27,10 +27,13 @@ from .io_config import IOSegmentorConfig if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.typing import Resolution, Units + from tiatoolbox.typing import IntPair, Resolution, Units -def _estimate_canvas_parameters(sample_prediction, canvas_shape): +def _estimate_canvas_parameters( + sample_prediction: np.ndarray, + canvas_shape: np.ndarray, +) -> tuple[tuple, tuple, bool]: """Estimates canvas parameters. Args: @@ -58,11 +61,11 @@ def _estimate_canvas_parameters(sample_prediction, canvas_shape): def _prepare_save_output( - save_path, - cache_count_path, - canvas_cum_shape_, - canvas_count_shape_, -): + save_path: str | Path, + cache_count_path: str | Path, + canvas_cum_shape_: tuple[int, ...], + canvas_count_shape_: tuple[int, ...], +) -> tuple: """Prepares for saving the cached output.""" if save_path is not None: save_path = Path(save_path) @@ -193,7 +196,7 @@ class SemanticSegmentor: """ def __init__( - self, + self: SemanticSegmentor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -251,7 +254,7 @@ def __init__( def get_coordinates( image_shape: list[int] | np.ndarray, ioconfig: IOSegmentorConfig, - ): + ) -> tuple[list, list]: """Calculate patch tiling coordinates. By default, internally, it will call the @@ -309,7 +312,7 @@ def filter_coordinates( bounds: np.ndarray, resolution: Resolution | None = None, units: Units | None = None, - ): + ) -> np.ndarray: """Indicates which coordinate is valid basing on the mask. To use your own approaches, either subclass to overwrite or @@ -369,7 +372,7 @@ def filter_coordinates( scale_factor = mask_real_shape / mask_resolution_shape scale_factor = scale_factor[0] # what if ratio x != y - def sel_func(coord: np.ndarray): + def sel_func(coord: np.ndarray) -> bool: """Accept coord as long as its box contains part of mask.""" coord_in_real_mask = np.ceil(scale_factor * coord).astype(np.int32) start_x, start_y, end_x, end_y = coord_in_real_mask @@ -386,7 +389,7 @@ def get_reader( mode: str, *, auto_get_mask: bool, - ): + ) -> tuple[WSIReader, WSIReader]: """Define how to get reader for mask and source image.""" img_path = Path(img_path) reader = WSIReader.open(img_path) @@ -411,12 +414,12 @@ def get_reader( return reader, mask_reader def _predict_one_wsi( - self, + self: SemanticSegmentor, wsi_idx: int, ioconfig: IOSegmentorConfig, save_path: str, mode: str, - ): + ) -> None: """Make a prediction on tile/wsi. Args: @@ -527,13 +530,13 @@ def _predict_one_wsi( shutil.rmtree(cache_dir) def _process_predictions( - self, + self: SemanticSegmentor, cum_batch_predictions: list, wsi_reader: WSIReader, ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also saving afterwards. @@ -595,7 +598,7 @@ def merge_prediction( locations: list | np.ndarray, save_path: str | Path | None = None, cache_count_path: str | Path | None = None, - ): + ) -> np.ndarray: """Merge patch-level predictions to form a 2-dimensional prediction map. When accumulating the raw prediction onto a same canvas (via @@ -665,7 +668,7 @@ def merge_prediction( canvas_count_shape_, ) - def index(arr, tl, br): + def index(arr: np.ndarray, tl: np.ndarray, br: np.ndarray) -> np.ndarray: """Helper to shorten indexing.""" return arr[tl[0] : br[0], tl[1] : br[1]] @@ -726,7 +729,7 @@ def index(arr, tl, br): return cum_canvas @staticmethod - def _prepare_save_dir(save_dir): + def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]: """Prepare save directory and cache.""" if save_dir is None: logger.warning( @@ -749,14 +752,14 @@ def _prepare_save_dir(save_dir): @staticmethod def _update_ioconfig( - ioconfig, - mode, - patch_input_shape, - patch_output_shape, - stride_shape, - resolution, - units, - ): + ioconfig: IOSegmentorConfig, + mode: str, + patch_input_shape: IntPair, + patch_output_shape: IntPair, + stride_shape: IntPair, + resolution: Resolution, + units: Units, + ) -> IOSegmentorConfig: """Update ioconfig according to input parameters. Args: @@ -815,7 +818,7 @@ def _update_ioconfig( return ioconfig - def _prepare_workers(self): + def _prepare_workers(self: SemanticSegmentor) -> None: """Prepare number of workers.""" self._postproc_workers = None if self.num_postproc_workers is not None: @@ -823,7 +826,7 @@ def _prepare_workers(self): max_workers=self.num_postproc_workers, ) - def _memory_cleanup(self): + def _memory_cleanup(self: SemanticSegmentor) -> None: """Memory clean up.""" self.imgs = None self.masks = None @@ -838,15 +841,16 @@ def _memory_cleanup(self): self._postproc_workers = None def _predict_wsi_handle_exception( - self, - imgs, - wsi_idx, - img_path, - mode, - ioconfig, - save_dir, - crash_on_exception, - ): + self: SemanticSegmentor, + imgs: list, + wsi_idx: int, + img_path: str | Path, + mode: str, + ioconfig: IOSegmentorConfig, + save_dir: str | Path, + *, + crash_on_exception: bool, + ) -> None: """Predict on multiple WSIs. Args: @@ -916,21 +920,21 @@ def _predict_wsi_handle_exception( logging.exception("Crashed on %s", wsi_save_path) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=None, - units=None, - save_dir=None, + self: SemanticSegmentor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig = None, + patch_input_shape: IntPair = None, + patch_output_shape: IntPair = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - on_gpu=True, - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the object instantiation time @@ -966,8 +970,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1049,10 +1053,10 @@ def predict( # noqa: PLR0913 ) # use external for testing - self._on_gpu = on_gpu + self._device = device self._model = tiatoolbox.models.models_abc.model_to( model=self.model, - on_gpu=on_gpu, + device=device, ) # workers should be > 0 else Value Error will be thrown @@ -1170,7 +1174,7 @@ class DeepFeatureExtractor(SemanticSegmentor): """ def __init__( - self, + self: DeepFeatureExtractor, batch_size: int = 8, num_loader_workers: int = 0, num_postproc_workers: int = 0, @@ -1197,13 +1201,13 @@ def __init__( self.process_prediction_per_batch = False def _process_predictions( - self, + self: DeepFeatureExtractor, cum_batch_predictions: list, wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 ioconfig: IOSegmentorConfig, save_path: str, cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 - ): + ) -> None: """Define how the aggregated predictions are processed. This includes merging the prediction if necessary and also @@ -1241,21 +1245,21 @@ def _process_predictions( np.save(f"{save_path}.features.{idx}.npy", prediction_list) def predict( # noqa: PLR0913 - self, - imgs, - masks=None, - mode="tile", - ioconfig=None, - patch_input_shape=None, - patch_output_shape=None, - stride_shape=None, - resolution=1.0, - units="baseline", - save_dir=None, + self: DeepFeatureExtractor, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig | None = None, + patch_input_shape: IntPair | None = None, + patch_output_shape: IntPair | None = None, + stride_shape: IntPair = None, + resolution: Resolution = 1.0, + units: Units = "baseline", + save_dir: str | Path | None = None, *, - on_gpu=True, - crash_on_exception=False, - ): + device: str = "cpu", + crash_on_exception: bool = False, + ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. By default, if the input model at the time of object @@ -1291,8 +1295,8 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + Select the device to run the model. Default is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1348,7 +1352,7 @@ def predict( # noqa: PLR0913 imgs=imgs, masks=masks, mode=mode, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, patch_input_shape=patch_input_shape, patch_output_shape=patch_output_shape, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 9c5bb4cd1..09aef76dc 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -2,19 +2,67 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable import torch from torch import nn if TYPE_CHECKING: # pragma: no cover + from pathlib import Path + import numpy as np +# Draft - will be moved into ModelABC as a class method +def load_torch_model(model: nn.Module, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. + + Args: + model (torch.nn.Module): + A torch model. + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + model.load_state_dict(saved_state_dict, strict=True) + return model + + +# Draft - will be moved into ModelABC as a class method +def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + """ + device = torch.device(device) + model = model.to(device) + + # If target device is CUDA and more than one GPU is available, use DataParallel + if device.type == "cuda" and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + return model + + class ModelABC(ABC, nn.Module): """Abstract base class for models used in tiatoolbox.""" - def __init__(self) -> None: + def __init__(self: ModelABC) -> None: """Initialize Abstract class ModelABC.""" super().__init__() self._postproc = self.postproc @@ -22,13 +70,53 @@ def __init__(self) -> None: @abstractmethod # This is generic abc, else pylint will complain - def forward(self, *args, **kwargs): + def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover + def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module: + """Transfers model to cpu/gpu. + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to cpu/gpu. + """ + device = torch.device(device) + model = super().to(device) + + # If target device is CUDA and more than one GPU is available, use DataParallel + if device.type == "cuda" and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + return model + + def load_weights_from_path(self: ModelABC, weights: str | Path) -> nn.Module: + """Helper function to load a torch model. + + Args: + weights (str or Path): + Path to pretrained weights. + + Returns: + torch.nn.Module: + Torch model with pretrained weights loaded on CPU. + + """ + # ! assume to be saved in single GPU mode + # always load on to the CPU + saved_state_dict = torch.load(weights, map_location="cpu") + + return self.load_state_dict(saved_state_dict, strict=True) + @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> None: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -39,29 +127,29 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool): batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ ... # pragma: no cover @staticmethod - def preproc(image): + def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" return image @staticmethod - def postproc(image): + def postproc(image: np.ndarray) -> np.ndarray: """Define the post-processing of this class of model.""" return image @property - def preproc_func(self): + def preproc_func(self: ModelABC) -> Callable: """Return the current pre-processing function of this instance.""" return self._preproc @preproc_func.setter - def preproc_func(self, func): + def preproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance. If `func=None`, the method will default to `self.preproc`. @@ -73,7 +161,7 @@ def preproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func - >>> transformed_img = model.preproc_func(img) + >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -86,12 +174,12 @@ def preproc_func(self, func): self._preproc = func @property - def postproc_func(self): + def postproc_func(self: ModelABC) -> Callable: """Return the current post-processing function of this instance.""" return self._postproc @postproc_func.setter - def postproc_func(self, func): + def postproc_func(self: ModelABC, func: Callable) -> None: """Set the pre-processing function for this instance of model. If `func=None`, the method will default to `self.postproc`. @@ -104,7 +192,7 @@ def postproc_func(self, func): >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func - >>> transformed_img = model.postproc_func(img) + >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -115,22 +203,3 @@ def postproc_func(self, func): self._postproc = self.postproc else: self._postproc = func - - -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu")