-
Notifications
You must be signed in to change notification settings - Fork 102
🆕 Add GrandQC Tissue Segmentation Model #965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jiaqi-Lv
wants to merge
20
commits into
dev-define-engines-abc
Choose a base branch
from
dev-add-grandQC
base: dev-define-engines-abc
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,723
−6
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b18b98f
add grandqc tissue model
Jiaqi-Lv 899d6cb
add example
Jiaqi-Lv 8a7295d
fix tests
Jiaqi-Lv 5c5bfc4
fix error
Jiaqi-Lv fd692da
update docstring
Jiaqi-Lv d82cc3d
improve test coverage
Jiaqi-Lv 93a24a1
add unet++ model
Jiaqi-Lv 2d076c0
Merge branch 'dev-define-engines-abc' into dev-add-grandQC
shaneahmed 283b888
Merge branch 'dev-add-grandQC' of https://github.com/TissueImageAnaly…
Jiaqi-Lv 94c43ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 98cef83
remove smp dependency
Jiaqi-Lv d47fa0a
refactor code
Jiaqi-Lv d2a66ca
add tests
Jiaqi-Lv 19cca90
address comments
Jiaqi-Lv 1895e38
:memo: Update docstring for grandqc.py and timm_efficientnet.py
shaneahmed 3ade99a
:bug: Fix docstring
shaneahmed 5f0202f
:memo: Remove duplicate docstring for classses.
shaneahmed 6b8eb90
address comments
Jiaqi-Lv 2ce379f
update test
Jiaqi-Lv 9c62b72
:white_check_mark: Add test to improve coverage
shaneahmed File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| """Unit test package for GrandQC Tissue Model.""" | ||
|
|
||
| from collections.abc import Callable | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from tiatoolbox.annotation.storage import SQLiteStore | ||
| from tiatoolbox.models.architecture import ( | ||
| fetch_pretrained_weights, | ||
| get_pretrained_model, | ||
| ) | ||
| from tiatoolbox.models.architecture.grandqc import ( | ||
| CenterBlock, | ||
| GrandQCModel, | ||
| SegmentationHead, | ||
| UnetPlusPlusDecoder, | ||
| ) | ||
| from tiatoolbox.models.engine.io_config import IOSegmentorConfig | ||
| from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor | ||
| from tiatoolbox.utils import env_detection as toolbox_env | ||
| from tiatoolbox.wsicore.wsireader import VirtualWSIReader | ||
|
|
||
| device = "cuda" if toolbox_env.has_gpu() else "cpu" | ||
|
|
||
|
|
||
| def test_functional_grandqc() -> None: | ||
| """Test for GrandQC model.""" | ||
| # test fetch pretrained weights | ||
| pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection") | ||
| assert pretrained_weights is not None | ||
|
|
||
| # test creation | ||
| model = GrandQCModel(num_output_channels=2) | ||
| assert model is not None | ||
|
|
||
| # load pretrained weights | ||
| pretrained = torch.load(pretrained_weights, map_location=device) | ||
| model.load_state_dict(pretrained) | ||
|
|
||
| # test get pretrained model | ||
| model, ioconfig = get_pretrained_model("grandqc_tissue_detection") | ||
| assert isinstance(model, GrandQCModel) | ||
| assert isinstance(ioconfig, IOSegmentorConfig) | ||
| assert model.num_output_channels == 2 | ||
| assert model.decoder_channels == (256, 128, 64, 32, 16) | ||
|
|
||
| # test inference | ||
| generator = np.random.default_rng(1337) | ||
| test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8) | ||
| reader = VirtualWSIReader.open(test_image) | ||
| read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"} | ||
| batch = np.array( | ||
| [ | ||
| reader.read_bounds((0, 0, 512, 512), **read_kwargs), | ||
| reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), | ||
| ], | ||
| ) | ||
| batch = torch.from_numpy(batch) | ||
| output = model.infer_batch(model, batch, device=device) | ||
| assert output.shape == (2, 512, 512, 2) | ||
|
|
||
|
|
||
| def test_grandqc_preproc_postproc() -> None: | ||
| """Test GrandQC preproc and postproc functions.""" | ||
| model = GrandQCModel(num_output_channels=2) | ||
|
|
||
| generator = np.random.default_rng(1337) | ||
| # test preproc | ||
| dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8) | ||
| preproc_image = model.preproc(dummy_image) | ||
| assert preproc_image.shape == dummy_image.shape | ||
| assert preproc_image.dtype == np.float64 | ||
|
|
||
| # test postproc | ||
| dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32) | ||
| postproc_image = model.postproc(dummy_output) | ||
| assert postproc_image.shape == (512, 512) | ||
| assert postproc_image.dtype == np.int64 | ||
|
|
||
|
|
||
| def test_grandqc_with_semantic_segmentor( | ||
| remote_sample: Callable, track_tmp_path: Path | ||
| ) -> None: | ||
| """Test GrandQC tissue mask generation.""" | ||
| segmentor = SemanticSegmentor(model="grandqc_tissue_detection") | ||
|
|
||
| sample_image = remote_sample("svs-1-small") | ||
| inputs = [str(sample_image)] | ||
|
|
||
| output = segmentor.run( | ||
| images=inputs, | ||
| device=device, | ||
| patch_mode=False, | ||
| output_type="annotationstore", | ||
| save_dir=track_tmp_path / "grandqc_test_outputs", | ||
| overwrite=True, | ||
| ) | ||
|
|
||
| assert len(output) == 1 | ||
| assert Path(output[sample_image]).exists() | ||
|
|
||
| store = SQLiteStore.open(output[sample_image]) | ||
| assert len(store) == 3 | ||
|
|
||
| tissue_area_px = 0.0 | ||
| for annotation in store.values(): | ||
| assert annotation.properties["type"] == "mask" | ||
| tissue_area_px += annotation.geometry.area | ||
| assert 2999000 < tissue_area_px < 3004000 | ||
|
|
||
| store.close() | ||
|
|
||
|
|
||
| def test_segmentation_head_behaviour() -> None: | ||
| """Verify SegmentationHead defaults and upsampling.""" | ||
| head = SegmentationHead(3, 5, activation=None, upsampling=1) | ||
| assert isinstance(head[1], nn.Identity) | ||
| assert isinstance(head[2], nn.Identity) | ||
|
|
||
| x = torch.randn(1, 3, 6, 8) | ||
| out = head(x) | ||
| assert out.shape == (1, 5, 6, 8) | ||
|
|
||
| head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2) | ||
| x = torch.ones(1, 3, 4, 4) | ||
| out = head(x) | ||
| assert out.shape == (1, 2, 8, 8) | ||
| assert torch.all(out >= 0) | ||
| assert torch.all(out <= 1) | ||
|
|
||
|
|
||
| def test_unetplusplus_decoder_forward_shapes() -> None: | ||
| """Ensure UnetPlusPlusDecoder handles dense connections.""" | ||
| decoder = UnetPlusPlusDecoder( | ||
| encoder_channels=[1, 2, 4, 8], | ||
| decoder_channels=[8, 4, 2], | ||
| n_blocks=3, | ||
| ) | ||
|
|
||
| features = [ | ||
| torch.randn(1, 1, 32, 32), | ||
| torch.randn(1, 2, 16, 16), | ||
| torch.randn(1, 4, 8, 8), | ||
| torch.randn(1, 8, 4, 4), | ||
| ] | ||
|
|
||
| output = decoder(features) | ||
| assert output.shape == (1, 2, 32, 32) | ||
|
|
||
|
|
||
| def test_center_block_behavior() -> None: | ||
| """Test CenterBlock behavior in UnetPlusPlusDecoder.""" | ||
| center_block = CenterBlock(in_channels=8, out_channels=8) | ||
|
|
||
| x = torch.randn(1, 8, 4, 4) | ||
| out = center_block(x) | ||
| assert out.shape == (1, 8, 4, 4) | ||
|
|
||
|
|
||
| def test_unetpp_raises_value_error() -> None: | ||
| """Test UnetPlusPlusDecoder raises ValueError.""" | ||
| with pytest.raises( | ||
| ValueError, match=r".*depth is 4, but you provide `decoder_channels` for 3.*" | ||
| ): | ||
| _ = UnetPlusPlusDecoder( | ||
| encoder_channels=[1, 2, 4, 8], | ||
| decoder_channels=[8, 4, 2], | ||
| n_blocks=4, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| """Unit tests for timm EfficientNet encoder helpers.""" | ||
Jiaqi-Lv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from tiatoolbox.models.architecture import timm_efficientnet as effnet_mod | ||
| from tiatoolbox.models.architecture.timm_efficientnet import ( | ||
| DEFAULT_IN_CHANNELS, | ||
| EfficientNetEncoder, | ||
| EncoderMixin, | ||
| replace_strides_with_dilation, | ||
| ) | ||
|
|
||
|
|
||
| class DummyEncoder(nn.Module, EncoderMixin): | ||
| """Lightweight encoder for testing mixin behavior.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| """Initialize EncoderMixin for testing.""" | ||
| nn.Module.__init__(self) | ||
| EncoderMixin.__init__(self) | ||
| self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1) | ||
| self.conv32 = nn.Conv2d(4, 4, 3) | ||
| self._out_channels = [DEFAULT_IN_CHANNELS, 4, 8] | ||
| self._depth = 2 | ||
|
|
||
| def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: | ||
| """Get stages for dilation modification. | ||
|
|
||
| Returns: | ||
| Dictionary with keys as output stride and values as list of modules. | ||
| """ | ||
| return {16: [self.conv], 32: [self.conv32]} | ||
|
|
||
|
|
||
| def test_patch_first_conv() -> None: | ||
| """patch_first_conv should reduce or expand correctly.""" | ||
| # create simple conv | ||
| model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) | ||
| conv = model[0] | ||
|
|
||
| # collapsing 3 channels into 1 | ||
| effnet_mod.patch_first_conv(model, new_in_channels=1, pretrained=True) | ||
| assert conv.in_channels == 1 | ||
|
|
||
| # expanding to 5 channels | ||
| model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) | ||
| conv = model[0] | ||
|
|
||
| effnet_mod.patch_first_conv(model, new_in_channels=5, pretrained=True) | ||
| assert conv.in_channels == 5 | ||
|
|
||
|
|
||
| def test_patch_first_conv_reset_weights_when_not_pretrained() -> None: | ||
| """Ensure random reinit happens when pretrained flag is False.""" | ||
| # start from known weights | ||
| model = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1, bias=False)) | ||
| original = model[0].weight.clone() | ||
| # changing channel count without pretrained should reinit parameters | ||
| effnet_mod.patch_first_conv(model, new_in_channels=4, pretrained=False) | ||
| assert model[0].in_channels == 4 | ||
| assert model[0].weight.shape[1] == 4 | ||
| # Almost surely changed due to reset_parameters | ||
| assert not torch.equal(original, model[0].weight[:1, :3]) | ||
|
|
||
|
|
||
| def test_patch_first_conv_no_matching_layer_is_safe() -> None: | ||
| """The function should silently exit when no suitable conv exists.""" | ||
| model = nn.Sequential(nn.Conv2d(5, 1, kernel_size=1)) | ||
| original = model[0].weight.clone() | ||
| # no conv with default channel count, so weights stay unchanged | ||
| effnet_mod.patch_first_conv(model, new_in_channels=3, pretrained=True) | ||
| assert torch.equal(original, model[0].weight) | ||
|
|
||
|
|
||
| def test_replace_strides_with_dilation_applies_to_nested_convs() -> None: | ||
| """Strides become dilation and static padding gets removed.""" | ||
| module = nn.Sequential( | ||
| nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1), | ||
| ) | ||
| # attach static_padding to mirror EfficientNet convs | ||
| module[0].static_padding = nn.Conv2d(1, 1, 1) | ||
|
|
||
| # applying dilation should also strip static padding | ||
| replace_strides_with_dilation(module, dilation_rate=3) | ||
| conv = module[0] | ||
| assert conv.stride == (1, 1) | ||
| assert conv.dilation == (3, 3) | ||
| assert conv.padding == (3, 3) | ||
| assert isinstance(conv.static_padding, nn.Identity) | ||
|
|
||
|
|
||
| def test_encoder_mixin_properties_and_set_in_channels() -> None: | ||
| """EncoderMixin should expose out_channels/output_stride and patch convs.""" | ||
| # use dummy encoder to check property logic | ||
| encoder = DummyEncoder() | ||
| assert encoder.out_channels == [3, 4, 8] | ||
| # adjust internals to check min logic in output_stride | ||
| encoder._output_stride = 4 | ||
| encoder._depth = 3 | ||
| assert encoder.output_stride == 4 # min(output_stride, 2**depth) | ||
|
|
||
| # calling set_in_channels should patch first conv and update bookkeeping | ||
| encoder.set_in_channels(5, pretrained=False) | ||
| assert encoder._in_channels == 5 | ||
| assert encoder.out_channels[0] == 5 | ||
| assert encoder.conv.in_channels == 5 | ||
|
|
||
|
|
||
| def test_encoder_mixin_make_dilated_and_validation() -> None: | ||
| """make_dilated should error on invalid stride and patch convs otherwise.""" | ||
| encoder = DummyEncoder() | ||
|
|
||
| # invalid stride raises | ||
| with pytest.raises(ValueError, match="Output stride should be 16 or 8"): | ||
| encoder.make_dilated(output_stride=4) | ||
|
|
||
| # valid stride should touch both stage groups | ||
| encoder.make_dilated(output_stride=8) | ||
| conv16, conv32 = encoder.get_stages()[16][0], encoder.get_stages()[32][0] | ||
| assert conv16.stride == (1, 1) | ||
| assert conv16.dilation == (2, 2) | ||
| assert conv32.stride == (1, 1) | ||
| assert conv32.dilation == (4, 4) | ||
|
|
||
|
|
||
| def test_get_efficientnet_kwargs_shapes_and_values() -> None: | ||
| """get_efficientnet_kwargs should produce expected keys and scaling.""" | ||
| # confirm output contains decoded blocks and scaled channels | ||
| kwargs = effnet_mod.get_efficientnet_kwargs( | ||
| channel_multiplier=1.2, depth_multiplier=1.4, drop_rate=0.3 | ||
| ) | ||
| assert kwargs.get("block_args") | ||
| assert kwargs["num_features"] == effnet_mod.round_channels(1280, 1.2, 8, None) | ||
| assert kwargs["drop_rate"] == 0.3 | ||
|
|
||
|
|
||
| def test_efficientnet_encoder_depth_validation_and_forward() -> None: | ||
| """EfficientNetEncoder should validate depth and run forward returning features.""" | ||
| # invalid depth should fail fast | ||
| with pytest.raises( | ||
| ValueError, match=r"EfficientNetEncoder depth should be in range\s+\[1, 5\]" | ||
| ): | ||
| EfficientNetEncoder( | ||
| stage_idxs=[2, 3, 5], | ||
| out_channels=[3, 32, 24, 40, 112, 320], | ||
| depth=6, | ||
| ) | ||
|
|
||
| # build shallow encoder and run a forward pass | ||
| encoder = EfficientNetEncoder( | ||
| stage_idxs=[2, 3, 5], | ||
| out_channels=[3, 32, 24, 40, 112, 320], | ||
| depth=3, | ||
| channel_multiplier=0.5, | ||
| depth_multiplier=0.5, | ||
| ) | ||
| x = torch.randn(1, 3, 32, 32) | ||
| features = encoder(x) | ||
| assert len(features) == encoder._depth + 1 | ||
| assert torch.equal(features[0], x) | ||
|
|
||
| # ensure classifier keys are dropped before loading into the model | ||
| extended_state = dict(encoder.state_dict()) | ||
| extended_state["classifier.bias"] = torch.tensor([1.0]) | ||
| extended_state["classifier.weight"] = torch.tensor([[1.0]]) | ||
| load_result = encoder.load_state_dict(extended_state, strict=True) | ||
| assert not load_result.missing_keys | ||
| assert not load_result.unexpected_keys | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.