-
Notifications
You must be signed in to change notification settings - Fork 102
🆕 Add GrandQC Tissue Segmentation Model #965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jiaqi-Lv
wants to merge
19
commits into
dev-define-engines-abc
Choose a base branch
from
dev-add-grandQC
base: dev-define-engines-abc
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,707
−6
Open
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
b18b98f
add grandqc tissue model
Jiaqi-Lv 899d6cb
add example
Jiaqi-Lv 8a7295d
fix tests
Jiaqi-Lv 5c5bfc4
fix error
Jiaqi-Lv fd692da
update docstring
Jiaqi-Lv d82cc3d
improve test coverage
Jiaqi-Lv 93a24a1
add unet++ model
Jiaqi-Lv 2d076c0
Merge branch 'dev-define-engines-abc' into dev-add-grandQC
shaneahmed 283b888
Merge branch 'dev-add-grandQC' of https://github.com/TissueImageAnaly…
Jiaqi-Lv 94c43ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 98cef83
remove smp dependency
Jiaqi-Lv d47fa0a
refactor code
Jiaqi-Lv d2a66ca
add tests
Jiaqi-Lv 19cca90
address comments
Jiaqi-Lv 1895e38
:memo: Update docstring for grandqc.py and timm_efficientnet.py
shaneahmed 3ade99a
:bug: Fix docstring
shaneahmed 5f0202f
:memo: Remove duplicate docstring for classses.
shaneahmed 6b8eb90
address comments
Jiaqi-Lv 2ce379f
update test
Jiaqi-Lv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| """Unit test package for GrandQC Tissue Model.""" | ||
|
|
||
| from collections.abc import Callable | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from tiatoolbox.models.architecture import ( | ||
| fetch_pretrained_weights, | ||
| get_pretrained_model, | ||
| ) | ||
| from tiatoolbox.models.architecture.grandqc import TissueDetectionModel | ||
| from tiatoolbox.models.engine.io_config import IOSegmentorConfig | ||
| from tiatoolbox.utils.misc import select_device | ||
| from tiatoolbox.wsicore.wsireader import WSIReader | ||
|
|
||
| ON_GPU = False | ||
|
|
||
|
|
||
| def test_functional_grandqc(remote_sample: Callable) -> None: | ||
| """Test for GrandQC model.""" | ||
| # test fetch pretrained weights | ||
| pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10") | ||
| assert pretrained_weights is not None | ||
|
|
||
| # test creation | ||
| model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) | ||
| assert model is not None | ||
|
|
||
| # load pretrained weights | ||
| pretrained = torch.load(pretrained_weights, map_location="cpu") | ||
| model.load_state_dict(pretrained) | ||
|
|
||
| # test get pretrained model | ||
| model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") | ||
| assert isinstance(model, TissueDetectionModel) | ||
| assert isinstance(ioconfig, IOSegmentorConfig) | ||
| assert model.num_input_channels == 3 | ||
| assert model.num_output_channels == 2 | ||
|
|
||
| # test inference | ||
| mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) | ||
| reader = WSIReader.open(mini_wsi_svs) | ||
| read_kwargs = {"resolution": 10.0, "units": "mpp", "coord_space": "resolution"} | ||
| batch = np.array( | ||
| [ | ||
| reader.read_bounds((0, 0, 512, 512), **read_kwargs), | ||
| reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), | ||
| ], | ||
| ) | ||
| batch = torch.from_numpy(batch) | ||
| output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) | ||
| assert output.shape == (2, 512, 512, 2) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| """Define GrandQC Tissue Detection Model architecture.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Mapping | ||
|
|
||
| import cv2 | ||
| import numpy as np | ||
| import segmentation_models_pytorch as smp | ||
| import torch | ||
|
|
||
| from tiatoolbox.models.models_abc import ModelABC | ||
|
|
||
|
|
||
| class TissueDetectionModel(ModelABC): | ||
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """GrandQC Tissue Detection Model. | ||
| Example: | ||
| >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor | ||
| >>> semantic_segmentor = SemanticSegmentor( | ||
| ... model="grandqc_tissue_detection_mpp10", | ||
| ... ) | ||
| >>> results = semantic_segmentor.run( | ||
| ... ["/example_wsi.svs"], | ||
| ... masks=None, | ||
| ... auto_get_mask=False, | ||
| ... patch_mode=False, | ||
| ... save_dir=Path("/tissue_mask/"), | ||
| ... output_type="annotationstore", | ||
| ... ) | ||
| """ | ||
Jiaqi-Lv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self: TissueDetectionModel, num_input_channels: int, num_output_channels: int | ||
| ) -> None: | ||
| """Initialize TissueDetectionModel.""" | ||
| super().__init__() | ||
| self.num_input_channels = num_input_channels | ||
| self.num_output_channels = num_output_channels | ||
| self._postproc = self.postproc | ||
| self._preproc = self.preproc | ||
| self.tissue_detection_model = smp.UnetPlusPlus( | ||
| encoder_name="timm-efficientnet-b0", | ||
| encoder_weights=None, | ||
| in_channels=self.num_input_channels, | ||
| classes=self.num_output_channels, | ||
| activation=None, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def preproc(image: np.ndarray) -> np.ndarray: | ||
| """Apply jpg compression then ImageNet normalise.""" | ||
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] | ||
| _, compressed_image = cv2.imencode(".jpg", image, encode_param) | ||
| compressed_image = np.array(cv2.imdecode(compressed_image, 1)) | ||
|
|
||
Jiaqi-Lv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mean = np.array([0.485, 0.456, 0.406]) | ||
| std = np.array([0.229, 0.224, 0.225]) | ||
| return (compressed_image / 255.0 - mean) / std | ||
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @staticmethod | ||
| def postproc(image: np.ndarray) -> np.ndarray: | ||
| """Define post-processing for this model. | ||
| This returns the class index with the minimum probability. | ||
| In this model, this means selecting tissue class. | ||
| """ | ||
| return image.argmin(axis=-1) | ||
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def forward( | ||
| self: TissueDetectionModel, | ||
| imgs: torch.Tensor, | ||
| *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 | ||
| **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 | ||
| ) -> torch.Tensor: | ||
| """Forward function for model.""" | ||
| return self.tissue_detection_model(imgs) | ||
|
|
||
| @staticmethod | ||
| def infer_batch( | ||
| model: torch.nn.Module, | ||
| batch_data: torch.Tensor, | ||
| *, | ||
| device: str, | ||
| ) -> np.ndarray: | ||
| """Run inference on an input batch. | ||
| This contains logic for forward operation as well as i/o | ||
| Args: | ||
| model (nn.Module): | ||
| PyTorch defined model. | ||
| batch_data (:class:`torch.Tensor`): | ||
| A batch of data generated by | ||
| `torch.utils.data.DataLoader`. | ||
| device (str): | ||
| Transfers model to the specified device. Default is "cpu". | ||
Jiaqi-Lv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Returns: | ||
| np.ndarray: | ||
| The inference results as a numpy array. | ||
| """ | ||
| model.eval() | ||
|
|
||
| #### | ||
Jiaqi-Lv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| imgs = batch_data | ||
|
|
||
| imgs = imgs.to(device).type(torch.float32) | ||
| imgs = imgs.permute(0, 3, 1, 2) # to NCHW | ||
|
|
||
| with torch.inference_mode(): | ||
| logits = model(imgs) | ||
| probs = torch.nn.functional.softmax(logits, 1) | ||
| probs = probs.permute(0, 2, 3, 1) # to NHWC | ||
|
|
||
| return probs.cpu().numpy() | ||
|
|
||
| def load_state_dict( | ||
| self: TissueDetectionModel, | ||
| state_dict: Mapping[str, Any], | ||
| **kwargs: bool, | ||
| ) -> torch.nn.modules.module._IncompatibleKeys: | ||
| """Load state dict for the TissueDetectionModel.""" | ||
| return self.tissue_detection_model.load_state_dict(state_dict, **kwargs) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.