diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c0ace5cfc..175e4bb9d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -35,6 +35,7 @@ timm>=1.0.3 torch>=2.1.0 torchvision>=0.15.0 tqdm>=4.64.1 +transformers>=4.51.1 umap-learn>=0.5.3 wsidicom>=0.18.0 zarr>=2.13.3, <3.0.0 diff --git a/tests/models/test_arch_sam.py b/tests/models/test_arch_sam.py new file mode 100644 index 000000000..03e31a94b --- /dev/null +++ b/tests/models/test_arch_sam.py @@ -0,0 +1,65 @@ +"""Unit test package for SAM.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import torch + +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device + +ON_GPU = toolbox_env.has_gpu() + +# Test pretrained Model ============================= + + +def test_functional_sam( + remote_sample: Callable, +) -> None: + """Test for SAM.""" + # convert to pathlib Path to prevent wsireader complaint + tile_path = Path(remote_sample("patch-extraction-vf")) + img = imread(tile_path) + + # test creation + + model = SAM(device=select_device(on_gpu=ON_GPU)) + + # create image patch and prompts + patch = img[63:191, 750:878, :] + + points = [[[64, 64]]] + boxes = [[[64, 64, 128, 128]]] + + # test preproc + tensor = torch.from_numpy(img) + patch = np.expand_dims(model.preproc(tensor), axis=0) + patch = model.preproc(patch) + + # test inference + + mask_output, score_output = model.infer_batch( + model, patch, points, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU) + ) + + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" + + mask_output, score_output = model.infer_batch( + model, patch, device=select_device(on_gpu=ON_GPU) + ) + + assert mask_output is not None, "Output should not be None" + assert len(mask_output) > 0, "Output should have at least one element" + assert len(score_output) > 0, "Output should have at least one element" diff --git a/tests/models/test_prompt_segmentor.py b/tests/models/test_prompt_segmentor.py new file mode 100644 index 000000000..1996a9448 --- /dev/null +++ b/tests/models/test_prompt_segmentor.py @@ -0,0 +1,273 @@ +"""Unit test package for Prompt Segmentor.""" + +from __future__ import annotations + +# ! The garbage collector +import multiprocessing +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pytest + +from tiatoolbox.models import PromptSegmentor +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.models.engine.semantic_segmentor import ( + IOSegmentorConfig, +) +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils import imwrite +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = toolbox_env.has_gpu() +BATCH_SIZE = 1 if not ON_GPU else 2 +try: + NUM_LOADER_WORKERS = multiprocessing.cpu_count() +except NotImplementedError: + NUM_LOADER_WORKERS = 2 + + +def test_functional_segmentor( + remote_sample: Callable, + tmp_path: Path, +) -> None: + """Functional test for segmentor.""" + save_dir = tmp_path / "dump" + # # convert to pathlib Path to prevent wsireader complaint + resolution = 2.0 + mini_wsi_svs = Path(remote_sample("patch-extraction-vf")) + reader = WSIReader.open(mini_wsi_svs, resolution) + thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") + thumb = thumb[63:191, 750:878, :] + mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" + imwrite(mini_wsi_jpg, thumb) + + # preemptive clean up + shutil.rmtree(save_dir, ignore_errors=True) + + model = SAM() + + # test engine setup + + _ = PromptSegmentor(None, BATCH_SIZE, NUM_LOADER_WORKERS) + + prompt_segmentor = PromptSegmentor(model, BATCH_SIZE, NUM_LOADER_WORKERS) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "mpp", "resolution": 4.0}, + ], + output_resolutions=[{"units": "mpp", "resolution": 4.0}], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[512, 512], + ) + + # test inference + + points = np.array([[[64, 64]], [[64, 64]]]) # Point on nuclei + + # Run on tile mode with multi-prompt + # Test running with multiple images + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg, mini_wsi_jpg], + mode="tile", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + pred_1 = np.load(output_list[0][1] + "/0.raw.0.npy") + pred_2 = np.load(output_list[1][1] + "/0.raw.0.npy") + assert len(output_list) == 2 + assert np.sum(pred_1 - pred_2) == 0 + + points = np.array([[[64, 64], [100, 40], [100, 70]]]) # Points on nuclei + boxes = np.array([[[10, 10, 50, 50], [80, 80, 110, 110]]]) # Boxes on nuclei + + # Run on tile mode with single-prompt + # Also tests boxes + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + mode="tile", + multi_prompt=False, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + box_coords=boxes, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + total_prompts = points.shape[1] + boxes.shape[1] + preds = [ + np.load(output_list[0][1] + f"/{i}.raw.0.npy") for i in range(total_prompts) + ] + + assert len(output_list) == 1 + assert len(preds) == total_prompts + + # Generate mask + mask = np.zeros((thumb.shape[0], thumb.shape[1]), dtype=np.uint8) + mask[32:120, 32:120] = 1 + mini_wsi_msk = f"{tmp_path}/mini_svs_mask.jpg" + imwrite(mini_wsi_msk, mask) + + ioconfig = IOSegmentorConfig( + input_resolutions=[ + {"units": "baseline", "resolution": 1.0}, + ], + output_resolutions=[{"units": "baseline", "resolution": 1.0}], + patch_input_shape=[512, 512], + patch_output_shape=[512, 512], + stride_shape=[512, 512], + save_resolution={"units": "baseline", "resolution": 1.0}, + ) + + # Only point within mask should generate a segmentation + points = np.array([[[64, 64], [100, 40]]]) + save_dir = tmp_path / "dump" + + # Run on wsi mode with multi-prompt + # Also tests masks + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + points = np.array([[[10, 30]]]) + boxes = np.array([[[10, 10, 30, 30]]]) + # Test no prompts within mask + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + masks=[mini_wsi_msk], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + box_coords=boxes, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + # Run on wsi mode with single-prompt + shutil.rmtree(save_dir, ignore_errors=True) + output_list = prompt_segmentor.predict( + [mini_wsi_jpg], + mode="wsi", + multi_prompt=False, + device=select_device(on_gpu=ON_GPU), + point_coords=points, + ioconfig=ioconfig, + crash_on_exception=False, + save_dir=save_dir, + ) + + # Check if db exists + assert Path(output_list[0][1] + ".0.db").exists() + + +def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: + """Functional crash tests for segmentor.""" + # # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) + + save_dir = tmp_path / "test_crash_segmentor" + prompt_segmentor = PromptSegmentor(batch_size=BATCH_SIZE) + + # * test basic crash + with pytest.raises(TypeError, match=r".*`mask_reader`.*"): + prompt_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) + with pytest.raises(TypeError, match=r".*`mask_reader`.*"): + prompt_segmentor.get_mask_bounds(mini_wsi_msk) + with pytest.raises(TypeError, match=r".*mask_reader.*"): + prompt_segmentor.clip_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) + + with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): + prompt_segmentor.filter_coordinates( + WSIReader.open(mini_wsi_msk), + np.array([1.0, 2.0]), + ) + with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): + prompt_segmentor.clip_coordinates( + WSIReader.open(mini_wsi_msk), + np.array([1.0, 2.0]), + ) + prompt_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) + with pytest.raises(ValueError, match=r".*must be a valid file path.*"): + prompt_segmentor.get_reader( + mini_wsi_msk, + "not_exist", + "wsi", + auto_get_mask=True, + ) + + shutil.rmtree(save_dir, ignore_errors=True) # default output dir test + with pytest.raises(ValueError, match=r".*valid mode.*"): + prompt_segmentor.predict([], mode="abc") + + crash_segmentor = PromptSegmentor() + + # * test crash segmentor + def _predict_one_wsi( + *args: dict, + **kwargs: dict, + ) -> tuple[WSIReader, str]: + """Override the predict function to test crash segmentor.""" + msg = f"Test crash segmentor:{args} {kwargs}" + raise RuntimeError(msg) + + crash_segmentor._predict_one_wsi = _predict_one_wsi + shutil.rmtree(save_dir, ignore_errors=True) + with pytest.raises( + RuntimeError, + match=r"Test crash segmentor:\(.*\) \{.*\}", + ): + crash_segmentor.predict( + [mini_wsi_svs], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + patch_input_shape=[512, 512], + resolution=2.0, + units="mpp", + crash_on_exception=True, + save_dir=save_dir, + ) + + # test ignore crash + shutil.rmtree(save_dir, ignore_errors=True) + crash_segmentor.predict( + [mini_wsi_svs], + mode="wsi", + multi_prompt=True, + device=select_device(on_gpu=ON_GPU), + patch_input_shape=[512, 512], + resolution=2.0, + units="mpp", + crash_on_exception=False, + save_dir=save_dir, + ) diff --git a/tests/test_app_bokeh.py b/tests/test_app_bokeh.py index ce97fb2fd..a926a86de 100644 --- a/tests/test_app_bokeh.py +++ b/tests/test_app_bokeh.py @@ -512,6 +512,50 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No assert len(main.UI["type_column"].children) == 1 +def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None: + """Test running SAM on points and a box.""" + slide_select = doc.get_model_by_name("slide_select0") + slide_select.value = [data_path["slide2"].name] + run_button = doc.get_model_by_name("to_model0") + assert len(main.UI["color_column"].children) == 0 + slide_select.value = [data_path["slide1"].name] + # set up a box selection + main.UI["box_source"].data = { + "x": [1200], + "y": [-2000], + "width": [400], + "height": [400], + } + + # select SAM model and run it on box + model_select = doc.get_model_by_name("model_drop0") + model_select.value = "SAM" + + click = ButtonClick(run_button) + run_button._trigger_event(click) + assert len(main.UI["color_column"].children) > 0 + + # test save functionality + save_button = doc.get_model_by_name("save_button0") + click = ButtonClick(save_button) + save_button._trigger_event(click) + saved_path = ( + data_path["base_path"] / "overlays" / (data_path["slide1"].stem + ".db") + ) + assert saved_path.exists() + + # load an overlay with different types + cprop_select = doc.get_model_by_name("cprop0") + cprop_select.value = ["prob"] + layer_drop = doc.get_model_by_name("layer_drop0") + click = MenuItemClick(layer_drop, str(data_path["dat_anns"])) + layer_drop._trigger_event(click) + assert main.UI["vstate"].types == ["annotation"] + # check the per-type ui controls have been updated + assert len(main.UI["color_column"].children) == 1 + assert len(main.UI["type_column"].children) == 1 + + def test_alpha_sliders(doc: Document) -> None: """Test sliders for adjusting slide and overlay alpha.""" slide_alpha = doc.get_model_by_name("slide_alpha0") diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ee896365..844cee384 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1953,6 +1953,7 @@ def test_dict_to_store_semantic_segment() -> None: assert "Line String" in annotations_geometry_type +# Tests for OME tiff writer def test_dict_to_store_semantic_segment_holes(track_tmp_path: Path) -> None: """Tests behaviour of holes in dict_to_store and save_path.""" test_pred = np.array( diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index 39d1441ce..42b758c33 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -8,6 +8,7 @@ from .architecture.mapde import MapDe from .architecture.micronet import MicroNet from .architecture.nuclick import NuClick +from .architecture.sam import SAM from .architecture.sccnn import SCCNN from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor @@ -17,6 +18,7 @@ PatchPredictor, WSIPatchDataset, ) +from .engine.prompt_segmentor import PromptSegmentor from .engine.semantic_segmentor import ( DeepFeatureExtractor, IOSegmentorConfig, @@ -25,6 +27,7 @@ ) __all__ = [ + "SAM", "SCCNN", "HoVerNet", "HoVerNetPlus", @@ -35,5 +38,6 @@ "NuClick", "NucleusInstanceSegmentor", "PatchPredictor", + "PromptSegmentor", "SemanticSegmentor", ] diff --git a/tiatoolbox/models/architecture/sam.py b/tiatoolbox/models/architecture/sam.py new file mode 100644 index 000000000..99be318f1 --- /dev/null +++ b/tiatoolbox/models/architecture/sam.py @@ -0,0 +1,222 @@ +"""Define SAM architecture.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image +from transformers import SamModel, SamProcessor + +from tiatoolbox.models.models_abc import ModelABC + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import IntBounds, IntPair + + +class SAM(ModelABC): + """Segment Anything Model (SAM) Architecture. + + Meta AI's zero-shot segmentation model. + SAM is used for interactive general-purpose segmentation. + + Currently supports SAM, which requires a checkpoint and model type. + + SAM accepts an RGB image patch along with a list of point and bounding + box coordinates as prompts. + + Args: + model_type (str): + Model type. + Currently supported: vit_b, vit_l, vit_h. + checkpoint_path (str): + Path to the model checkpoint. + device (str): + Device to run inference on. + + Examples: + >>> # instantiate SAM with checkpoint path and model type + >>> sam = SAM( + ... model_type="vit_b", + ... checkpoint_path="path/to/sam_checkpoint.pth" + ... ) + """ + + def __init__( + self: SAM, + model_path: str = "facebook/sam-vit-huge", + *, + device: str = "cpu", + ) -> None: + """Initialize :class:`SAM`.""" + super().__init__() + self.net_name = "SAM" + self.device = device + + self.model = SamModel.from_pretrained(model_path).to(device) + self.processor = SamProcessor.from_pretrained(model_path) + + def forward( # skipcq: PYL-W0221 + self: SAM, + imgs: list, + point_coords: list | None = None, + box_coords: list | None = None, + ) -> np.ndarray: + """Torch method. Defines forward pass on each image in the batch. + + Note: This architecture only uses a single layer, so only one forward pass + is needed. + + Args: + imgs (list): + List of images to process, of the shape NHWC. + point_coords (list): + List of point coordinates for each image. + box_coords (list): + Bounding box coordinates for each image. + + Returns: + list: + List of masks and scores for each image. + + """ + masks, scores = [], [] + for i, img in enumerate(imgs): + image = [Image.fromarray(img)] + embeddings, orig_sizes, reshaped_sizes = self._encode_image(image) + point_labels = None + points = None + boxes = None + + # Processor expects coordinates to be lists + def format_coords(coords: np.ndarray | list) -> list: + """Helper function that converts coordinates to list format.""" + if isinstance(coords, np.ndarray): + return coords.tolist() + if isinstance(coords[0], np.ndarray): + return [ + item.tolist() if isinstance(item, np.ndarray) else item + for item in coords + ] + return coords + + if point_coords is not None: + points = point_coords[i] + # Convert point coordinates to list + if points is not None: + point_labels = [[[1] * len(points)]] + points = [format_coords(points)] + + if box_coords is not None: + boxes = box_coords[i] + # Convert box coordinates to list + if boxes is not None: + boxes = [format_coords(boxes)] + + inputs = self.processor( + image, + input_points=points, + input_labels=point_labels, + input_boxes=boxes, + return_tensors="pt", + ).to(self.device) + + # Replaces pixel_values with image embeddings + inputs.pop("pixel_values", None) + inputs.update( + { + "image_embeddings": embeddings, + "original_sizes": orig_sizes, + "reshaped_input_sizes": reshaped_sizes, + } + ) + + with torch.inference_mode(): + # Forward pass through the model + outputs = self.model(**inputs, multimask_output=False) + image_masks = self.processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + ) + image_scores = outputs.iou_scores.cpu() + masks.append(image_masks) + scores.append(image_scores) + torch.cuda.empty_cache() + + return np.array(masks), np.array(scores) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: list, + point_coords: list[list[IntPair]] | None = None, + box_coords: list[IntBounds] | None = None, + *, + device: str = "cpu", + ) -> np.ndarray: + """Run inference on an input batch. + + Contains logic for forward operation as well as I/O aggregation. + SAM accepts a list of points and a single bounding box per image. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (list): + A batch of data generated by + `torch.utils.data.DataLoader`. + point_coords (list): + Point coordinates for each image in the batch. + box_coords (list): + Bounding box coordinates for each image in the batch. + device (str): + Device to run inference on. + + Returns: + pred_info (list): + Tuple of masks and scores for each image in the batch. + + """ + model.eval().to(device) + + if isinstance(batch_data, torch.Tensor): + batch_data = batch_data.cpu().numpy() + + with torch.inference_mode(): + masks, scores = model(batch_data, point_coords, box_coords) + + return masks, scores + + def _encode_image(self: SAM, image: np.ndarray) -> np.ndarray: + """Encodes image and stores size info for later mask post-processing.""" + processed = self.processor(image, return_tensors="pt") + original_sizes = processed["original_sizes"] + reshaped_sizes = processed["reshaped_input_sizes"] + + inputs = processed.to(self.device) + embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) + return embeddings, original_sizes, reshaped_sizes + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Pre-processes an image - Converts it into a format accepted by SAM (HWC).""" + # Move the tensor to the CPU if it's a PyTorch tensor + if isinstance(image, torch.Tensor): + image = image.permute(1, 2, 0).cpu().numpy() + + return image[..., :3] # Remove alpha channel if present + + def to( + self: ModelABC, + device: str = "cpu", + dtype: torch.dtype | None = None, + *, + non_blocking: bool = False, + ) -> ModelABC | torch.nn.DataParallel[ModelABC]: + """Moves the model to the specified device.""" + super().to(device, dtype=dtype, non_blocking=non_blocking) + self.device = device + self.model.to(device) + return self diff --git a/tiatoolbox/models/engine/prompt_segmentor.py b/tiatoolbox/models/engine/prompt_segmentor.py new file mode 100644 index 000000000..9b9fb9b17 --- /dev/null +++ b/tiatoolbox/models/engine/prompt_segmentor.py @@ -0,0 +1,963 @@ +"""This module enables interactive segmentation.""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +import joblib +import numpy as np +import torch +import torch.multiprocessing as torch_mp +import torch.utils.data as torch_data +import tqdm + +from tiatoolbox import logger +from tiatoolbox.models.architecture.sam import SAM +from tiatoolbox.models.engine.semantic_segmentor import ( + IOSegmentorConfig, + SemanticSegmentor, + WSIStreamDataset, +) +from tiatoolbox.models.models_abc import model_to +from tiatoolbox.tools.patchextraction import PointsPatchExtractor +from tiatoolbox.utils.misc import dict_to_store_semantic_segmentor +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader + +if TYPE_CHECKING: # pragma: no cover + from tiatoolbox.type_hints import Callable, IntBounds, IntPair, Resolution, Units + + +class PromptSegmentor(SemanticSegmentor): + """Engine for prompt-based segmentation of WSIs. + + This class is designed to work with the SAM model architecture. + It allows for interactive segmentation by providing point and bounding box + coordinates as prompts. The model can be used in both tile and WSI modes, + where tile mode processes individual image patches and WSI mode processes + whole-slide images. The class also supports multi-prompt segmentation, + where multiple point and bounding box coordinates can be provided for + segmentation. + + Args: + model (SAM): + Model architecture to use. + batch_size (int): + Batch size for processing. + num_loader_workers (int): + Number of workers for data loading. + dataset_class (Callable): + Dataset class to use. + + """ + + def __init__( + self, + model: torch.nn.Module = None, + batch_size: int = 4, + num_loader_workers: int = 0, + dataset_class: Callable = WSIStreamDataset, + ) -> None: + """Initializes the PromptSegmentor.""" + if model is None: + model = SAM() + super().__init__( + batch_size=batch_size, + num_loader_workers=num_loader_workers, + model=model, + dataset_class=dataset_class, + ) + self.multi_prompt = True + + def predict( # skipcq: PYL-W0221 + self, + imgs: list, + masks: list | None = None, + mode: str = "tile", + ioconfig: IOSegmentorConfig = None, + point_coords: list[list[IntPair]] | None = None, + box_coords: list[list[IntBounds]] | None = None, + save_dir: str | Path | None = None, + device: str = "cpu", + *, + multi_prompt: bool = True, + crash_on_exception: bool = False, + **ioconfig_kwargs: dict, + ) -> list[tuple[Path, Path]]: + """Predict on a list of WSIs using prompts. + + Args: + imgs (list, ndarray): + A list of paths to the input WSIs. + masks (list): + A list of masks corresponding to the input WSIs. + Used to filter the coordinates of patches for inference. + mode (str): + The mode of prediction. Can be either `tile` or `wsi`. + Affects how the input images are processed and saved. + Use 'tile' for saving as raw numpy files, or 'wsi' for + saving as annotations. + ioconfig (:class:`IOSegmentorConfig`): + Configuration for input/output processing. + point_coords (list): + Point coordinates for each image as `[x, y]` pairs. + Stored as a list of lists of coordinates. + box_coords (list): + Bounding box coordinates for each image as `[x1, y1, x2, y2]` pairs. + Stored as a list of lists of coordinates. + save_dir (str, Path): + Directory to save the output predictions. + device (str): + Device to run inference on. + crash_on_exception (bool): + Whether to crash on exceptions during prediction. + multi_prompt (bool): + Whether to use multiple prompts simulataneously for segmentation. + If false, the image will be processed for each prompt separately. + **ioconfig_kwargs (dict): + Additional keyword arguments for the IOSegmentorConfig. + + Returns: + output_paths(list[tuple[Path, Path]]): + A list of tuples containing the input image path and the corresponding + output path for the predictions. + Each tuple is of the form (input_path, save_path). + + Examples: + >>> segmentor = PromptSegmentor(model=model) + >>> imgs = ["path/to/image1", "path/to/image2"] + >>> masks = ["path/to/mask1", "path/to/mask2"] + >>> point_coords = [[[100, 200]], [[150, 250]]] + >>> box_coords = [[[50, 50, 150, 150]], [[100, 100, 200, 200]]] + >>> output_paths = segmentor.predict( + ... imgs, + ... masks=masks, + ... mode="tile", + ... point_coords=point_coords, + ... box_coords=box_coords, + ... save_dir="output_dir",) + + """ + if mode not in ["wsi", "tile"]: + msg = f"{mode} is not a valid mode. Use either `tile` or `wsi`." + raise ValueError(msg) + + save_dir, self._cache_dir = self._prepare_save_dir(save_dir=save_dir) + ioconfig_kwargs = self.pad_ioconfig(**ioconfig_kwargs) + ioconfig = self._update_ioconfig(ioconfig, mode, **ioconfig_kwargs) + + # use external for testing + self._device = device + self._model = model_to(model=self.model, device=device) + + # workers should be > 0 else Value Error will be thrown + self._prepare_workers() + + mp_manager = torch_mp.Manager() + mp_shared_space = mp_manager.Namespace() + self._mp_shared_space = mp_shared_space + + ds = self.dataset_class( + ioconfig=ioconfig, + preproc=self.model.preproc_func, + wsi_paths=imgs, + mp_shared_space=mp_shared_space, + mode=mode, + ) + + loader = torch_data.DataLoader( + ds, + drop_last=False, + batch_size=self.batch_size, + num_workers=self.num_loader_workers, + persistent_workers=self.num_loader_workers > 0, + ) + + self._loader = loader + self.imgs = imgs + self.masks = masks + self.multi_prompt = multi_prompt + + self._outputs = [] + + for wsi_idx, image_path in enumerate(imgs): + self._predict_wsi_handle_exception( + imgs, + wsi_idx, + image_path, + mode, + ioconfig, + point_coords[wsi_idx] if point_coords is not None else None, + box_coords[wsi_idx] if box_coords is not None else None, + save_dir, + crash_on_exception=crash_on_exception, + ) + + # clean up the cache directories + try: + shutil.rmtree(self._cache_dir) + except PermissionError: # pragma: no cover + logger.warning("Unable to remove %s", self._cache_dir) + + self._memory_cleanup() + + return self._outputs + + def _predict_one_wsi( # skipcq: PYL-W0221 + self, + wsi_idx: int, + ioconfig: IOSegmentorConfig, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + save_path: str | Path | None = None, + mode: str = "tile", + ) -> tuple[Path, Path, Path]: + """Predict on a single WSI. + + Args: + wsi_idx (int): + Index of the WSI to process. + ioconfig (:class:`IOSegmentorConfig`): + Configuration for input/output processing. + point_coords (list): + Point coordinates for the current image as [x, y] pairs. + box_coords (list): + Bounding box coordinates for the current image as + [x1, y1, x2, y2] pairs. + save_path (str, Path): + Directory to save the output predictions. + mode (str): + The mode of prediction. Can be either "tile" or "wsi". + + Returns: + tuple[Path, Path, Path]: + A tuple containing the input image path and the corresponding + output paths for the predictions. + Each tuple is of the form (input_path, mask_path, score_path). + """ + cache_dir = self._cache_dir / str(wsi_idx) + cache_dir.mkdir(parents=True) + + wsi_path = self.imgs[wsi_idx] + mask_path = None if self.masks is None else self.masks[wsi_idx] + wsi_reader, mask_reader = self.get_reader( + wsi_path, + mask_path, + mode, + auto_get_mask=self.auto_generate_mask, + ) + + resolution = ioconfig.to_baseline().highest_input_resolution + + if mask_reader is not None: + # Filters the point coordinates to only include those within the + # mask. filter_coordinates only accepts bounding-box style coordinates + point_coords = ( + point_coords[ + PromptSegmentor.filter_coordinates( + mask_reader, + point_coords, + **resolution, + ) + ] + if point_coords is not None + else None + ) + if np.array(point_coords).size == 0: + point_coords = None + + box_coords = ( + PromptSegmentor.clip_coordinates(mask_reader, box_coords, **resolution) + if box_coords is not None + else None + ) + if np.array(box_coords).size == 0: + box_coords = None + + patch_inputs, point_coords, box_coords = self.get_coordinates( + wsi_reader=wsi_reader, + ioconfig=ioconfig, + mode=mode, + point_coords=point_coords, + box_coords=box_coords, + multi_prompt=self.multi_prompt, + ) + + resolution = ioconfig.highest_input_resolution + + patch_inputs = ( + np.array(self.clip_coordinates(mask_reader, patch_inputs, **resolution)) + if mask_reader is not None + else patch_inputs + ) + + patch_outputs = patch_inputs.copy() + + # modify the shared space so that we can update worker info + # without needing to re-create the worker. There should be no + # race-condition because only the following enumerate loop + # triggers the parallelism, and this portion is still in + # sequential execution order + patch_inputs = torch.from_numpy(patch_inputs).share_memory_() + patch_outputs = torch.from_numpy(patch_outputs).share_memory_() + self._mp_shared_space.patch_inputs = patch_inputs + self._mp_shared_space.patch_outputs = patch_outputs + self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() + + pbar_desc = "Process Batch: " + pbar = tqdm.tqdm( + desc=pbar_desc, + leave=True, + total=len(self._loader), + ncols=80, + ascii=True, + position=0, + ) + + cum_output = [] + for i, batch_data in enumerate(self._loader): + sample_datas, sample_infos = batch_data + batch_size = sample_infos.shape[0] + # ! depending on the protocol of the output within infer_batch + # ! this may change, how to enforce/document/expose this in a + # ! sensible way? + + prompt_slice = slice(i * batch_size, (i + 1) * batch_size) + + points = point_coords[prompt_slice] if point_coords is not None else None + boxes = box_coords[prompt_slice] if box_coords is not None else None + + # assume to return a list of L output, + # each of shape N x etc. (N=batch size) + + sample_outputs = self.model.infer_batch( + self.model, + sample_datas, + point_coords=points, + box_coords=boxes, + device=self._device, + ) + + # repackage so that it's an N list, each contains + # L x etc. output + sample_outputs = [ + np.split(np.array(v), batch_size, axis=0) for v in sample_outputs + ] + sample_outputs = list(zip(*sample_outputs, strict=False)) + + # tensor to numpy, costly? + sample_infos = sample_infos.numpy() + sample_infos = np.split(sample_infos, batch_size, axis=0) + + sample_outputs = list(zip(sample_infos, sample_outputs, strict=False)) + cum_output.extend(sample_outputs) + pbar.update() + + pbar.close() + + self._process_predictions( + cum_output, + wsi_reader, + ioconfig, + save_path, + cache_dir, + mode, + ) + + # clean up the cache directories + shutil.rmtree(cache_dir) + + @staticmethod + def pad_ioconfig( + **kw_ioconfig: dict, + ) -> dict: + """Assign None to missing keyword ioconfig info.""" + # Define the expected keys + required_keys = [ + "patch_input_shape", + "patch_output_shape", + "stride_shape", + "resolution", + "units", + ] + + # Fill in any missing keys with None + for key in required_keys: + kw_ioconfig.setdefault(key, None) + return kw_ioconfig + + @staticmethod + def _adjust_prompt_resolution( + wsi_reader: WSIReader, + coords: np.ndarray | None, + resolution: Resolution, + units: Units, + ) -> np.ndarray | None: + """Adjust the resolution of the prompt coordinates. + + This function scales the provided coordinates to the specified + resolution and units. It is used to ensure that the coordinates + are in the correct format for processing. + + Args: + wsi_reader (WSIReader): + A reader for the image where the predictions come from. + resolution (Resolution): + The resolution of the image. + units (Units): + The units of the image. + coords (np.ndarray): + Coordinates to adjust. + """ + if coords is not None: + coords = coords * ( + wsi_reader.slide_dimensions(resolution, units)[0] + / wsi_reader.slide_dimensions(1.0, "baseline")[0] + ) + return coords + + @staticmethod + def get_coordinates( # skipcq: PYL-W0221 + wsi_reader: WSIReader, + ioconfig: IOSegmentorConfig, + mode: str, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + *, + multi_prompt: bool = False, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate patch tiling coordinates. + + ! Update this docstring to reflect the new API. + + By default, internally, it will call the + `PatchExtractor.get_coordinates`. To use your own approach, + either subclass to overwrite or directly assign your own + function to this name. In either cases, the function must obey + the API defined here. + + Args: + wsi_reader (WSIReader): + A reader for the image where the predictions come from. + ioconfig (:class:`IOSegmentorConfig`): + Configuration for input/output processing. + mode (str): + The mode of prediction. Can be either `tile` or `wsi`. + point_coords (np.ndarray): + Point coordinates for the current image as [x, y] pairs. + box_coords (np.ndarray): + Bounding box coordinates for the current image as + [x1, y1, x2, y2] pairs. + multi_prompt (bool): + Whether to use multiple prompts simultaneously for segmentation. + If false, the image will be processed for each prompt separately. + + Returns: + tuple: + List of patch inputs and outputs + + - :py:obj:`list` - patch_inputs: + A list of corrdinates in `[start_x, start_y, end_x, + end_y]` format indicating the read location of the + patch in the mother image. + + - point_coords: + A list of point coordinates for the current image + as `[x, y]` pairs. + - box_coords: + A list of bounding box coordinates for the current + image as `[x1, y1, x2, y2]` pairs. + + Examples: + >>> # API of function expected to overwrite `get_coordinates` + >>> def func(image_shape, ioconfig): + ... patch_inputs = np.array([[0, 0, 256, 256]]) + ... patch_outputs = np.array([[0, 0, 256, 256]]) + ... return patch_inputs, patch_outputs + >>> segmentor = SemanticSegmentor(model='unet') + >>> segmentor.get_coordinates = func + + """ + resolution = ioconfig.highest_input_resolution + wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) + image_patch = np.array([0, 0, wsi_proc_shape[0], wsi_proc_shape[1]]) + + point_coords = PromptSegmentor._adjust_prompt_resolution( + wsi_reader, point_coords, **resolution + ) + box_coords = PromptSegmentor._adjust_prompt_resolution( + wsi_reader, box_coords, **resolution + ) + + if multi_prompt: + patch_inputs = np.array([np.copy(image_patch)]) + point_coords = ( + np.array([point_coords]) if point_coords is not None else None + ) + # Will only use the first box passed in + box_coords = np.array([[box_coords[0]]]) if box_coords is not None else None + else: + num_points = len(point_coords) if point_coords is not None else 0 + num_boxes = len(box_coords) if box_coords is not None else 0 + + patch_inputs = PromptSegmentor._extract_patches( + wsi_reader=wsi_reader, + ioconfig=ioconfig, + point_coords=point_coords, + box_coords=box_coords, + mode=mode, + ) + + # Format coordinates by adding padding + # Required for slicing when iterating over DataLoader + point_coords = ( + ([[x] for x in point_coords] + [None] * num_boxes) + if point_coords is not None + else None + ) + box_coords = ( + [None] * num_points + [[y] for y in box_coords] + if box_coords is not None + else None + ) + + return patch_inputs, point_coords, box_coords + + @staticmethod + def _extract_patches( + wsi_reader: WSIReader, + ioconfig: IOSegmentorConfig, + point_coords: np.ndarray | None = None, + box_coords: np.ndarray | None = None, + mode: str = "tile", + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract patches from the WSI, given that it is WSI mode. + + Args: + wsi_reader (WSIReader): + A reader for the image where the predictions come from. + ioconfig (:class:`IOSegmentorConfig`): + Configuration for input/output processing. + mode (str): + The mode of prediction. Can be either `tile` or `wsi`. + point_coords (np.ndarray): + Point coordinates for the current image as [x, y] pairs. + box_coords (np.ndarray): + Bounding box coordinates for the current image as + [x1, y1, x2, y2] pairs. + + Returns: + tuple: + List of patch inputs and outputs + + - :py:obj:`list` - patch_inputs: + A list of corrdinates in `[start_x, start_y, end_x, + end_y]` format indicating the read location of the + patch in the mother image. + + """ + resolution = ioconfig.highest_input_resolution + wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) + image_patch = np.array([0, 0, wsi_proc_shape[0], wsi_proc_shape[1]]) + num_points = len(point_coords) if point_coords is not None else 0 + num_boxes = len(box_coords) if box_coords is not None else 0 + num_patches = num_points + num_boxes + + if mode == "tile": + patch_inputs = np.array([np.copy(image_patch) for _ in range(num_patches)]) + else: + patch_extractor = PointsPatchExtractor( + wsi_reader, point_coords, ioconfig.patch_input_shape, **resolution + ) + patch_inputs = patch_extractor.get_coordinates( + image_shape=wsi_proc_shape, + patch_input_shape=ioconfig.patch_input_shape, + stride_shape=ioconfig.stride_shape, + ) + return patch_inputs + + @staticmethod + def filter_coordinates( + mask_reader: VirtualWSIReader, + bounds: np.ndarray, + resolution: Resolution | None = None, + units: Units | None = None, + ) -> np.ndarray: + """Indicates which coordinate is valid basing on the mask. + + To use your own approaches, either subclass to overwrite or + directly assign your own function to this name. In either cases, + the function must obey the API defined here. + + Args: + mask_reader (:class:`.VirtualReader`): + A virtual pyramidal reader of the mask related to the + WSI from which we want to extract the patches. + bounds (ndarray and np.int32): + Coordinates to be checked via the `func`. They must be + in the same resolution as requested `resolution` and + `units`. The shape of `coordinates` is (N, K) where N is + the number of coordinate sets and K is either 2 for + centroids or 4 for bounding boxes. When using the + default `func=None`, K should be 4, as we expect the + `coordinates` to be bounding boxes in `[start_x, + start_y, end_x, end_y]` format. + resolution (Resolution): + Resolution of the requested patch. + units (Units): + Units of the requested patch. + + Returns: + :class:`numpy.ndarray`: + List of flags to indicate which coordinate is valid. + + Examples: + >>> # API of function expected to overwrite `filter_coordinates` + >>> def func(reader, bounds, resolution, units): + ... # as example, only select first bound + ... return np.array([1, 0]) + >>> coords = [[0, 0, 256, 256], [128, 128, 384, 384]] + >>> segmentor = SemanticSegmentor(model='unet') + >>> segmentor.filter_coordinates = func + + """ + if not isinstance(mask_reader, VirtualWSIReader): + msg = "`mask_reader` should be VirtualWSIReader." + raise TypeError(msg) + + if not isinstance(bounds, np.ndarray) or not np.issubdtype( + bounds.dtype, + np.integer, + ): + msg = "`coordinates` should be ndarray of integer type." + raise ValueError(msg) + + mask_real_shape = mask_reader.img.shape[:2] + mask_resolution_shape = mask_reader.slide_dimensions( + resolution=resolution, + units=units, + )[::-1] + mask_real_shape = np.array(mask_real_shape) + mask_resolution_shape = np.array(mask_resolution_shape) + scale_factor = mask_real_shape / mask_resolution_shape + scale_factor = scale_factor[0] # what if ratio x != y + + # Get mask bounding box + mask_bbox = PromptSegmentor.get_mask_bounds(mask_reader) + scaled_bbox = np.ceil(mask_bbox / scale_factor).astype(np.int32) + + def sel_func(coord: np.ndarray) -> bool: + """Accept coord if it is part of mask.""" + x, y = coord + return (scaled_bbox[0] <= x <= scaled_bbox[2]) and ( + scaled_bbox[1] <= y <= scaled_bbox[3] + ) + + flags = [sel_func(bound) for bound in bounds] + return np.array(flags) + + def _process_predictions( + self, + cum_batch_predictions: list, + wsi_reader: WSIReader, + ioconfig: IOSegmentorConfig, + save_path: str, + cache_dir: str, + mode: str = "tile", + ) -> None: + """Define how the aggregated predictions are processed. + + This includes merging the prediction if necessary and also saving afterwards. + Note that items within `cum_batch_predictions` will be consumed during + the operation. + + Args: + cum_batch_predictions (list): + List of batch predictions. Each item within the list + should be of (location, patch_predictions). + wsi_reader (:class:`WSIReader`): + A reader for the image where the predictions come from. + ioconfig (:class: IOSegmentorConfig): + Configuration for input/output processing. + save_path (str): + Root path to save current WSI predictions. + cache_dir (str): + Root path to cache current WSI data. + mode (str): + Type of input to process. Can either be `tile` or + `wsi`. + + """ + wsi_shape = wsi_reader.slide_dimensions(1.0, "baseline")[::-1] + + if mode == "tile": + self._prepare_save_dir(save_path) + for i, (_, patch_prediction) in enumerate(cum_batch_predictions): + mask_memmap, score_memmap = self._prepare_save_output( + Path(save_path) / f"{i}.raw.0.npy", + Path(save_path) / f"{i}.raw.1.npy", + tuple(wsi_shape), + (len(cum_batch_predictions),), + ) + mask = patch_prediction[0] + score = patch_prediction[1] + + # store the predictions + mask_memmap[:, :] = mask[0] + score_memmap[i] = score[0][0][0] + + mask_memmap.flush() + score_memmap.flush() + + else: + locations, predictions = list(zip(*cum_batch_predictions, strict=False)) + # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of + # output patch this can exceed the image bound at the requested + # resolution remove singleton due to split. + locations = np.array([v[0] for v in locations]) + for index, output_resolution in enumerate(ioconfig.output_resolutions): + # assume resolution index to be in the same order as L + merged_resolution = ioconfig.highest_input_resolution + merged_locations = locations + if ioconfig.save_resolution is not None: + merged_resolution = ioconfig.save_resolution + output_shape = wsi_reader.slide_dimensions(**output_resolution) + merged_shape = wsi_reader.slide_dimensions(**merged_resolution) + fx = merged_shape[0] / output_shape[0] + merged_locations = np.ceil(merged_locations * fx).astype(np.int64) + merged_shape = wsi_reader.slide_dimensions(**merged_resolution) + # ! Need to find better way to extract prediction + to_merge_predictions = predictions[0][0][0][0][0] + sub_save_path = f"{save_path}.raw.{index}.npy" + sub_count_path = f"{cache_dir}/count.{index}.npy" + merged_output = { + "predictions": self.merge_prediction( + merged_shape[::-1], # XY to YX + to_merge_predictions, + merged_locations, + save_path=sub_save_path, + cache_count_path=sub_count_path, + ) + } + # Scale the merged output to the original WSI shape + scale_factor = np.array(wsi_shape) / np.array(merged_shape[::-1]) + # Generate annotations + dict_to_store_semantic_segmentor( + patch_output=merged_output, + scale_factor=scale_factor, + save_path=Path(f"{save_path}.{index}.db"), + ) + + @staticmethod + def get_mask_bounds( + mask_reader: VirtualWSIReader, + ) -> np.ndarray: + """Generate a bounding box for the mask.""" + if not isinstance(mask_reader, VirtualWSIReader): + msg = "`mask_reader` should be VirtualWSIReader." + raise TypeError(msg) + + ys, xs = np.where(mask_reader.img > 0) + x_min, x_max = xs.min(), xs.max() + y_min, y_max = ys.min(), ys.max() + return np.array([x_min, y_min, x_max, y_max]) + + @staticmethod + def clip_coordinates( + mask_reader: VirtualWSIReader, + bounds: np.ndarray, + resolution: Resolution | None = None, + units: Units | None = None, + ) -> np.ndarray: + """Clip coordinates to the mask bounding box. + + This function scales the provided coordinates to the mask + resolution and clips them to the mask bounding box. + Only non-empty boxes are kept. + + Unlike the `filter_coordinates` function in the base class, this + function clips patches to within the mask bounding box, and discards + patches that are completely outside. Therefore, masks should + be overestimates of the area of interest. + + Args: + mask_reader (VirtualWSIReader): + A reader for the image where the predictions come from. + bounds (np.ndarray): + The coordinates to filter. + resolution (Resolution): + The resolution of the image. + units (Units): + The units of the image. + + Returns: + np.ndarray: + The filtered coordinates. + """ + if not isinstance(mask_reader, VirtualWSIReader): + msg = "`mask_reader` should be VirtualWSIReader." + raise TypeError(msg) + + if not isinstance(bounds, np.ndarray) or not np.issubdtype( + bounds.dtype, + np.integer, + ): + msg = "`coordinates` should be ndarray of integer type." + raise ValueError(msg) + + mask_real_shape = mask_reader.img.shape[:2] + mask_resolution_shape = mask_reader.slide_dimensions( + resolution=resolution, + units=units, + )[::-1] + + mask_real_shape = np.array(mask_real_shape) + mask_resolution_shape = np.array(mask_resolution_shape) + scale_factor = mask_real_shape / mask_resolution_shape + scale_factor = scale_factor[0] # what if ratio x != y + + # Get mask bounding box + mask_bbox = PromptSegmentor.get_mask_bounds(mask_reader) + scaled_bbox = np.ceil(mask_bbox / scale_factor).astype(np.int32) + + # Clip to mask bounding box + new_bounds = [] + for box in bounds: + x1, y1, x2, y2 = box + x1_new = max(x1, scaled_bbox[0]) + y1_new = max(y1, scaled_bbox[1]) + x2_new = min(x2, scaled_bbox[2]) + y2_new = min(y2, scaled_bbox[3]) + + # Only keep if box is non-empty + if x1_new < x2_new and y1_new < y2_new: + new_bounds.append([x1_new, y1_new, x2_new, y2_new]) + + return np.array(new_bounds, dtype=np.int32) + + def _predict_wsi_handle_exception( # skipcq: PYL-W0221 + self: PromptSegmentor, + imgs: list, + wsi_idx: int, + img_path: str | Path, + mode: str, + ioconfig: IOSegmentorConfig, + point_coords: list[IntPair], + box_coords: IntBounds, + save_dir: str | Path, + *, + crash_on_exception: bool, + ) -> None: + """Predict on multiple WSIs. + + Args: + imgs (list, ndarray): + List of image file paths to process. + wsi_idx (int): + index of current WSI being processed. + img_path(str or Path): + Path to current image. + mode (str): + Type of input to process. Can either be `tile` or + `wsi`. + ioconfig (:class:`IOSegmentorConfig`): + Object defines information about input and output + placement of patches. + point_coords (list): + List of point coordinates. + box_coords (IntBounds): + Bounding box coordinates in [x1, y1, x2, y2] form. + save_dir (str, Path): + Output directory when processing multiple tiles and + whole-slide images. By default, it is folder `output` + where the running script is invoked. + crash_on_exception (bool): + If `True`, the running loop will crash if there is any + error during processing a WSI. Otherwise, the loop will + move on to the next wsi for processing. + + Returns: + list: + A list of tuple(input_path, save_path) where + `input_path` is the path of the input wsi while + `save_path` corresponds to the output predictions. + + """ + try: + wsi_save_path = save_dir / f"{wsi_idx}" + self._predict_one_wsi( + wsi_idx, ioconfig, point_coords, box_coords, str(wsi_save_path), mode + ) + + # Do not use dict with file name as key, because it can be + # overwritten. It may be user intention to provide files with a + # same name multiple times (maybe they have different root path) + self._outputs.append([str(img_path), str(wsi_save_path)]) + + # ? will this corrupt old version if control + c midway? + map_file_path = save_dir / "file_map.dat" + # backup old version first + if Path.exists(map_file_path): + old_map_file_path = save_dir / "file_map_old.dat" + shutil.copy(map_file_path, old_map_file_path) + joblib.dump(self._outputs, map_file_path) + + # verbose mode, error by passing ? + logging.info("Finish: %d", wsi_idx / len(imgs)) + logging.info("--Input: %s", str(img_path)) + logging.info("--Output: %s", str(wsi_save_path)) + # prevent deep source check because this is bypass and + # delegating error message + except Exception as err: # skipcq: PYL-W0703 + wsi_save_path = save_dir.joinpath(f"{wsi_idx}") + if crash_on_exception: + raise err # noqa: TRY201 + logging.exception("Crashed on %s", wsi_save_path) + + @staticmethod + def _prepare_save_output( + mask_path: str | Path, + score_path: str | Path, + mask_shape: tuple[int, ...], + scores_shape: tuple[int, ...], + ) -> tuple: + """Prepares for saving the cached output.""" + # Check if save path exists + + if mask_path is not None and score_path is not None: + mask_path = Path(mask_path) + score_path = Path(score_path) + mask_memmap = np.lib.format.open_memmap( + mask_path, + mode="w+", + shape=mask_shape, + dtype=np.uint8, + ) + score_memmap = np.lib.format.open_memmap( + score_path, + mode="w+", + shape=scores_shape, + dtype=np.float32, + ) + return mask_memmap, score_memmap + + @staticmethod + def calc_mpp(area_dims: IntPair, base_mpp: float, fixed_size: int = 1500) -> float: + """Calculates the microns per pixel for a fixed area of an image. + + Args: + area_dims (tuple): + Dimensions of the area to be scaled. + base_mpp (float): + Microns per pixel of the base image. + fixed_size (int): + Fixed size of the area. + + Returns: + float: + Microns per pixel required to scale the area to a fixed size. + """ + scale = max(area_dims) / fixed_size if max(area_dims) > fixed_size else 1.0 + return base_mpp * scale diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index a81098acd..ee58936b3 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1204,6 +1204,8 @@ def add_from_dat( store.append_many(anns) +def dict_to_store_semantic_segmentor( + patch_output: dict | zarr.group, def process_contours( contours: list[np.ndarray], hierarchy: np.ndarray, @@ -1339,6 +1341,8 @@ def dict_to_store_semantic_segmentor( layer_list = np.delete(layer_list, np.where(layer_list == 0)) + count = 1 + store = SQLiteStore() _ = class_dict # use it once overlay is working diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py index 5df7acb04..39ff4112d 100644 --- a/tiatoolbox/visualization/bokeh_app/main.py +++ b/tiatoolbox/visualization/bokeh_app/main.py @@ -65,9 +65,8 @@ # GitHub actions seems unable to find TIAToolbox unless this is here sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from tiatoolbox import logger -from tiatoolbox.models.engine.nucleus_instance_segmentor import ( - NucleusInstanceSegmentor, -) +from tiatoolbox.models.engine.nucleus_instance_segmentor import NucleusInstanceSegmentor +from tiatoolbox.models.engine.prompt_segmentor import PromptSegmentor from tiatoolbox.tools.pyramid import ZoomifyGenerator from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.visualization import random_colors @@ -1118,6 +1117,8 @@ def to_model_cb(attr: ButtonClick) -> None: # noqa: ARG001 """Callback to run currently selected model.""" if UI["vstate"].current_model == "hovernet": segment_on_box() + elif UI["vstate"].current_model == "SAM": + sam_segment() # Add any other models here else: # pragma: no cover logger.warning("unknown model") @@ -1273,6 +1274,101 @@ def segment_on_box() -> None: rmtree(tmp_mask_dir) +def sam_segment() -> None: + """Callback to run SAM using a point on the slide. + + Will run GeneralSegmentor on selected region of wsi defined + by the point in pt_source. + + """ + # Get point coordinates + x = np.round(UI["pt_source"].data["x"]) + y = np.round(UI["pt_source"].data["y"]) + point_coords = ( + np.array([[[x[i], -y[i]] for i in range(len(x))]], np.uint32) + if len(x) > 0 + else None + ) + + # Get box coordinates + x = np.round(UI["box_source"].data["x"]) + y = np.round(UI["box_source"].data["y"]) + height = np.round(UI["box_source"].data["height"]) + width = np.round(UI["box_source"].data["width"]) + box_coords = ( + np.array( + [[[x[i], -y[i], x[i] + width[i], height[i] - y[i]] for i in range(len(x))]], + np.uint32, + ) + if len(x) > 0 + else None + ) + + prompt_segmentor = PromptSegmentor() + tmp_save_dir = Path(tempfile.mkdtemp()) + tmp_mask_dir = Path(tempfile.mkdtemp()) + + x_start = max(0, UI["p"].x_range.start) + y_start = max(0, -UI["p"].y_range.end) + x_end = min(UI["p"].x_range.end, UI["vstate"].dims[0]) + y_end = min(-UI["p"].y_range.start, UI["vstate"].dims[1]) + + height = y_end - y_start + width = x_end - x_start + res = prompt_segmentor.calc_mpp((width, height), UI["vstate"].mpp[0], 1500) + + # Make a mask defining the box + thumb = UI["vstate"].wsi.slide_thumbnail() + conv_mpp = UI["vstate"].dims[0] / thumb.shape[1] + x = round(x_start / conv_mpp) + y = round(y_start / conv_mpp) + width = round((x_end - x_start) / conv_mpp) + height = round((y_end - y_start) / conv_mpp) + + mask = np.zeros((thumb.shape[0], thumb.shape[1]), dtype=np.uint8) + mask[y : y + height, x : x + width] = 1 + + Image.fromarray(mask).save(tmp_mask_dir / "mask.png") + # ! Mask is currently causing issues. Tool works fine without it, + # ! but reduction in segmentation quality for larger WSIs. + + # Run SAM on the point + prediction = prompt_segmentor.predict( + imgs=[UI["vstate"].slide_path], + masks=[tmp_mask_dir / "mask.png"], + device=select_device(on_gpu=torch.cuda.is_available()), + save_dir=tmp_save_dir / "sam_out", + point_coords=point_coords, + box_coords=box_coords, + mode="wsi", + patch_input_shape=(1024, 1024), + patch_output_shape=(1024, 1024), + resolution=res, + units="mpp", + multi_prompt=True, + ) + + ann_loc = f"{prediction[0][1]}.0.db" + + slide_filename = UI["vstate"].slide_path.stem + ".db" + destination = doc_config["overlay_folder"] / slide_filename + + # Move the database file + # ! Need to check if this is necessary + move(ann_loc, destination) + + fname = make_safe_name(destination) + resp = UI["s"].put( + f"http://{host2}:{port}/tileserver/overlay", + data={"overlay_path": fname}, + ) + ann_types = json.loads(resp.text) + update_ui_on_new_annotations(ann_types) + + # Clean up temp files + rmtree(tmp_save_dir) + + # endregion # Set up main window @@ -1501,7 +1597,7 @@ def gather_ui_elements( # noqa: PLR0915 ) model_drop = Select( title="choose model:", - options=["hovernet"], + options=["hovernet", "SAM"], height=25, width=120, max_width=120,