Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/models/test_arch_grandqc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Unit test package for GrandQC Tissue Model."""

from collections.abc import Callable
from pathlib import Path

import numpy as np
import pytest
import torch
from torch import nn

from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.grandqc import (
CenterBlock,
GrandQCModel,
SegmentationHead,
UnetPlusPlusDecoder,
)
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.wsicore.wsireader import VirtualWSIReader

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_functional_grandqc() -> None:
"""Test for GrandQC model."""
# test fetch pretrained weights
pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection")
assert pretrained_weights is not None

# test creation
model = GrandQCModel(num_output_channels=2)
assert model is not None

# load pretrained weights
pretrained = torch.load(pretrained_weights, map_location=device)
model.load_state_dict(pretrained)

# test get pretrained model
model, ioconfig = get_pretrained_model("grandqc_tissue_detection")
assert isinstance(model, GrandQCModel)
assert isinstance(ioconfig, IOSegmentorConfig)
assert model.num_output_channels == 2
assert model.decoder_channels == (256, 128, 64, 32, 16)

# test inference
generator = np.random.default_rng(1337)
test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8)
reader = VirtualWSIReader.open(test_image)
read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"}
batch = np.array(
[
reader.read_bounds((0, 0, 512, 512), **read_kwargs),
reader.read_bounds((512, 512, 1024, 1024), **read_kwargs),
],
)
batch = torch.from_numpy(batch)
output = model.infer_batch(model, batch, device=device)
assert output.shape == (2, 512, 512, 2)


def test_grandqc_preproc_postproc() -> None:
"""Test GrandQC preproc and postproc functions."""
model = GrandQCModel(num_output_channels=2)

generator = np.random.default_rng(1337)
# test preproc
dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8)
preproc_image = model.preproc(dummy_image)
assert preproc_image.shape == dummy_image.shape
assert preproc_image.dtype == np.float64

# test postproc
dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32)
postproc_image = model.postproc(dummy_output)
assert postproc_image.shape == (512, 512)
assert postproc_image.dtype == np.int64


def test_grandqc_with_semantic_segmentor(
remote_sample: Callable, track_tmp_path: Path
) -> None:
"""Test GrandQC tissue mask generation."""
segmentor = SemanticSegmentor(model="grandqc_tissue_detection")

sample_image = remote_sample("svs-1-small")
inputs = [str(sample_image)]

output = segmentor.run(
images=inputs,
device=device,
patch_mode=False,
output_type="annotationstore",
save_dir=track_tmp_path / "grandqc_test_outputs",
overwrite=True,
)

assert len(output) == 1
assert Path(output[sample_image]).exists()

store = SQLiteStore.open(output[sample_image])
assert len(store) == 3

tissue_area_px = 0.0
for annotation in store.values():
assert annotation.properties["type"] == "mask"
tissue_area_px += annotation.geometry.area
assert 2999000 < tissue_area_px < 3004000

store.close()


def test_segmentation_head_behaviour() -> None:
"""Verify SegmentationHead defaults and upsampling."""
head = SegmentationHead(3, 5, activation=None, upsampling=1)
assert isinstance(head[1], nn.Identity)
assert isinstance(head[2], nn.Identity)

x = torch.randn(1, 3, 6, 8)
out = head(x)
assert out.shape == (1, 5, 6, 8)

head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2)
x = torch.ones(1, 3, 4, 4)
out = head(x)
assert out.shape == (1, 2, 8, 8)
assert torch.all(out >= 0)
assert torch.all(out <= 1)


def test_unetplusplus_decoder_forward_shapes() -> None:
"""Ensure UnetPlusPlusDecoder handles dense connections."""
decoder = UnetPlusPlusDecoder(
encoder_channels=[1, 2, 4, 8],
decoder_channels=[8, 4, 2],
n_blocks=3,
)

features = [
torch.randn(1, 1, 32, 32),
torch.randn(1, 2, 16, 16),
torch.randn(1, 4, 8, 8),
torch.randn(1, 8, 4, 4),
]

output = decoder(features)
assert output.shape == (1, 2, 32, 32)


def test_center_block_behavior() -> None:
"""Test CenterBlock behavior in UnetPlusPlusDecoder."""
center_block = CenterBlock(in_channels=8, out_channels=8)

x = torch.randn(1, 8, 4, 4)
out = center_block(x)
assert out.shape == (1, 8, 4, 4)


def test_unetpp_raises_value_error() -> None:
"""Test UnetPlusPlusDecoder raises ValueError."""
with pytest.raises(
ValueError, match=r".*depth is 4, but you provide `decoder_channels` for 3.*"
):
_ = UnetPlusPlusDecoder(
encoder_channels=[1, 2, 4, 8],
decoder_channels=[8, 4, 2],
n_blocks=4,
)
177 changes: 177 additions & 0 deletions tests/models/test_arch_timm_efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Unit tests for timm EfficientNet encoder helpers."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence

import pytest
import torch
from torch import nn

from tiatoolbox.models.architecture import timm_efficientnet as effnet_mod
from tiatoolbox.models.architecture.timm_efficientnet import (
DEFAULT_IN_CHANNELS,
EfficientNetEncoder,
EncoderMixin,
replace_strides_with_dilation,
)


class DummyEncoder(nn.Module, EncoderMixin):
"""Lightweight encoder for testing mixin behavior."""

def __init__(self) -> None:
"""Initialize EncoderMixin for testing."""
nn.Module.__init__(self)
EncoderMixin.__init__(self)
self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1)
self.conv32 = nn.Conv2d(4, 4, 3)
self._out_channels = [DEFAULT_IN_CHANNELS, 4, 8]
self._depth = 2

def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]:
"""Get stages for dilation modification.

Returns:
Dictionary with keys as output stride and values as list of modules.
"""
return {16: [self.conv], 32: [self.conv32]}


def test_patch_first_conv() -> None:
"""patch_first_conv should reduce or expand correctly."""
# create simple conv
model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False))
conv = model[0]

# collapsing 3 channels into 1
effnet_mod.patch_first_conv(model, new_in_channels=1, pretrained=True)
assert conv.in_channels == 1

# expanding to 5 channels
model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False))
conv = model[0]

effnet_mod.patch_first_conv(model, new_in_channels=5, pretrained=True)
assert conv.in_channels == 5


def test_patch_first_conv_reset_weights_when_not_pretrained() -> None:
"""Ensure random reinit happens when pretrained flag is False."""
# start from known weights
model = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1, bias=False))
original = model[0].weight.clone()
# changing channel count without pretrained should reinit parameters
effnet_mod.patch_first_conv(model, new_in_channels=4, pretrained=False)
assert model[0].in_channels == 4
assert model[0].weight.shape[1] == 4
# Almost surely changed due to reset_parameters
assert not torch.equal(original, model[0].weight[:1, :3])


def test_patch_first_conv_no_matching_layer_is_safe() -> None:
"""The function should silently exit when no suitable conv exists."""
model = nn.Sequential(nn.Conv2d(5, 1, kernel_size=1))
original = model[0].weight.clone()
# no conv with default channel count, so weights stay unchanged
effnet_mod.patch_first_conv(model, new_in_channels=3, pretrained=True)
assert torch.equal(original, model[0].weight)


def test_replace_strides_with_dilation_applies_to_nested_convs() -> None:
"""Strides become dilation and static padding gets removed."""
module = nn.Sequential(
nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1),
)
# attach static_padding to mirror EfficientNet convs
module[0].static_padding = nn.Conv2d(1, 1, 1)

# applying dilation should also strip static padding
replace_strides_with_dilation(module, dilation_rate=3)
conv = module[0]
assert conv.stride == (1, 1)
assert conv.dilation == (3, 3)
assert conv.padding == (3, 3)
assert isinstance(conv.static_padding, nn.Identity)


def test_encoder_mixin_properties_and_set_in_channels() -> None:
"""EncoderMixin should expose out_channels/output_stride and patch convs."""
# use dummy encoder to check property logic
encoder = DummyEncoder()
assert encoder.out_channels == [3, 4, 8]
# adjust internals to check min logic in output_stride
encoder._output_stride = 4
encoder._depth = 3
assert encoder.output_stride == 4 # min(output_stride, 2**depth)

# calling set_in_channels should patch first conv and update bookkeeping
encoder.set_in_channels(5, pretrained=False)
assert encoder._in_channels == 5
assert encoder.out_channels[0] == 5
assert encoder.conv.in_channels == 5


def test_encoder_mixin_make_dilated_and_validation() -> None:
"""make_dilated should error on invalid stride and patch convs otherwise."""
encoder = DummyEncoder()

# invalid stride raises
with pytest.raises(ValueError, match="Output stride should be 16 or 8"):
encoder.make_dilated(output_stride=4)

# valid stride should touch both stage groups
encoder.make_dilated(output_stride=8)
conv16, conv32 = encoder.get_stages()[16][0], encoder.get_stages()[32][0]
assert conv16.stride == (1, 1)
assert conv16.dilation == (2, 2)
assert conv32.stride == (1, 1)
assert conv32.dilation == (4, 4)


def test_get_efficientnet_kwargs_shapes_and_values() -> None:
"""get_efficientnet_kwargs should produce expected keys and scaling."""
# confirm output contains decoded blocks and scaled channels
kwargs = effnet_mod.get_efficientnet_kwargs(
channel_multiplier=1.2, depth_multiplier=1.4, drop_rate=0.3
)
assert kwargs.get("block_args")
assert kwargs["num_features"] == effnet_mod.round_channels(1280, 1.2, 8, None)
assert kwargs["drop_rate"] == 0.3


def test_efficientnet_encoder_depth_validation_and_forward() -> None:
"""EfficientNetEncoder should validate depth and run forward returning features."""
# invalid depth should fail fast
with pytest.raises(
ValueError, match=r"EfficientNetEncoder depth should be in range\s+\[1, 5\]"
):
EfficientNetEncoder(
stage_idxs=[2, 3, 5],
out_channels=[3, 32, 24, 40, 112, 320],
depth=6,
)

# build shallow encoder and run a forward pass
encoder = EfficientNetEncoder(
stage_idxs=[2, 3, 5],
out_channels=[3, 32, 24, 40, 112, 320],
depth=3,
channel_multiplier=0.5,
depth_multiplier=0.5,
)
x = torch.randn(1, 3, 32, 32)
features = encoder(x)
assert len(features) == encoder._depth + 1
assert torch.equal(features[0], x)

# ensure classifier keys are dropped before loading into the model
extended_state = dict(encoder.state_dict())
extended_state["classifier.bias"] = torch.tensor([1.0])
extended_state["classifier.weight"] = torch.tensor([[1.0]])
load_result = encoder.load_state_dict(extended_state, strict=True)
assert not load_result.missing_keys
assert not load_result.unexpected_keys
Loading