Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ requests>=2.28.1
scikit-image>=0.20
scikit-learn>=1.2.0
scipy>=1.8
segmentation-models-pytorch>=0.5.0
shapely>=2.0.0
SimpleITK>=2.2.1
sphinx>=5.3.0
Expand Down
54 changes: 54 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Unit test package for GrandQC Tissue Model."""

from collections.abc import Callable
from pathlib import Path

import numpy as np
import torch

from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import TissueDetectionModel
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = False


def test_functional_grandqc(remote_sample: Callable) -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10")
assert pretrained_weights is not None

# test creation
model = TissueDetectionModel(num_input_channels=3, num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10")
assert isinstance(model, TissueDetectionModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_input_channels == 3
assert model.num_output_channels == 2

# test inference
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
reader = WSIReader.open(mini_wsi_svs)
read_kwargs = {"resolution": 10.0, "units": "mpp", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
assert output.shape == (2, 512, 512, 2)
31 changes: 25 additions & 6 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ mapde-crchisto:
threshold_abs: 250
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -837,7 +837,7 @@ mapde-conic:
threshold_abs: 205
num_classes: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -860,7 +860,7 @@ sccnn-crchisto:
threshold_abs: 0.20
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -883,7 +883,7 @@ sccnn-conic:
threshold_abs: 0.05
patch_output_shape: [ 13, 13 ]
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- { "units": "mpp", "resolution": 0.5 }
Expand All @@ -903,7 +903,7 @@ nuclick_original-pannuke:
num_input_channels: 5
num_output_channels: 1
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -925,7 +925,7 @@ nuclick_light-pannuke:
decoder_block: [3,3]
skip_type: "add"
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 0.25}
Expand All @@ -934,3 +934,22 @@ nuclick_light-pannuke:
patch_input_shape: [128, 128]
patch_output_shape: [128, 128]
save_resolution: {'units': 'baseline', 'resolution': 1.0}

grandqc_tissue_detection_mpp10:
hf_repo_id: TIACentre/GrandQC_Tissue_Detection
architecture:
class: grandqc.TissueDetectionModel
kwargs:
num_input_channels: 3
num_output_channels: 2
ioconfig:
class: io_config.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
output_resolutions:
- {'units': 'mpp', 'resolution': 10.0}
patch_input_shape: [512, 512]
patch_output_shape: [512, 512]
stride_shape: [256, 256]
save_resolution: {'units': 'mpp', 'resolution': 10.0}
130 changes: 130 additions & 0 deletions tiatoolbox/models/architecture/grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Define GrandQC Tissue Detection Model architecture."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Mapping

import cv2
import numpy as np
import segmentation_models_pytorch as smp
import torch

from tiatoolbox.models.models_abc import ModelABC


class TissueDetectionModel(ModelABC):
"""GrandQC Tissue Detection Model.
Example:
>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
>>> semantic_segmentor = SemanticSegmentor(
... model="grandqc_tissue_detection_mpp10",
... )
>>> results = semantic_segmentor.run(
... ["/example_wsi.svs"],
... masks=None,
... auto_get_mask=False,
... patch_mode=False,
... save_dir=Path("/tissue_mask/"),
... output_type="annotationstore",
... )
"""

def __init__(
self: TissueDetectionModel, num_input_channels: int, num_output_channels: int
) -> None:
"""Initialize TissueDetectionModel."""
super().__init__()
self.num_input_channels = num_input_channels
self.num_output_channels = num_output_channels
self._postproc = self.postproc
self._preproc = self.preproc
self.tissue_detection_model = smp.UnetPlusPlus(
encoder_name="timm-efficientnet-b0",
encoder_weights=None,
in_channels=self.num_input_channels,
classes=self.num_output_channels,
activation=None,
)

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Apply jpg compression then ImageNet normalise."""
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
_, compressed_image = cv2.imencode(".jpg", image, encode_param)
compressed_image = np.array(cv2.imdecode(compressed_image, 1))

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
return (compressed_image / 255.0 - mean) / std

@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define post-processing for this model.
This returns the class index with the minimum probability.
In this model, this means selecting tissue class.
"""
return image.argmin(axis=-1)

def forward(
self: TissueDetectionModel,
imgs: torch.Tensor,
*args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002
**kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002
) -> torch.Tensor:
"""Forward function for model."""
return self.tissue_detection_model(imgs)

@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: torch.Tensor,
*,
device: str,
) -> np.ndarray:
"""Run inference on an input batch.
This contains logic for forward operation as well as i/o
Args:
model (nn.Module):
PyTorch defined model.
batch_data (:class:`torch.Tensor`):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".
Returns:
np.ndarray:
The inference results as a numpy array.
"""
model.eval()

####
imgs = batch_data

imgs = imgs.to(device).type(torch.float32)
imgs = imgs.permute(0, 3, 1, 2) # to NCHW

with torch.inference_mode():
logits = model(imgs)
probs = torch.nn.functional.softmax(logits, 1)
probs = probs.permute(0, 2, 3, 1) # to NHWC

return probs.cpu().numpy()

def load_state_dict(
self: TissueDetectionModel,
state_dict: Mapping[str, Any],
**kwargs: bool,
) -> torch.nn.modules.module._IncompatibleKeys:
"""Load state dict for the TissueDetectionModel."""
return self.tissue_detection_model.load_state_dict(state_dict, **kwargs)
Loading