diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index e4b8d30a7..875cf4b85 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -11,24 +11,23 @@ ON_GPU = toolbox_env.has_gpu() -def _load_mapde(tmp_path, name): +def _load_mapde(name): """Loads MapDe model with specified weights.""" model = MapDe() - fetch_pretrained_weights(name, f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights(name) map_location = select_device(ON_GPU) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) return model -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test for MapDe. Tests the functionality of MapDe model for inference at the patch level. """ - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -37,7 +36,7 @@ def test_functionality(remote_sample, tmp_path): (0, 0, 252, 252), resolution=0.50, units="mpp", coord_space="resolution" ) - model = _load_mapde(tmp_path=tmp_path, name="mapde-conic") + model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] model = model.to(select_device(ON_GPU)) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index 5abc5dd16..c24f07c8a 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -29,9 +29,9 @@ def test_functionality(remote_sample, tmp_path): model = MicroNet() patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - fetch_pretrained_weights("micronet-consep", f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("micronet-consep") map_location = select_device(ON_GPU) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=ON_GPU) output, _ = model.postproc(output[0]) diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index c49e3b85b..a0f0ea724 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -20,8 +20,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog): tile_path = pathlib.Path(remote_sample("patch-extraction-vf")) img = imread(tile_path) - _pretrained_path = f"{tmp_path}/weights.pth" - fetch_pretrained_weights("nuclick_original-pannuke", _pretrained_path) + weights_path = fetch_pretrained_weights("nuclick_original-pannuke") # test creation _ = NuClick(num_input_channels=5, num_output_channels=1) @@ -46,7 +45,7 @@ def test_functional_nuclick(remote_sample, tmp_path, caplog): batch = torch.from_numpy(batch[np.newaxis, ...]) model = NuClick(num_input_channels=5, num_output_channels=1) - pretrained = torch.load(_pretrained_path, map_location="cpu") + pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=ON_GPU) postproc_masks = model.postproc( diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index beafc6d1d..6f320b11f 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -8,24 +8,23 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def _load_sccnn(tmp_path, name): +def _load_sccnn(name): """Loads SCCNN model with specified weights.""" model = SCCNN() - fetch_pretrained_weights(name, f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights(name) map_location = utils.misc.select_device(utils.env_detection.has_gpu()) - pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) + pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) return model -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test for SCCNN. Tests the functionality of SCCNN model for inference at the patch level. """ - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -34,12 +33,12 @@ def test_functionality(remote_sample, tmp_path): (30, 30, 61, 61), resolution=0.25, units="mpp", coord_space="resolution" ) batch = torch.from_numpy(patch)[None] - model = _load_sccnn(tmp_path=tmp_path, name="sccnn-crchisto") + model = _load_sccnn(name="sccnn-crchisto") output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) - model = _load_sccnn(tmp_path=tmp_path, name="sccnn-conic") + model = _load_sccnn(name="sccnn-conic") output = model.infer_batch(model, batch, on_gpu=False) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index c9ce84fdf..044c9d6e8 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -20,8 +20,7 @@ def test_functional_unet(remote_sample, tmp_path): # convert to pathlib Path to prevent wsireader complaint mini_wsi_svs = pathlib.Path(remote_sample("wsi2_4k_4k_svs")) - _pretrained_path = f"{tmp_path}/weights.pth" - fetch_pretrained_weights("fcn-tissue_mask", _pretrained_path) + pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") reader = WSIReader.open(mini_wsi_svs) with pytest.raises(ValueError, match=r".*Unknown encoder*"): @@ -47,7 +46,7 @@ def test_functional_unet(remote_sample, tmp_path): batch = torch.from_numpy(batch) model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) - pretrained = torch.load(_pretrained_path, map_location="cpu") + pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=ON_GPU) output = output[0] diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index bbc2d202a..728586486 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -61,12 +61,17 @@ def test_dataset_abc(): def test_kather_dataset_default(tmp_path): """Test for kather patch dataset with default parameters.""" # test kather with default init + dataset_path = os.path.join( + rcParam["TIATOOLBOX_HOME"], "dataset", "kather100k-validation" + ) + shutil.rmtree(dataset_path, ignore_errors=True) + _ = KatherPatchDataset() # kather with default data path skip download _ = KatherPatchDataset() # remove generated data - shutil.rmtree(rcParam["TIATOOLBOX_HOME"]) + shutil.rmtree(dataset_path, ignore_errors=False) def test_kather_nonexisting_dir(): diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index 717234a98..9a5421989 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -15,9 +15,8 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def test_functionality(remote_sample, tmp_path): +def test_functionality(remote_sample): """Functionality test.""" - tmp_path = str(tmp_path) sample_wsi = str(remote_sample("wsi1_2k_2k_svs")) reader = WSIReader.open(sample_wsi) @@ -27,8 +26,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=6, mode="fast") - fetch_pretrained_weights("hovernet_fast-pannuke", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -41,8 +40,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=5, mode="fast") - fetch_pretrained_weights("hovernet_fast-monusac", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_fast-monusac") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -55,8 +54,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=5, mode="original") - fetch_pretrained_weights("hovernet_original-consep", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_original-consep") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] @@ -69,8 +68,8 @@ def test_functionality(remote_sample, tmp_path): ) batch = torch.from_numpy(patch)[None] model = HoVerNet(num_types=None, mode="original") - fetch_pretrained_weights("hovernet_original-kumar", f"{tmp_path}/weights.pth") - pretrained = torch.load(f"{tmp_path}/weights.pth") + weights_path = fetch_pretrained_weights("hovernet_original-kumar") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) output = [v[0] for v in output] diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index bfeddc234..e513994ba 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -24,8 +24,8 @@ def test_functionality(remote_sample, tmp_path): assert len(model.decoder["hv"]) > 0, "Decoder must contain hv branch." assert len(model.decoder["tp"]) > 0, "Decoder must contain tp branch." assert len(model.decoder["ls"]) > 0, "Decoder must contain ls branch." - fetch_pretrained_weights("hovernetplus-oed", f"{tmp_path}/weigths.pth") - pretrained = torch.load(f"{tmp_path}/weigths.pth") + weights_path = fetch_pretrained_weights("hovernetplus-oed") + pretrained = torch.load(weights_path) model.load_state_dict(pretrained) output = model.infer_batch(model, batch, on_gpu=False) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index d7f711dea..ab21699bc 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -519,9 +519,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path): mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" imwrite(mini_wsi_jpg, thumb) - fetch_pretrained_weights( - "hovernet_fast-pannuke", str(tmp_path.joinpath("hovernet_fast-pannuke.pth")) - ) + pretrained_weights = fetch_pretrained_weights("hovernet_fast-pannuke") # resolution for travis testing, not the correct ones config = { @@ -550,7 +548,7 @@ def test_cli_nucleus_instance_segment_ioconfig(remote_sample, tmp_path): "--img-input", str(mini_wsi_jpg), "--pretrained-weights", - str(tmp_path.joinpath("hovernet_fast-pannuke.pth")), + str(pretrained_weights), "--num-loader-workers", str(0), "--num-postproc-workers", diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index f69274623..6da87acf6 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -11,7 +11,7 @@ import torch from click.testing import CliRunner -from tiatoolbox import cli, rcParam +from tiatoolbox import cli from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import ( @@ -205,7 +205,6 @@ def test_patch_dataset_crash(tmp_path): match="Cannot load image data from", ): _ = PatchDataset(imgs) - _rm_dir(rcParam["TIATOOLBOX_HOME"]) # preproc func for not defined dataset with pytest.raises( @@ -657,8 +656,9 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path): # remove prev generated data _rm_dir(save_dir_path) os.makedirs(save_dir_path) + pretrained_weights = os.path.join( - rcParam["TIATOOLBOX_HOME"], + save_dir_path, "tmp_pretrained_weigths", "resnet18-kather100k.pth", ) diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index cde7221e3..029f58f7b 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -765,9 +765,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path): sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - fetch_pretrained_weights( - "fcn-tissue_mask", str(tmp_path.joinpath("fcn-tissue_mask.pth")) - ) + pretrained_weights = fetch_pretrained_weights("fcn-tissue_mask") config = { "input_resolutions": [{"units": "mpp", "resolution": 2.0}], @@ -789,7 +787,7 @@ def test_cli_semantic_segmentation_ioconfig(remote_sample, tmp_path): "--img-input", str(mini_wsi_svs), "--pretrained-weights", - str(tmp_path.joinpath("fcn-tissue_mask.pth")), + str(pretrained_weights), "--mode", "wsi", "--masks", diff --git a/tests/test_utils.py b/tests/test_utils.py index 424ed1633..d08209697 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ from tests.test_annotation_stores import cell_polygon from tiatoolbox import rcParam, utils +from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupported from tiatoolbox.utils.transforms import locsize2bounds @@ -1472,3 +1473,18 @@ def test_from_multi_head_dat_type_dict(tmp_path): assert len(result) == 1 result = store.query(where=lambda x: x["type"][0:4] == "cell") assert len(result) == 2 + + +def test_fetch_pretrained_weights(tmp_path): + """Test fetching pretrained weights for a model.""" + + file_path = os.path.join(tmp_path, "test_fetch_pretrained_weights.pth") + if os.path.exists(file_path): + os.remove(file_path) + + fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path) + assert os.path.exists(file_path) + assert os.path.getsize(file_path) > 0 + + with pytest.raises(ValueError, match="does not exist"): + fetch_pretrained_weights("abc", file_path) diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 159c874e1..bf26cb47b 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -8,7 +8,6 @@ import torch from tiatoolbox import rcParam -from tiatoolbox.models.architecture.vanilla import CNNBackbone, CNNModel from tiatoolbox.models.dataset.classification import predefined_preproc_func from tiatoolbox.utils import download_data @@ -16,7 +15,9 @@ PRETRAINED_INFO = rcParam["pretrained_model_info"] -def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool = True): +def fetch_pretrained_weights( + model_name: str, save_path: str = None, overwrite: bool = False +) -> pathlib.Path: """Get the pretrained model information from yml file. Args: @@ -29,9 +30,22 @@ def fetch_pretrained_weights(model_name: str, save_path: str, overwrite: bool = overwrite (bool): Overwrite existing downloaded weights. + Returns: + pathlib.Path: + The local path to the cached pretrained weights after downloading. + """ + if model_name not in PRETRAINED_INFO: + raise ValueError(f"Pretrained model `{model_name}` does not exist") + info = PRETRAINED_INFO[model_name] + + if save_path is None: + file_name = info["url"].split("/")[-1] + save_path = os.path.join(rcParam["TIATOOLBOX_HOME"], "models/", file_name) + download_data(info["url"], save_path, overwrite) + return pathlib.Path(save_path) def get_pretrained_model( @@ -111,11 +125,9 @@ def get_pretrained_model( model.preproc_func = predefined_preproc_func(info["dataset"]) if pretrained_weights is None: - file_name = info["url"].split("/")[-1] - pretrained_weights = os.path.join( - rcParam["TIATOOLBOX_HOME"], "models/", file_name + pretrained_weights = fetch_pretrained_weights( + pretrained_model, overwrite=overwrite ) - fetch_pretrained_weights(pretrained_model, pretrained_weights, overwrite) # ! assume to be saved in single GPU mode # always load on to the CPU diff --git a/tiatoolbox/models/dataset/info.py b/tiatoolbox/models/dataset/info.py index 3db496bc8..5562f1fb6 100644 --- a/tiatoolbox/models/dataset/info.py +++ b/tiatoolbox/models/dataset/info.py @@ -85,7 +85,7 @@ def __init__( if save_dir_path is None: # pragma: no cover save_dir_path = Path(rcParam["TIATOOLBOX_HOME"], "dataset") - if not os.path.exists(save_dir_path): + if not os.path.exists(os.path.join(save_dir_path, "kather100k-validation")): save_zip_path = os.path.join(save_dir_path, "Kather.zip") url = ( "https://tiatoolbox.dcs.warwick.ac.uk/datasets"