From 0d77e993f027528407e13592cffa5a0f14ece07c Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 17 Jul 2023 16:45:09 +0100 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=85Reuse=20models=20and=20datasets=20?= =?UTF-8?q?in=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/models/test_arch_mapde.py | 13 ++++++------ tests/models/test_arch_micronet.py | 4 ++-- tests/models/test_arch_nuclick.py | 5 ++--- tests/models/test_arch_sccnn.py | 13 ++++++------ tests/models/test_arch_unet.py | 5 ++--- tests/models/test_dataset.py | 7 ++++++- tests/models/test_hovernet.py | 19 +++++++++--------- tests/models/test_hovernetplus.py | 4 ++-- .../models/test_nucleus_instance_segmentor.py | 6 ++---- tests/models/test_patch_predictor.py | 6 +++--- tests/models/test_semantic_segmentation.py | 6 ++---- tiatoolbox/models/architecture/__init__.py | 20 ++++++++++++++----- tiatoolbox/models/dataset/info.py | 2 +- 13 files changed, 58 insertions(+), 52 deletions(-) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 49a144343..122b8252b 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -8,24 +8,23 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def _load_mapde(tmp_path, name): +def _load_mapde(name): """Loads MapDe model with specified weights.""" model = MapDe() - fetch_pretrained_weights(name, f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights(name) map_location = utils.misc.select_device(utils.env_detection.has_gpu()) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) return model -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test for MapDe. Tests the functionality of MapDe model for inference at the patch level. """ - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -34,14 +33,14 @@ def test_functionality(remote_sample, tmp_path): (0, 0, 252, 252), resolution=0.50, units="mpp", coord_space="resolution" ) - model = _load_mapde(tmp_path=tmp_path, name="mapde-crchisto") + model = _load_mapde(name="mapde-crchisto") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output[0:2] == [[99, 178], [64, 218]]) - model = _load_mapde(tmp_path=tmp_path, name="mapde-conic") + model = _load_mapde(name="mapde-conic") output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index d359a72b6..44e3794ed 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -28,9 +28,9 @@ def test_functionality(remote_sample, tmp_path): model = MicroNet() patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - fetch_pretrained_weights("micronet-consep", f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("micronet-consep") map_location = utils.misc.select_device(utils.env_detection.has_gpu()) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output, _ = model.postproc(output[0]) diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index c49e3b85b..a0f0ea724 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -20,8 +20,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog): tile_path = pathlib.Path(remote_sample("patch-extraction-vf")) img = imread(tile_path) - _pretrained_path = f"{tmp_path}/weights.pth" - fetch_pretrained_weights("nuclick_original-pannuke", _pretrained_path) + weights_path = fetch_pretrained_weights("nuclick_original-pannuke") # test creation _ = NuClick(num_input_channels=5, num_output_channels=1) @@ -46,7 +45,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog): batch = torch.from_numpy(batch[np.newaxis, ...]) model = NuClick(num_input_channels=5, num_output_channels=1) - pretrained = torch.load(_pretrained_path, map_location="cpu") + pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=ON_GPU) postproc_masks = model.postproc( diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index beafc6d1d..6f320b11f 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -8,24 +8,23 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def _load_sccnn(tmp_path, name): +def _load_sccnn(name): """Loads SCCNN model with specified weights.""" model = SCCNN() - fetch_pretrained_weights(name, f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights(name) map_location = utils.misc.select_device(utils.env_detection.has_gpu()) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) return model -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test for SCCNN. Tests the functionality of SCCNN model for inference at the patch level. """ - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -34,12 +33,12 @@ def test_functionality(remote_sample, tmp_path): (30, 30, 61, 61), resolution=0.25, units="mpp", coord_space="resolution" ) batch = torch.from_numpy(patch)[None] - model = _load_sccnn(tmp_path=tmp_path, name="sccnn-crchisto") + model = _load_sccnn(name="sccnn-crchisto") output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) - model = _load_sccnn(tmp_path=tmp_path, name="sccnn-conic") + model = _load_sccnn(name="sccnn-conic") output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index c9ce84fdf..044c9d6e8 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -20,8 +20,7 @@ def test_functional_unet(remote_sample, tmp_path): # convert to pathlib Path to prevent wsireader complaint mini_wsi_svs = pathlib.Path(remote_sample("wsi2_4k_4k_svs")) - _pretrained_path = f"{tmp_path}/weights.pth" - fetch_pretrained_weights("fcn-tissue_mask", _pretrained_path) + pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") reader = WSIReader.open(mini_wsi_svs) with pytest.raises(ValueError, match=r".*Unknown encoder*"): @@ -47,7 +46,7 @@ def test_functional_unet(remote_sample, tmp_path): batch = torch.from_numpy(batch) model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) - pretrained = torch.load(_pretrained_path, map_location="cpu") + pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=ON_GPU) output = output[0] diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index bbc2d202a..728586486 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -61,12 +61,17 @@ def test_dataset_abc(): def test_kather_dataset_default(tmp_path): """Test for kather patch dataset with default parameters.""" # test kather with default init + dataset_path = os.path.join( + rcParam["TIATOOLBOX_HOME"], "dataset", "kather100k-validation" + ) + shutil.rmtree(dataset_path, ignore_errors=True) + _ = KatherPatchDataset() # kather with default data path skip download _ = KatherPatchDataset() # remove generated data - shutil.rmtree(rcParam["TIATOOLBOX_HOME"]) + shutil.rmtree(dataset_path, ignore_errors=False) def test_kather_nonexisting_dir(): diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index 717234a98..9a5421989 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -15,9 +15,8 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test.""" - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -27,8 +26,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=6, mode="fast") - fetch_pretrained_weights("hovernet_fast-pannuke", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -41,8 +40,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=5, mode="fast") - fetch_pretrained_weights("hovernet_fast-monusac", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_fast-monusac") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -55,8 +54,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=5, mode="original") - fetch_pretrained_weights("hovernet_original-consep", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_original-consep") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -69,8 +68,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=None, mode="original") - fetch_pretrained_weights("hovernet_original-kumar", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_original-kumar") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index bfeddc234..e513994ba 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -24,8 +24,8 @@ def test_functionality(remote_sample, tmp_path): assert len(model.decoder["hv"]) > 0, "Decoder must contain hv branch." assert len(model.decoder["tp"]) > 0, "Decoder must contain tp branch." assert len(model.decoder["ls"]) > 0, "Decoder must contain ls branch." - fetch_pretrained_weights("hovernetplus-oed", f"{tmp_path}/weigths.pth") - pretrained = torch.load(f"{tmp_path}/weigths.pth") + weights_path = fetch_pretrained_weights("hovernetplus-oed") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index d7f711dea..ab21699bc 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -519,9 +519,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path): mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" imwrite(mini_wsi_jpg, thumb) - fetch_pretrained_weights( - "hovernet_fast-pannuke", str(tmp_path.joinpath("hovernet_fast-pannuke.pth")) - ) + pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") # resolution for travis testing, not the correct ones config = { @@ -550,7 +548,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path): "--img-input", str(mini_wsi_jpg), "--pretrained-weights", - str(tmp_path.joinpath("hovernet_fast-pannuke.pth")), + str(pretrained_weights), "--num-loader-workers", str(0), "--num-postproc-workers", diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 23c0e7a21..6849fdd21 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -11,7 +11,7 @@ import torch from click.testing import CliRunner -from tiatoolbox import cli, rcParam +from tiatoolbox import cli from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import ( PatchDataset, @@ -208,7 +208,6 @@ def test_patch_dataset_crash(tmp_path): match="Cannot load image data from", ): _ = PatchDataset(imgs) - _rm_dir(rcParam["TIATOOLBOX_HOME"]) # preproc func for not defined dataset with pytest.raises( @@ -660,8 +659,9 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path): # remove prev generated data _rm_dir(save_dir_path) os.makedirs(save_dir_path) + pretrained_weights = os.path.join( - rcParam["TIATOOLBOX_HOME"], + save_dir_path, "tmp_pretrained_weigths", "resnet18-kather100k.pth", ) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 6bc7b0903..11d3942ce 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -765,9 +765,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path): sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - fetch_pretrained_weights( - "fcn-tissue_mask", str(tmp_path.joinpath("fcn-tissue_mask.pth")) - ) + pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") config = { "input_resolutions": [{"units": "mpp", "resolution": 2.0}], @@ -789,7 +787,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path): "--img-input", str(mini_wsi_svs), "--pretrained-weights", - str(tmp_path.joinpath("fcn-tissue_mask.pth")), + str(pretrained_weights), "--mode", "wsi", "--masks", diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 80a21d4f9..cda90d0aa 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -16,7 +16,9 @@ PRETRAINED_INFO = rcParam["pretrained_model_info"] -def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool = True): +def fetch_pretrained_weights( + model_name: str, save_path: str = None, overwrite: bool = False +) -> pathlib.Path: """Get the pretrained model information from yml file. Args: @@ -28,9 +30,19 @@ def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool = corresponding `model_name`. overwrite (bool): Overwrite existing downloaded weights. + + Returns: + pathlib.Path: + The local path to the cached pretrained weights after downloading. """ info = PRETRAINED_INFO[model_name] + + if save_path is None: + file_name = info["url"].split("/")[-1] + save_path = os.path.join(rcParam["TIATOOLBOX_HOME"], "models/", file_name) + download_data(info["url"], save_path, overwrite) + return pathlib.Path(save_path) def get_pretrained_model( @@ -110,11 +122,9 @@ def get_pretrained_model( model.preproc_func = predefined_preproc_func(info["dataset"]) if pretrained_weights is None: - file_name = info["url"].split("/")[-1] - pretrained_weights = os.path.join( - rcParam["TIATOOLBOX_HOME"], "models/", file_name + pretrained_weights = fetch_pretrained_weights( + pretrained_model, overwrite=overwrite ) - fetch_pretrained_weights(pretrained_model, pretrained_weights, overwrite) # ! assume to be saved in single GPU mode # always load on to the CPU diff --git a/tiatoolbox/models/dataset/info.py b/tiatoolbox/models/dataset/info.py index 3db496bc8..5562f1fb6 100644 --- a/tiatoolbox/models/dataset/info.py +++ b/tiatoolbox/models/dataset/info.py @@ -85,7 +85,7 @@ def __init__( if save_dir_path is None: # pragma: no cover save_dir_path = Path(rcParam["TIATOOLBOX_HOME"], "dataset") - if not os.path.exists(save_dir_path): + if not os.path.exists(os.path.join(save_dir_path, "kather100k-validation")): save_zip_path = os.path.join(save_dir_path, "Kather.zip") url = ( "https://tiatoolbox.dcs.warwick.ac.uk/datasets" From 5a2d5ed122eefa0dfd9f107ccbe720bb35fe1f02 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 18 Jul 2023 10:49:35 +0100 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=85=20Add=20missing=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_utils.py | 16 ++++++++++++++++ tiatoolbox/models/architecture/__init__.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 424ed1633..d08209697 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ from tests.test_annotation_stores import cell_polygon from tiatoolbox import rcParam, utils +from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupported from tiatoolbox.utils.transforms import locsize2bounds @@ -1472,3 +1473,18 @@ def test_from_multi_head_dat_type_dict(tmp_path): assert len(result) == 1 result = store.query(where=lambda x: x["type"][0:4] == "cell") assert len(result) == 2 + + +def test_fetch_pretrained_weights(tmp_path): + """Test fetching pretrained weights for a model.""" + + file_path = os.path.join(tmp_path, "test_fetch_pretrained_weights.pth") + if os.path.exists(file_path): + os.remove(file_path) + + fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path) + assert os.path.exists(file_path) + assert os.path.getsize(file_path) > 0 + + with pytest.raises(ValueError, match="does not exist"): + fetch_pretrained_weights("abc", file_path) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 9f68a6e11..65b4e10ff 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -36,6 +36,9 @@ def fetch_pretrained_weights( The local path to the cached pretrained weights after downloading. """ + if model_name not in PRETRAINED_INFO: + raise ValueError(f"Pretrained model `{model_name}` does not exist") + info = PRETRAINED_INFO[model_name] if save_path is None: From 096bc31360d82d673242e0c3da1a9f6028b93111 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 18 Jul 2023 14:58:52 +0100 Subject: [PATCH 3/3] :art: Remove unused import --- tiatoolbox/models/architecture/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 65b4e10ff..bf26cb47b 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -8,7 +8,6 @@ import torch from tiatoolbox import rcParam -from tiatoolbox.models.architecture.vanilla import CNNBackbone, CNNModel from tiatoolbox.models.dataset.classification import predefined_preproc_func from tiatoolbox.utils import download_data