Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 85 additions & 48 deletions tests/models/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,73 @@
from typing import TYPE_CHECKING

import pytest
import torch
from torch import nn

from tiatoolbox import rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.utils import env_detection as toolbox_env

if TYPE_CHECKING:
import numpy as np


class ProtoRaisesTypeError(ModelABC):
"""Intentionally created to check for TypeError."""

# skipcq
def __init__(self: Proto) -> None:
"""Initialize ProtoRaisesTypeError."""
super().__init__()

@staticmethod
# skipcq
def infer_batch() -> None:
"""Define infer batch."""
# base class definition pass


class ProtoNoPostProcess(ModelABC):
"""Intentionally created to check No Post Processing."""

def forward(self: ProtoNoPostProcess) -> None:
"""Define forward function."""

@staticmethod
# skipcq
def infer_batch() -> None:
"""Define infer batch."""


class Proto(ModelABC):
"""Intentionally created to check error."""

def __init__(self: Proto) -> None:
"""Initialize Proto."""
super().__init__()
self.dummy_param = nn.Parameter(torch.empty(0))

@staticmethod
# skipcq
def postproc(image: np.ndarray) -> np.ndarray:
"""Define postproc function."""
return image - 2

# skipcq
def forward(self: Proto) -> None:
"""Define forward function."""

@staticmethod
# skipcq
def infer_batch() -> None:
"""Define infer batch."""
pass # base class definition pass # noqa: PIE790


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not toolbox_env.has_gpu(),
reason="Local test on machine with GPU.",
Expand All @@ -25,67 +82,37 @@ def test_get_pretrained_model() -> None:
get_pretrained_model(pretrained_name, overwrite=True)


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not toolbox_env.has_gpu(),
reason="Local test on CLI",
)
def test_model_to_cuda() -> None:
"""This Test should pass locally if GPU is available."""
# Test on GPU
# no GPU on Travis so this will crash
model = Proto() # skipcq
assert model.dummy_param.device.type == "cpu"
model = model.to(device="cuda")
assert isinstance(model, nn.Module)
assert model.dummy_param.device.type == "cuda"


def test_model_abc() -> None:
"""Test API in model ABC."""
# test missing definition for abstract
with pytest.raises(TypeError):
# crash due to not defining forward, infer_batch, postproc
ModelABC() # skipcq

# intentionally created to check error
# skipcq
class Proto(ModelABC):
# skipcq
def __init__(self: Proto) -> None:
super().__init__()

@staticmethod
# skipcq
def infer_batch() -> None:
pass # base class definition pass

# skipcq
with pytest.raises(TypeError):
# crash due to not defining forward and postproc
Proto() # skipcq
ProtoRaisesTypeError() # skipcq

# intentionally create to check inheritance
# skipcq
class Proto(ModelABC):
# skipcq
def forward(self: Proto) -> None:
pass # base class definition pass

@staticmethod
# skipcq
def infer_batch() -> None:
pass # base class definition pass

model = Proto()
model = ProtoNoPostProcess()
assert model.preproc(1) == 1, "Must be unchanged!"
assert model.postproc(1) == 1, "Must be unchanged!"

# intentionally created to check error
# skipcq
class Proto(ModelABC):
# skipcq
def __init__(self: Proto) -> None:
super().__init__()

@staticmethod
# skipcq
def postproc(image: np.ndarray) -> None:
return image - 2

# skipcq
def forward(self: Proto) -> None:
pass # base class definition pass

@staticmethod
# skipcq
def infer_batch() -> None:
pass # base class definition pass

model = Proto() # skipcq
# test assign un-callable to preproc_func/postproc_func
with pytest.raises(ValueError, match=r".*callable*"):
Expand All @@ -111,3 +138,13 @@ def infer_batch() -> None:
# coverage setter check
model.postproc_func = None # skipcq: PYL-W0201
assert model.postproc_func(2) == 0

# Test on CPU
model = model.to(device="cpu")
assert isinstance(model, nn.Module)
assert model.dummy_param.device.type == "cpu"

# Test load_weights_from_file() method
weights_path = fetch_pretrained_weights("alexnet-kather100k")
with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
_ = model.load_weights_from_file(weights_path)
57 changes: 54 additions & 3 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable

from torch import nn
import torch
from torch import device as torch_device

if TYPE_CHECKING: # pragma: no cover
from pathlib import Path

import numpy as np


Expand All @@ -31,7 +34,7 @@ def output_resolutions(self: IOConfigABC) -> None:
raise NotImplementedError


class ModelABC(ABC, nn.Module):
class ModelABC(ABC, torch.nn.Module):
"""Abstract base class for models used in tiatoolbox."""

def __init__(self: ModelABC) -> None:
Expand All @@ -48,7 +51,12 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:

@staticmethod
@abstractmethod
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> None:
def infer_batch(
model: torch.nn.Module,
batch_data: np.ndarray,
*,
on_gpu: bool,
) -> None:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.
Expand Down Expand Up @@ -135,3 +143,46 @@ def postproc_func(self: ModelABC, func: Callable) -> None:
self._postproc = self.postproc
else:
self._postproc = func

def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
"""Transfers model to cpu/gpu.

Args:
model (torch.nn.Module):
PyTorch defined model.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
torch.nn.Module:
The model after being moved to cpu/gpu.

"""
device = torch_device(device)
model = super().to(device)

# If target device istorch.cuda and more
# than one GPU is available, use DataParallel
if device.type == "cuda" and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) # pragma: no cover

return model

def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Module:
"""Helper function to load a torch model.

Args:
self (ModelABC):
A torch model as :class:`ModelABC`.
weights (str or Path):
Path to pretrained weights.

Returns:
torch.nn.Module:
Torch model with pretrained weights loaded on CPU.

"""
# ! assume to be saved in single GPU mode
# always load on to the CPU
saved_state_dict = torch.load(weights, map_location="cpu")
return super().load_state_dict(saved_state_dict, strict=True)