diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst index 4c075c91..3ca0f136 100644 --- a/docs/source/_templates/autosummary/class.rst +++ b/docs/source/_templates/autosummary/class.rst @@ -3,9 +3,8 @@ .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} - {% if objname != 'ValidDataset' %}:members:{% endif %} - {% if objname != 'ValidDataset' %}:inherited-members:{% endif %} - {% if objname == 'ValidBboxAnnotationsDataFrame' %}:exclude-members: Config{% endif %} + :members: + :show-inheritance: {% block methods %} {% set ns = namespace(has_public_methods=false) %} @@ -24,11 +23,13 @@ .. autosummary:: {% for item in methods %} {% if not item.startswith('_') %} - ~{{ name }}.{{ item }} + {{ item|is_own_method(name, module) }} {% endif %} {%- endfor %} {% endif %} {% endblock %} + .. rubric:: {{ _('Details') }} + .. minigallery:: {{ module }}.{{ objname }} :add-heading: Examples using ``{{ objname }}`` diff --git a/docs/source/conf.py b/docs/source/conf.py index eff49ef4..4a4ede24 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,9 +1,12 @@ """Sphinx configuration for ethology documentation.""" +import inspect import os import sys +from importlib import import_module from importlib.metadata import version as get_version +from jinja2.filters import FILTERS from sphinx_gallery import sorting # Used when building API docs, put the dependencies @@ -68,6 +71,7 @@ # Automatically add anchors to markdown headings myst_heading_anchors = 4 +# -------- Autosummary # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -76,6 +80,47 @@ autosummary_generate_overwrite = False autodoc_default_options = {"show-inheritance": True} # applies to all classes + +def is_own_method(method_name, obj, module_name): + """Check if a method is defined in the class itself (not inherited). + + Returns the method reference string if it's defined in the class, + empty string otherwise. + """ + module = import_module(module_name) + if hasattr(module, "__all__") and obj not in module.__all__: + return "" + + cls = getattr(module, obj) + if not inspect.isclass(cls): + return "" + + # Check if method is defined in this class (not inherited) + if hasattr(cls, method_name): + # Check if it's in the class's __dict__ (defined in this class) + if method_name in cls.__dict__: + return f"~{obj}.{method_name}" + # Or check using inspect to see if it's defined in this class + try: + method = getattr(cls, method_name) + if inspect.ismethod(method) or inspect.isfunction(method): + # Check if the method's defining class is this class + if hasattr(method, "__qualname__"): + qualname_parts = method.__qualname__.split(".") + if len(qualname_parts) >= 2 and qualname_parts[-2] == obj: + return f"~{obj}.{method_name}" + # Fallback: check if it's in __dict__ + if method_name in cls.__dict__: + return f"~{obj}.{method_name}" + except (AttributeError, TypeError): + pass + + return "" + + +FILTERS["is_own_method"] = is_own_method + +# ------------- # Prefix section labels with the document name autosectionlabel_prefix_document = True @@ -197,6 +242,9 @@ "https://python-jsonschema.readthedocs.io/en/stable/", None, ), + "torch": ("https://pytorch.org/docs/stable/", None), + "torchvision": ("https://pytorch.org/vision/stable/", None), + "lightning": ("https://lightning.ai/docs/pytorch/stable/", None), } @@ -221,7 +269,6 @@ # sphinx-gallery configuration - sphinx_gallery_conf = { "examples_dirs": ["../../examples"], "within_subsection_order": sorting.ExplicitOrder( diff --git a/ethology/datasets/inference.py b/ethology/datasets/inference.py new file mode 100644 index 00000000..e52b428f --- /dev/null +++ b/ethology/datasets/inference.py @@ -0,0 +1,143 @@ +"""Datasets and related utilities for inference without ground-truth.""" + +from collections.abc import Callable +from pathlib import Path + +import torch +import torchvision.transforms.v2 as transforms +from PIL import Image +from torch.utils.data import Dataset + + +class InferenceImageDataset(Dataset): + """A simple dataset for images with no ground-truth annotations. + + The image files are sorted alphabetically. The annotations dictionary + returned by ``__getitem__`` is always empty to maintain a consistent + interface with training datasets. + + Parameters + ---------- + images_dir + Path to the root directory containing the images. + file_pattern + Glob pattern to match image filenames (e.g., "*.png", "*.jpg"). + transforms + Transforms to apply to the images. If `None` (default), the + transforms from :func:`get_default_inference_transforms` are used. + + Attributes + ---------- + images_dir : pathlib.Path + Path to the root directory containing the images. + transforms : torchvision.transforms.v2.Compose + Transforms to apply to the images. + image_files : list[pathlib.Path] + List of paths to each of the image files, sorted + alphabetically. + + See Also + -------- + get_default_inference_transforms : Returns default transforms for + inference. + + Examples + -------- + Create a dataset from 100 ``.png`` files in the ``/path/to/images`` + directory: + + >>> from ethology.datasets.inference import InferenceImageDataset + >>> dataset = InferenceImageDataset( + ... images_dir="/path/to/images", + ... file_pattern="*.png", + ... ) + >>> len(dataset) + 100 + + """ + + def __init__( + self, + images_dir: Path | str, + file_pattern: str, + transforms: transforms.Compose | None = None, + ): + """Initialise dataset.""" + self.images_dir = Path(images_dir) + self.transforms = ( + transforms + if transforms is not None + else get_default_inference_transforms() + ) + self.image_files = sorted(self.images_dir.glob(file_pattern)) + + def __len__(self) -> int: + """Return the number of images in the dataset.""" + return len(self.image_files) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]: + """Return the image and an empty annotations dictionary. + + Parameters + ---------- + idx : int + Index of the image to retrieve. + + Returns + ------- + tuple[torch.Tensor, dict] + A tuple containing the image as a tensor and an empty + annotations dictionary. + + """ + # Open requested image + img_path = Path(self.images_dir) / self.image_files[idx] + image = Image.open(img_path).convert("RGB") + + # If transforms are specified, apply to the image + if self.transforms: + image = self.transforms(image) + return image, {} + + +def get_default_inference_transforms() -> transforms.Compose: + """Return the default transforms for inference. + + Transforms the input image to a tensor and scales the pixel + values from ``[0, 255]`` (uint8) to ``[0, 1]`` (float32). + + Returns + ------- + torchvision.transforms.v2.Compose + The default transforms for inference. + + """ + return transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + ) + + +def get_detector_collate_fn() -> Callable: + """Return collate function for a detector. + + It supports images and annotations of different sizes. + A collate function takes a list of samples from a dataset + and batches them for the model to process it. Torch detectors + expect a list/tuple of images and annotations, since they can + be of different sizes across the dataset. + + Returns + ------- + Callable + A collate function for detector models. + + """ + + def collate_fn(dataset_samples): + """Return a tuple of tuples: (images_tuple, annots_tuple).""" + return tuple(zip(*dataset_samples, strict=True)) + + return collate_fn diff --git a/ethology/detectors/models.py b/ethology/detectors/models.py new file mode 100644 index 00000000..02e169a1 --- /dev/null +++ b/ethology/detectors/models.py @@ -0,0 +1,592 @@ +"""PyTorch Lightning modules for detectors.""" + +import difflib +from itertools import chain +from typing import Any + +import numpy as np +import torch +import xarray as xr +from lightning import LightningModule, Trainer +from torch.utils.data import DataLoader +from torchvision.models import get_model +from torchvision.models.detection import faster_rcnn, fcos, retinanet + +from ethology.detectors.utils import ( + _pad_to_max_first_dimension, + corners_to_centroid_shape, +) +from ethology.validators.detections import ValidBboxDetectionsDataset +from ethology.validators.utils import _check_output + +# Registry of supported models with their constructors +MODEL_CONSTRUCTORS_REGISTRY = { + "fasterrcnn_resnet50_fpn_v2": faster_rcnn.fasterrcnn_resnet50_fpn_v2, + "fasterrcnn_mobilenet_v3_large_fpn": ( + faster_rcnn.fasterrcnn_mobilenet_v3_large_fpn + ), + "fcos_resnet50_fpn": fcos.fcos_resnet50_fpn, + "retinanet_resnet50_fpn_v2": retinanet.retinanet_resnet50_fpn_v2, +} + +# Default number of classes in torchvision detection models trained on COCO2017 +# Can verify with: +# from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights +# len(FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT.meta['categories']) +DEFAULT_NUM_CLASSES = 91 + + +class ObjectDetector(LightningModule): + """LightningModule for `torchvision detection models `_. + + Supports Faster R-CNN, RetinaNet, and FCOS architectures. + This module is intended for inference only. + + Parameters + ---------- + config : dict + Configuration of the model. Expected keys: + + - **model_class** (*str*) -- + Name of the model to initialise. Should be one of + ``fasterrcnn_resnet50_fpn_v2``, + ``fasterrcnn_mobilenet_v3_large_fpn``, + ``fcos_resnet50_fpn``, or ``retinanet_resnet50_fpn_v2``. + + - **model_kwargs** (*dict*) -- + Keyword arguments to pass to the model constructor. See + the `torchvision detection models docs + `_ + for possible values for each supported model. + + All models support ``num_classes`` as a keyword argument. If not + specified, it defaults to 91 (the number of COCO2017 categories). + See the Notes section for details on how weights are initialised + when requesting COCO2017 pretrained weights for a custom number + of classes. + + - **checkpoint** (*str, Path or None*) -- + Path to the trained model checkpoint. If provided, model weights + are loaded entirely from the checkpoint file. The checkpoint must + match the architecture specified by ``model_class`` and the number + of classes in ``model_kwargs``. If ``None``, the model is + initialised using pretrained COCO2017 weights. + + Attributes + ---------- + config : dict + The configuration dictionary passed to the constructor. + model : torch.nn.Module + The object detector model. + model_params : dict + The parameters used to construct the selected model, with defaults + applied. + + Raises + ------ + TypeError + If ``config`` is not a dictionary, or if ``model_kwargs`` is provided + but is not a dictionary. + ValueError + If ``model_class`` is not defined in the ``config``, or if it is not + supported. See the Parameters section for the supported model classes. + Also if ``model_kwargs`` contains a key that is a misspelt variant + of ``num_classes`` (e.g. ``n_classes``). + + Notes + ----- + For two of the models using a ResNet backbone (Faster-RCNN and RetinaNet), + we use the improved ``v2`` versions from ``torchvision`` (see for + `FasterRCNN \ + `_ + and for + `RetinaNet `_). + + We cover the following cases for weights initialisation. If no checkpoint + is provided and... + + - ``num_classes`` is not specified (or is 91): pretrained COCO2017 weights + are loaded for both backbone and detection head. + - a custom ``num_classes`` is used: pretrained backbone weights are + retained and class-dependent layers are initialised with random weights. + + For RetinaNet and FCOS, only the classification head is replaced when using + a custom number of classes, since bounding box regression is + class-agnostic. For Faster R-CNN, the entire box predictor + (classification and regression) is replaced since bounding box regression + is class-specific. + + If a checkpoint is provided, all weights are loaded from the checkpoint. + Users must ensure that ``num_classes`` matches the number of classes the + checkpoint was trained with, otherwise loading will fail due to shape + mismatches. + + Examples + -------- + Initialise a FCOS model pretrained on COCO2017 with the default + 91 classes: + + >>> from ethology.detectors.models import ObjectDetector + >>> model = ObjectDetector({"model_class": "fcos_resnet50_fpn"}) + + Initialise a FCOS model for three classes (background included), + reusing COCO2017 weights wherever possible and randomly initialising + class-dependent layers: + + >>> from ethology.detectors.models import ObjectDetector + >>> model = ObjectDetector( + ... { + ... "model_class": "fcos_resnet50_fpn", + ... "model_kwargs": {"num_classes": 3}, + ... } + ... ) + + Initialise a Faster R-CNN model with the default number of classes + (91) from a saved checkpoint: + + >>> from ethology.detectors.models import ObjectDetector + >>> config = { + ... "model_class": "fasterrcnn_resnet50_fpn_v2", + ... "checkpoint": "/path/to/checkpoint.ckpt", + ... } + >>> model = ObjectDetector(config) + + Initialise a Faster R-CNN model with two classes (background included) + from a saved checkpoint: + + >>> from ethology.detectors.models import ObjectDetector + >>> config = { + ... "model_class": "fasterrcnn_resnet50_fpn_v2", + ... "model_kwargs": {"num_classes": 2}, + ... "checkpoint": "/path/to/checkpoint/two/classes.ckpt", + ... } + >>> model = ObjectDetector(config) + + """ + + def __init__(self, config: dict[str, Any]): + """Initialise object detector for the given configuration.""" + super().__init__() + self.config = self._validate_config(config) + self.model = self._configure_model() + + # save all arguments passed to __init__ to + # hparams attribute + self.save_hyperparameters() + + @staticmethod + def _validate_config(config: dict): + """Validate config dict for detector.""" + # Check config is a dictionary + if not isinstance(config, dict): + raise TypeError( + "config must be a dictionary", + f"but got {type(config).__name__}", + ) + + # model_class should always be defined + if "model_class" not in config: + raise ValueError("model_class must be defined in config") + + # Check if model_class is supported if defined + if config["model_class"] not in MODEL_CONSTRUCTORS_REGISTRY: + model_class = config["model_class"] + raise ValueError( + f"Model '{model_class}' not supported. " + f"Available: {list(MODEL_CONSTRUCTORS_REGISTRY.keys())}" + ) + + # Check model_kwargs type is dict if provided + if "model_kwargs" in config: + if not isinstance(config["model_kwargs"], dict): + raise TypeError( + f"model_kwargs must be a dict, got " + f"{type(config['model_kwargs']).__name__}" + ) + + # Check for misspellings of the main model_kwargs passed + # to the torchvision constructors + # (torchvision does not report errors on this) + list_fuzzy_matches = [ + "weights", + "progress", + "num_classes", + "weights_backbone", + "trainable_backbone_layers", + ] + for key in config["model_kwargs"]: + if key not in list_fuzzy_matches: + close_matches = difflib.get_close_matches( + key, list_fuzzy_matches + ) + if close_matches: + raise ValueError( + f"Invalid key '{key}' in model_kwargs. " + f"Did you mean '{close_matches[0]}'?" + ) + + return config + + def _configure_model(self) -> torch.nn.Module: + """Initialise model from ckpt if provided, else from pretrained.""" + # Extract model params as attributes + self._model_class = self.config.get("model_class") + self.model_params = self.config.get("model_kwargs", {}) + + # Set num_classes to default if not set + if "num_classes" not in self.model_params: + self.model_params["num_classes"] = DEFAULT_NUM_CLASSES + + # Delegate to the appropriate function + if "checkpoint" not in self.config: + model = self._configure_model_pretrained() + else: + model = self._configure_model_from_checkpoint( + str(self.config["checkpoint"]) + ) + + return model + + def _configure_model_pretrained(self) -> torch.nn.Module: + """Load pretrained weights into model. + + Default weights are used when possible. If there is a shape mismatch + in the layers, the weights are initialised with random weights. + """ + # Load selected model with pretrained weights in backbone and head + model = MODEL_CONSTRUCTORS_REGISTRY[self._model_class]( + weights="COCO_V1" # equivalent to weights="DEFAULT" + ) + + # Adapt model if there is a mismatch with the requested number of + # classes + n_classes_model = _get_n_classes_in_detector( + model, self._model_class + ) # shape of loaded model + if self.model_params["num_classes"] != n_classes_model: + # Keep as much as possible from the bbox prediction head + if "fasterrcnn" in self._model_class: + # Reinitialise box predictor for the required number of classes + # (both cls_score and bbox_pred are reinitialised) + # Note: in Faster R-CNN, the bbox regression is class-specific; + # it learns a different way of refining bboxes for each class. + # So we need to reinitialise the full box predictor if the + # number of classes is different from COCO2017. + in_features = ( + model.roi_heads.box_predictor.cls_score.in_features + ) + model.roi_heads.box_predictor = faster_rcnn.FastRCNNPredictor( + in_features, + self.model_params["num_classes"], + ) + elif "retinanet" in self._model_class: + # In retinanet bbox regression is class-agnostic, so we can + # retain it + in_channels = model.head.classification_head.conv[0][ + 0 + ].in_channels + num_anchors = model.head.classification_head.num_anchors + model.head.classification_head = ( + retinanet.RetinaNetClassificationHead( + in_channels, + num_anchors, + self.model_params["num_classes"], + ) + ) + + elif "fcos" in self._model_class: + # In fcos bbox regression is class-agnostic, so we can retain + # it + in_channels = model.head.classification_head.conv[ + 0 + ].in_channels + num_anchors = model.head.classification_head.num_anchors + model.head.classification_head = fcos.FCOSClassificationHead( + in_channels, + num_anchors, + self.model_params["num_classes"], + ) + + return model + + def _configure_model_from_checkpoint( + self, checkpoint_path: str + ) -> torch.nn.Module: + """Load weights from checkpoint into model.""" + # Get checkpoint + checkpoint_dict = torch.load(checkpoint_path, map_location=self.device) + + # Instantiate model + model = get_model(self._model_class, **self.model_params) + + # Get state dict from checkpoint and load into model + model_state_dict = self._get_model_state_dict(checkpoint_dict) + model.load_state_dict(model_state_dict, strict=True) + return model + + # ------ Convenience functions -------------- + @staticmethod + def _get_model_state_dict(checkpoint: dict) -> dict: + """Get model state dict from checkpoint dictionary. + + The checkpoint dictionary is expected to be in one of the following: + - A dictionary with the state dictionary itself (torch flat + convention). + - A dictionary with a "state_dict" key containing the state dictionary + (torch nested convention). + - A dictionary with a "state_dict" key containing the state dictionary + where each key has a "model." prefix (Lightning convention). + """ + # Get the state_dict key if it exists, + # otherwise use the checkpoint itself as the state dict + state_dict = checkpoint.get("state_dict", checkpoint) + + # Remove "model." prefix if present (if not, it leaves + # keys unchanged). + # Note: PyTorch Lightning saves the model with a "model." + # prefix in the state_dict keys if you defined self.model + # in your LightningModule + return { + key.removeprefix("model."): value + for key, value in state_dict.items() + } + + # ------- Inference ----------------------- + def predict_step( + self, + batch: tuple[torch.Tensor, dict], + batch_idx: int, + ) -> list[dict[str, torch.Tensor]]: + """Run an inference step on a batch of images. + + Parameters + ---------- + batch + A tuple containing the batch of images and the corresponding + annotations. + batch_idx + The index of the batch. + + Returns + ------- + list[dict[str, torch.Tensor]] + A list of raw predictions as a dictionary, one per image in the + batch and each with the following keys: + + - **"boxes"** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes, 4)`` and floating-point dtype + (typically ``torch.float32``), holding the bounding box corners + ``[x1, y1, x2, y2]`` in pixel coordinates for each detection. + - **"scores"** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes,)`` and floating-point dtype + (typically ``torch.float32``), holding the confidence score for + each detection. + - **"labels"** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes,)`` and dtype ``torch.int64``, + holding the integer label for each detection. + + """ + images_batch, _annotations_batch = batch + raw_prediction_dicts = self.model(images_batch) + + return raw_prediction_dicts + + @_check_output(ValidBboxDetectionsDataset) + def run_inference( + self, + trainer: Trainer, + dataloader: DataLoader, + attrs: dict | None = None, + ) -> xr.Dataset: + """Run inference on the input dataloader. + + Convenience method that wraps + :meth:`Trainer.predict \ + ` + and returns the formatted predictions as an ``ethology`` bounding box + detections dataset. + + Parameters + ---------- + trainer + The Lightning trainer to use for inference. The trainer + object handles device placement, precision settings, and + orchestrating the prediction loop. + dataloader + The dataloader providing the dataset for inference. + attrs + Attributes to add to the ``ethology`` detections dataset. + + Returns + ------- + xarray.Dataset + The predictions for each image in the dataloader, formatted + as an ``ethology`` detections dataset. + + """ + predictions = trainer.predict(self, dataloader) + return self._format_predictions(predictions, attrs=attrs) + + # ------- Formatting ------------------- + @staticmethod + @_check_output(ValidBboxDetectionsDataset) + def _format_predictions( + predictions: list[list[dict[str, torch.Tensor]]], + attrs: dict | None = None, + ) -> xr.Dataset: + """Format predictions as an ``ethology`` detections dataset. + + Parameters + ---------- + predictions : list[list[dict[str, torch.Tensor]]] + The raw predictions to format. The outer list corresponds to + batches, the inner list corresponds to images within a batch. + The dictionaries contain the following keys: + + - **boxes** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes, 4)`` and floating-point dtype + (typically ``torch.float32``), holding the bounding box corners + ``[x1, y1, x2, y2]`` in pixel coordinates for each detection. + - **scores** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes,)`` and floating-point dtype + (typically ``torch.float32``), holding the confidence score for + each detection. + - **labels** (*torch.Tensor*) -- + Tensor of shape ``(n_boxes,)`` and dtype ``torch.int64``, + holding the integer label for each detection. + + attrs : dict | None + Dictionary of attributes to add to the predictions dataset as + ``attrs``. + + Returns + ------- + xr.Dataset + The predictions formatted as an ``ethology`` detections dataset. + + Raises + ------ + TypeError : If predictions is not a list. + ValueError : If predictions list is empty or contains no image data. + + """ + # Check input data + if not isinstance(predictions, list): + raise TypeError( + f"predictions must be a list, got {type(predictions).__name__}" + ) + if len(predictions) == 0: + raise ValueError( + "predictions list is empty. " + "Cannot format an empty predictions list." + ) + + # Flatten output predictions + predictions_dict_per_img = list(chain.from_iterable(predictions)) + + # Check flattened data + if len(predictions_dict_per_img) == 0: + raise ValueError( + "No predictions to format. " + "predictions list contains no image data." + ) + + # Parse output from dicts and convert to numpy arrays + output_per_sample = { + key: [ + sample[key].cpu().numpy() + for sample in predictions_dict_per_img + ] + for key in ["boxes", "scores", "labels"] + } + + # Pad across image_ids + # (note: np.asarray(np.nan).dtype is float64) + fill_value = {"boxes": np.nan, "scores": np.nan, "labels": -1} + output_per_sample_padded = { + key: np.stack( + _pad_to_max_first_dimension(output_per_sample[key], val), + axis=0, + ) + for key, val in fill_value.items() + } + + # Compute centroid and shape arrays + bboxes_array = np.transpose( + output_per_sample_padded["boxes"], (0, -1, 1) + ) + centroid_array, shape_array = corners_to_centroid_shape( + bboxes_array[:, 0:2], bboxes_array[:, 2:4] + ) + + # Return as ethology detections dataset + max_n_detections = bboxes_array.shape[-1] + n_images = bboxes_array.shape[0] + return xr.Dataset( + data_vars={ + "position": ( + ["image_id", "space", "id"], + centroid_array, + ), + "shape": (["image_id", "space", "id"], shape_array), + "confidence": ( + ["image_id", "id"], + output_per_sample_padded["scores"], + ), + "category": ( # labels are renamed as "category" array + ["image_id", "id"], + output_per_sample_padded["labels"], + ), + }, + coords={ + "image_id": np.arange(n_images), + "space": ["x", "y"], + "id": np.arange(max_n_detections), + }, + attrs=attrs if attrs else {}, + ) + + +def _get_n_classes_in_detector( + model: torch.nn.Module, model_class: str +) -> int: + """Extract the number of classes from model based on its architecture. + + Parameters + ---------- + model : torch.nn.Module + The object detector model. + model_class : str + Name of the model architecture (e.g., "fasterrcnn_resnet50_fpn_v2"). + + Returns + ------- + int + The number of classes the model is configured to detect. + + Raises + ------ + ValueError + If the model architecture is not supported. + + """ + if model_class not in MODEL_CONSTRUCTORS_REGISTRY: + raise ValueError(f"Unsupported model class: {model_class}") + if "fasterrcnn" in model_class: + return _get_n_classes_fasterrcnn(model) + else: + # retinanet and fcos use anchor-based classification heads + return _get_n_classes_anchor_based(model) + + +def _get_n_classes_fasterrcnn(model: torch.nn.Module) -> int: + """Get the number of classes from a Faster R-CNN model.""" + return model.roi_heads.box_predictor.cls_score.out_features + + +def _get_n_classes_anchor_based(model: torch.nn.Module) -> int: + """Get the number of classes from an anchor-based model.""" + # In anchor-based detectors, the classification head makes predictions + # for every anchor at each spatial location + cls_head = model.head.classification_head + return cls_head.cls_logits.out_channels // cls_head.num_anchors diff --git a/ethology/detectors/utils.py b/ethology/detectors/utils.py new file mode 100644 index 00000000..4ba0591d --- /dev/null +++ b/ethology/detectors/utils.py @@ -0,0 +1,180 @@ +"""Utility functions for reshaping outputs of ensembles of detectors.""" + +import numpy as np + + +def _get_padding_width(array: np.ndarray, final_first_dim: int) -> list[tuple]: + """Get pad_width to pad the end of an array along the first dimension.""" + # Throw an error if shape mismatch + if array.shape[0] > final_first_dim: + raise ValueError( + "Array has more rows than the requested padded size: " + f"{array.shape[0]} > {final_first_dim}" + ) + pad_width = array.ndim * [(0, 0)] + pad_width[0] = (0, final_first_dim - array.shape[0]) + return pad_width + + +def _pad_to_max_first_dimension( + list_arrays: list[np.ndarray], fill_value=np.nan +) -> list[np.ndarray]: + """Pad arrays in list to maximum size of their first dimension.""" + max_first_dimension = max(array.shape[0] for array in list_arrays) + + # Check for dtype compatibility between fill_value and arrays + # (convert fill_value to numpy scalar/array to get its dtype first) + for i, arr in enumerate(list_arrays): + if not np.can_cast( + np.asarray(fill_value).dtype, arr.dtype, "same_kind" + ): + raise TypeError( + f"Cannot pad array (index {i}, dtype={arr.dtype}) " + f"with fill_value={fill_value!r} " + f"(type={type(fill_value).__name__}). " + f"Ensure fill_value is compatible with array dtype." + ) + + list_arrays_padded = [ + np.pad( + arr, + _get_padding_width(arr, max_first_dimension), + mode="constant", + constant_values=fill_value, + ) + for arr in list_arrays + ] + return list_arrays_padded + + +def centroid_shape_to_corners( + centroid: np.ndarray, shape: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Convert box centroid and shape arrays to x1y1, x2y2 corner arrays. + + The function assumes all coordinates are expressed in an image coordinate + system whose origin is at the centre of the top-left pixel in the image, + its x coordinate values increase from left to right of the image, + and its y coordinate values increase from top to bottom. + + Parameters + ---------- + centroid + Array of bounding box centroid coordinates with shape (..., 2, ...), + where the second dimension contains (x, y) coordinates in the image + coordinate system. + shape + Array of bounding box dimensions with shape (..., 2, ...), where the + second dimension contains (width, height) values, in the same units as + the ``centroid`` array. + + Returns + ------- + x1y1 : numpy.ndarray + Array of bounding box top-left corner coordinates + with shape (..., 2,..), where the second dimension contains (x, y) + coordinates. The top-left corner is the corner of the bounding box + with minimum x and y coordinates in the image coordinate system. + x2y2 : numpy.ndarray + Array of bottom-right corner coordinates with shape (..., 2,..), where + the second dimension contains (x, y) coordinates. The bottom-right + corner is the corner of the bounding box with maximum x and y + coordinates in the image coordinate system. + + Raises + ------ + ValueError + If ``position`` and ``shape`` arrays have different shapes, or + if any of their second dimensions is not 2. + + See Also + -------- + corners_to_centroid_shape : Inverse operation. + + """ + # Check position and shape have compatible shapes + if centroid.shape != shape.shape: + raise ValueError( + f"position and shape must have the same shape, " + f"got {centroid.shape} and {shape.shape}" + ) + + # Check size of second dimension is 2D + if centroid.shape[1] != 2 or shape.shape[1] != 2: + raise ValueError( + "Dimension at index 1 must be 2 " + "for both position and shape arrays, " + f"but got position: {centroid.shape}, shape: {shape.shape}" + ) + + half_shape = shape / 2 + return ( + centroid - half_shape, # x1y1, top-left corner + centroid + half_shape, # x2y2, bottom-right corner + ) + + +def corners_to_centroid_shape( + x1y1: np.ndarray, x2y2: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Convert x1y1, x2y2 box corner arrays to centroid and shape arrays. + + The function assumes all coordinates are expressed in an image coordinate + system whose origin is at the centre of the top-left pixel in the image, + its x coordinate values increase from left to right of the image, + and its y coordinate values increase from top to bottom. + + Parameters + ---------- + x1y1 + Array of bounding box top-left corner coordinates + with shape (..., 2,..), where the second dimension contains (x, y) + coordinates. The top-left corner is the corner of the bounding box + with minimum x and y coordinates in the image coordinate system. + x2y2 + Array of bottom-right corner coordinates with shape (..., 2,..), where + the second dimension contains (x, y) coordinates. The bottom-right + corner is the corner of the bounding box with maximum x and y + coordinates in the image coordinate system. + + Returns + ------- + centroid : numpy.ndarray + Array of bounding box centroid coordinates with shape (..., 2, ...), + where the second dimension contains (x, y) coordinates in the image + coordinate system. + + shape : numpy.ndarray + Array of bounding box dimensions with shape (..., 2, ...), where the + second dimension contains (width, height) values, in the same units as + the ``centroid`` array. + + Raises + ------ + ValueError + If x1y1 and x2y2 have different shapes, or + if any of their second dimensions is not 2. + + See Also + -------- + centroid_shape_to_corners : Inverse operation. + + """ + # Check x1y1 and x2y2 have compatible shapes + if x1y1.shape != x2y2.shape: + raise ValueError( + f"x1y1 and x2y2 must have the same shape, " + f"got x1y1: {x1y1.shape}, x2y2: {x2y2.shape}" + ) + + # Check dimension at index 1 is 2D + if x1y1.shape[1] != 2 or x2y2.shape[1] != 2: + raise ValueError( + f"Dimension at index 1 must be 2 for both x1y1 and x2y2, " + f"but got x1y1: {x1y1.shape}, x2y2: {x2y2.shape}" + ) + + return ( + 0.5 * (x1y1 + x2y2), # centroid + x2y2 - x1y1, # shape + ) diff --git a/ethology/validators/annotations.py b/ethology/validators/annotations.py index 6a02f10a..fc1c6b88 100644 --- a/ethology/validators/annotations.py +++ b/ethology/validators/annotations.py @@ -445,6 +445,12 @@ class ValidBboxAnnotationsCOCO(pa.DataFrameModel): """ + class Config: + """Pandera configuration for this schema.""" + + # Allow automatic type coercion (e.g., float32 -> float64) + coerce = True + # index idx: Index[int] = pa.Field(ge=0, check_name=False) diff --git a/examples/inference_with_trained_detector.py b/examples/inference_with_trained_detector.py new file mode 100644 index 00000000..799bc3ac --- /dev/null +++ b/examples/inference_with_trained_detector.py @@ -0,0 +1,164 @@ +"""Run inference with a trained detector +======================================== + +Run inference with a trained detector on a dataset of images for proofreading. +""" + +# %% +# This example demonstrates how to run inference with a trained detector on a +# dataset of images for later proofreading, for example using the VIA +# annotation tool. + +# %% +# Imports +# ------- +import os +from datetime import datetime +from pathlib import Path + +import pooch +from lightning import Trainer +from torch.utils.data import DataLoader +from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights + +from ethology.datasets.inference import ( + InferenceImageDataset, + get_default_inference_transforms, + get_detector_collate_fn, +) +from ethology.detectors.models import ObjectDetector +from ethology.io.annotations import save_bboxes + +# For interactive plots: install ipympl with `pip install ipympl` and uncomment +# the following line in your notebook +# %matplotlib widget + + +# %% +# Download dataset +# ----------------- +# Source of the dataset +data_source = { + "url": "https://storage.googleapis.com/public-datasets-lila/uas-imagery-of-migratory-waterfowl/uas-imagery-of-migratory-waterfowl.20240220.zip", + "hash": "c5b8dfc5a87ef625770ac8f22335dc9eb8a67688b610490a029dae81815a9896", +} + +# Define cache directory +ethology_cache = Path.home() / ".ethology" +ethology_cache.mkdir(exist_ok=True) + +# Download the dataset to the cache directory +extracted_files = pooch.retrieve( + url=data_source["url"], + known_hash=data_source["hash"], + fname="waterfowl_dataset.zip", + path=ethology_cache, + processor=pooch.Unzip(extract_dir=ethology_cache), +) + +data_dir = ethology_cache / "uas-imagery-of-migratory-waterfowl" + + +# %% +# Prepare dataset for inference +# ----------------------------- + +# Create dataset + +images_dir = data_dir / "experts" / "images" +dataset = InferenceImageDataset( + images_dir, + "*.jpg", + transforms=get_default_inference_transforms(), +) + +# %% +# Create dataloader +# --------------------- + +# The default collate function for the dataloader +# stacks all images (torch.stack([img1, img2])), +# which fails if images have different sizes. +# We use a detector collate fn. + +# Create dataloader for detector +dataloader = DataLoader( + dataset, + batch_size=12, # 12, + shuffle=False, + num_workers=8, # 4 + collate_fn=get_detector_collate_fn(), +) + +# %% +# Prepare model and trainer +# ------------------------- + +# Pretrained +detector = ObjectDetector({"model_class": "fasterrcnn_resnet50_fpn_v2"}) + + +# Instantiate trainer +trainer = Trainer( + accelerator="cpu", # recommended gpu if available + devices=1, + logger=False, +) + +# %% +# Define dataset attrs to add to predictions +# ------------------------------------------- + +# We need to add `map_category_to_str` as a dataset attribute +# to be able to export the predictions as COCO + +# %% +# Retrieve list of categories used in torchvision models +# trained on COCO2017 +list_category_str = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT.meta[ + "categories" +] + +ds_attrs = { + "images_dir": images_dir, + "map_image_id_to_filename": { + id: filename.relative_to(images_dir) + for id, filename in enumerate(dataset.image_files) + }, + "map_category_to_str": {k: cat for k, cat in enumerate(list_category_str)}, +} # required for COCO export + +# %% +# Run inference using model on dataloader +# ---------------------------------------- +# The predictions are formatted as an ``ethology`` detections dataset. + +# Run inference using model on dataloader +predictions_ds = detector.run_inference(trainer, dataloader, attrs=ds_attrs) + + +# %% +# Export predictions as COCO annotations for proofreading +# --------------------------------------------------------- +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +out_file = save_bboxes.to_COCO_file( + predictions_ds, output_filepath=f"out_{timestamp}.json" +) + +# %% +# Load proofread annotations and compare +# --------------------------------------- + +# proofread_ds = load_bboxes.from_files( +# "via_project_23Dec2025_15h25m_coco.json", +# format="COCO", +# ) + +# %% +# Clean-up +# --------- +# To remove the output files we have just created, we can run the following: + +os.remove(out_file) + +# %% diff --git a/pyproject.toml b/pyproject.toml index 1ae41b6a..effe7777 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "torch", "torchvision", "loguru", + "lightning>=2.6.0", "tables>=3.10.1" # wheels compatible with MacOS ARM ] diff --git a/tests/test_unit/test_datasets/__init__.py b/tests/test_unit/test_datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_datasets/test_inference.py b/tests/test_unit/test_datasets/test_inference.py new file mode 100644 index 00000000..b9e8edfd --- /dev/null +++ b/tests/test_unit/test_datasets/test_inference.py @@ -0,0 +1,141 @@ +import pytest +import torch +import torchvision.transforms.v2 as transforms +from PIL import Image + +from ethology.datasets.inference import ( + InferenceImageDataset, + get_default_inference_transforms, + get_detector_collate_fn, +) + + +def test_get_default_inference_transforms(): + """Test default transforms produce expected output.""" + out_transforms = get_default_inference_transforms() + + # Create a dummy image and transform it + dummy_img = Image.new( + "RGB", + (10, 10), + color=(128, 128, 128), + ) + result = out_transforms(dummy_img) + + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.min() >= 0.0 and result.max() <= 1.0 # scaled + + +def test_get_detector_collate_fn(): + """Test collate_fn for detectors returns batch in expected format.""" + # Simulate output from a dataset with no annotations (empty dict) + list_dataset_samples = [ + (torch.zeros((3, 24, 24)), {}), # C, H, W + (torch.zeros((3, 100, 200)), {}), + (torch.zeros((3, 12, 12)), {}), + ] + + # Run thru detector collate_fn + collate_fn = get_detector_collate_fn() + batched_data = collate_fn(list_dataset_samples) + + # Check output is a (image_tuple, annots_tuple) + assert isinstance(batched_data, tuple) + assert len(batched_data) == 2 # (images_tuple, annots_tuple) + assert isinstance(batched_data[0], tuple) + assert isinstance(batched_data[1], tuple) + + # Check number of samples in batch + assert len(batched_data[0]) == len(list_dataset_samples) + assert len(batched_data[1]) == len(list_dataset_samples) + + +class TestInferenceImageDataset: + """Tests for InferenceImageDataset.""" + + @pytest.fixture + def sample_images_dir(self, tmp_path): + """Create a temporary directory with 3 black images.""" + list_images = [] + for i in range(3): + img = Image.new("RGB", (200, 100)) # width, height + output_path = tmp_path / f"img_{i:02d}.png" + img.save(output_path) + list_images.append(output_path) + return tmp_path, list_images + + def test_len(self, sample_images_dir): + """Test dataset length matches number of images.""" + images_dir_path, list_images = sample_images_dir + dataset = InferenceImageDataset( + images_dir=images_dir_path, file_pattern="*.png" + ) + assert len(dataset) == len(list_images) + + def test_getitem(self, sample_images_dir): + """Test __getitem__ returns (image, empty dict).""" + images_dir_path, _ = sample_images_dir + dataset = InferenceImageDataset( + images_dir=images_dir_path, + file_pattern="*.png", + transforms=get_default_inference_transforms(), + ) + # Take one sample + img, annots = dataset[0] + + # Check outputs + assert isinstance(img, torch.Tensor) + assert img.shape == (3, 100, 200) # C, H, W + assert annots == {} + + def test_images_sorted(self, sample_images_dir): + """Test images in dataset are in alphabetical order.""" + images_dir_path, list_files_images_dir = sample_images_dir + dataset = InferenceImageDataset( + images_dir=images_dir_path, + file_pattern="*.png", + ) + + list_filenames = [f.name for f in dataset.image_files] + assert list_filenames == sorted( + [f.name for f in list_files_images_dir] + ) + + def test_file_pattern_filters(self, sample_images_dir): + """Test that file_pattern correctly filters files.""" + images_dir_path, _ = sample_images_dir + + # Add a jpg file to the dataset directory to filter out + img_jpg = Image.new("RGB", (200, 100)) + img_jpg.save(images_dir_path / "test.jpg") + + # Build a dataset from that directory with png filter + dataset = InferenceImageDataset( + images_dir=images_dir_path, + file_pattern="*.png", + ) + + # Check there are no jpg files + assert not all([im.suffix == ".jpg" for im in dataset.image_files]) + + def test_default_transforms(self, sample_images_dir): + """Check that default transforms are assigned if None specified.""" + # Create a minimal dataset + images_dir_path, _ = sample_images_dir + dataset = InferenceImageDataset( + images_dir=images_dir_path, + file_pattern="*.png", + ) + + # Get one sample + img, _annot = dataset[0] + + # Check transforms are applied as expected + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.max() <= 1.0 # scaled from [0,255] to [0,1] + + # Check type + assert isinstance(dataset.transforms, transforms.Compose) + assert len(dataset.transforms.transforms) == 2 diff --git a/tests/test_unit/test_detectors/__init__.py b/tests/test_unit/test_detectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_detectors/test_models.py b/tests/test_unit/test_detectors/test_models.py new file mode 100644 index 00000000..775b9997 --- /dev/null +++ b/tests/test_unit/test_detectors/test_models.py @@ -0,0 +1,720 @@ +from collections.abc import Callable +from contextlib import nullcontext as does_not_raise +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch +import xarray as xr +from torchvision.models import get_model +from torchvision.models.detection import faster_rcnn, fcos, retinanet + +from ethology.detectors.models import ( + DEFAULT_NUM_CLASSES, + ObjectDetector, + _get_n_classes_anchor_based, + _get_n_classes_fasterrcnn, + _get_n_classes_in_detector, +) + +# Map model names to class types +MODEL_CLASS_REGISTRY = { + "fasterrcnn_resnet50_fpn_v2": faster_rcnn.FasterRCNN, + "fasterrcnn_mobilenet_v3_large_fpn": faster_rcnn.FasterRCNN, + "fcos_resnet50_fpn": fcos.FCOS, + "retinanet_resnet50_fpn_v2": retinanet.RetinaNet, +} + + +@pytest.fixture +def sample_checkpoint_path_and_classes(tmp_path: Path) -> Callable: + def _checkpoint_path_and_classes( + model_class, ckpt_format: str, num_classes=91 + ) -> tuple[Path, int]: + """Return the path to a sample checkpoint. + + The checkpoint is for the requested architecture (model_class) and + format (Pytorch or Pytorch Lightning convention). By default, + """ + # Create a model with randomly initialised weights and save its state + # If not specified it has 91 categories by default + model = get_model(model_class, num_classes=num_classes) + ckpt_filename = "test_checkpoint" + + if ckpt_format == "lightning": + checkpoint_path = tmp_path / f"{ckpt_filename}.ckpt" + + # Save as Lightning-style checkpoint (with "model." prefix) + state_dict = { + f"model.{k}": v for k, v in model.state_dict().items() + } + torch.save({"state_dict": state_dict}, checkpoint_path) + + elif ckpt_format == "torch": + checkpoint_path = tmp_path / f"{ckpt_filename}.pt" + + # Save the state_dict as recommended in pytorch docs + # (see https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) + torch.save(model.state_dict(), checkpoint_path) + + else: + raise ValueError(f"Unsupported format: {ckpt_format}") + + return checkpoint_path, num_classes + + return _checkpoint_path_and_classes + + +@pytest.fixture +def valid_predictions_dataset(): + """Create a valid bbox detections dataset to simulate predictions.""" + image_ids = [ + 1, + 2, + ] + annotation_ids = [0, 1] + space_dims = ["x", "y"] + + # Create position, shape and confidence data all zeros + position_data = np.zeros( + (len(image_ids), len(space_dims), len(annotation_ids)) + ) + shape_data = np.copy(position_data) + category_data = np.ones((len(image_ids), len(annotation_ids))) + confidence_data = np.zeros((len(image_ids), len(annotation_ids))) + + # Create the dataset + ds = xr.Dataset( + data_vars={ + "position": (["image_id", "space", "id"], position_data), + "shape": (["image_id", "space", "id"], shape_data), + "category": (["image_id", "id"], category_data), + "confidence": (["image_id", "id"], confidence_data), + }, + coords={ + "image_id": image_ids, + "space": ["x", "y"], + "id": annotation_ids, + }, + ) + + return ds + + +# -------------- Gral model configuration --------------- + + +@pytest.mark.parametrize( + "config, expected_exception", + [ + ( + "", + pytest.raises(TypeError, match="config must be a dictionary"), + ), + ( + {"checkpoint": "path/to/checkpoint"}, + pytest.raises( + ValueError, + match=("model_class must be defined in config"), + ), + ), + ( + {"model_class": "foo"}, + pytest.raises(ValueError, match="Model 'foo' not supported"), + ), + ( + {"model_class": "fcos_resnet50_fpn", "model_kwargs": "foo"}, + pytest.raises(TypeError, match="model_kwargs must be a dict"), + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "model_kwargs": {"n_classes": 3}, + }, + pytest.raises( + ValueError, + match=( + "Invalid key 'n_classes' in model_kwargs. " + "Did you mean 'num_classes'?" + ), + ), + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "model_kwargs": {"trainable_backbone_layrs": 3}, + }, + pytest.raises( + ValueError, + match=( + "Invalid key 'trainable_backbone_layrs' in model_kwargs. " + "Did you mean 'trainable_backbone_layers'?" + ), + ), + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "model_kwargs": {"num_classes": 3}, + }, + does_not_raise(), + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "model_kwargs": {"trainable_backbone_layers": 5}, + }, + does_not_raise(), + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "model_kwargs": { + "num_classes": 3, + "trainable_backbone_layers": 5, + }, + }, + does_not_raise(), + ), + ], + ids=[ + "invalid config wrong type", + "invalid model_class missing", + "invalid model_class unsupported", + "invalid model_kwargs wrong type", + "invalid num_classes misspelt", + "invalid trainable_backbone_layers misspelt", + "valid model_kwargs single 1", + "valid model_kwargs single 2", + "valid model_kwargs multiple", + ], +) +def test_validate_config(config, expected_exception): + """Test the config validation throws the expected errors.""" + with expected_exception: + ObjectDetector._validate_config(config) + + +@pytest.mark.parametrize( + "input_config, expected_config_function", + [ + ( + {"model_class": "fcos_resnet50_fpn"}, + "ethology.detectors.models.ObjectDetector._configure_model_pretrained", + ), + ( + { + "model_class": "fcos_resnet50_fpn", + "checkpoint": "/path/to/checkpoint", + }, + "ethology.detectors.models.ObjectDetector._configure_model_from_checkpoint", + ), + ], + ids=[ + "config without checkpoint", + "config with checkpoint", + ], +) +def test_configure_model(input_config, expected_config_function): + """Test the constructor delegates correctly to the weight loading fn.""" + with patch(expected_config_function) as mock_config_function: + model = ObjectDetector(input_config) + + # check expected function was called + mock_config_function.assert_called_once() + + # check model params are saved as attributes + assert model._model_class # check it is truthy + assert model.model_params # should have at least num_classes + assert "num_classes" in model.model_params + + +@pytest.mark.parametrize( + "model_class", + [ + "fasterrcnn_resnet50_fpn_v2", + "fasterrcnn_mobilenet_v3_large_fpn", + "retinanet_resnet50_fpn_v2", + "fcos_resnet50_fpn", + ], +) +@pytest.mark.parametrize( + "model_kwargs, expected_num_classes", + [ + ({"num_classes": 1}, 1), + ({"num_classes": 100}, 100), + ({}, DEFAULT_NUM_CLASSES), + ], +) +def test_configure_model_pretrained_n_classes( + model_class, model_kwargs, expected_num_classes +): + """Test that the requested number of classes is passed to the model.""" + # Instantiate detector + config = { + "model_class": model_class, + "model_kwargs": model_kwargs, + } + detector = ObjectDetector(config) + + # Check n of classes in output layer + assert ( + _get_n_classes_in_detector(detector.model, model_class) + == expected_num_classes + ) + + # Check model architecture and type + assert isinstance(detector.model, MODEL_CLASS_REGISTRY[model_class]) + assert isinstance(detector.model, torch.nn.Module) + + +# -------------- Configure model from ckpt ----------- + + +@pytest.mark.parametrize( + "ckpt_format", + [ + "lightning", + "torch", + ], +) +@pytest.mark.parametrize( + "model_class", + [ + "fasterrcnn_resnet50_fpn_v2", + "fasterrcnn_mobilenet_v3_large_fpn", + "fcos_resnet50_fpn", + "retinanet_resnet50_fpn_v2", + ], +) +@pytest.mark.parametrize( + "num_classes", + [ + None, # use default (91) + 5, + ], +) +def test_configure_model_from_checkpoint( + sample_checkpoint_path_and_classes, + model_class, + ckpt_format, + num_classes, +): + """Test loading weights from a checkpoint. + + We test the cases of different checkpoint saving formats, + different model classes, and whether the number of classes in + the model architecture is different from the default (91 classes + for COCO2017). + """ + # Get ckpt for specified architecture, num_classes and ckpt format + # If num_classes is None: model will the default number of classes + ckpt_kwargs = {} if num_classes is None else {"num_classes": num_classes} + ckpt_path, out_num_classes = sample_checkpoint_path_and_classes( + model_class, ckpt_format, **ckpt_kwargs + ) + + # Define config + input_config = { + "model_class": model_class, + "checkpoint": str(ckpt_path), + } + if num_classes is not None: + input_config["model_kwargs"] = {"num_classes": out_num_classes} + + # Instantiate detector + detector = ObjectDetector(input_config) + + # Check type and type + assert isinstance(detector.model, torch.nn.Module) + assert ( + _get_n_classes_in_detector(detector.model, model_class) + == out_num_classes + ) + + +@pytest.mark.parametrize( + "format", + [ + "lightning", + "torch", + ], +) +@pytest.mark.parametrize( + "input_config, expected_exception", + [ + ( + { + "model_class": "fasterrcnn_resnet50_fpn_v2", + "model_kwargs": {"num_classes": 2}, + }, + pytest.raises( + RuntimeError, + match=( + r"Error\(s\) in loading state_dict for FasterRCNN:\n\t" + r"size mismatch.*" + ), + ), + ), # ckpt has 91 classes but config specifies 2 + ( + { + "model_class": "fcos_resnet50_fpn", + }, + pytest.raises( + RuntimeError, + match=( + r"Error\(s\) in loading state_dict for FCOS:\n\t" + r"Missing key\(s\) in state_dict.*" + ), + ), + ), # ckpt is fasterrcnn_resnet50_fpn_v2 but config fcos_resnet50_fpn + ], + ids=["mismatch_n_classes", "mismatch_architecture"], +) +def test_configure_model_from_checkpoint_invalid( + sample_checkpoint_path_and_classes, + format, + input_config, + expected_exception, +): + """Test loading weights from a FasterRCNN checkpoint with 91 classes.""" + # Get fasterrcnn COCO2017 ckpt and add to config + fasterrcnn_ckpt_path, _ = sample_checkpoint_path_and_classes( + "fasterrcnn_resnet50_fpn_v2", format + ) + input_config["checkpoint"] = str(fasterrcnn_ckpt_path) + + # Check detector instantiation throws error + with expected_exception: + _detector = ObjectDetector(input_config) + + +# -------- Convenience functions --------------------------- + + +@pytest.mark.parametrize( + "model_class, expected_get_n_classes_fn", + [ + ("fasterrcnn_resnet50_fpn_v2", "_get_n_classes_fasterrcnn"), + ("fasterrcnn_mobilenet_v3_large_fpn", "_get_n_classes_fasterrcnn"), + ("fcos_resnet50_fpn", "_get_n_classes_anchor_based"), + ("retinanet_resnet50_fpn_v2", "_get_n_classes_anchor_based"), + ], +) +def test_get_n_classes_in_detector(model_class, expected_get_n_classes_fn): + """Test _get_n_classes_in_detector delegates to the right function.""" + model = torch.nn.Module() + function_to_patch = ( + f"ethology.detectors.models.{expected_get_n_classes_fn}" + ) + with patch(function_to_patch) as mock_get_n_classes_fn: + _ = _get_n_classes_in_detector(model, model_class) + + # check expected function was called + mock_get_n_classes_fn.assert_called_once() + + +def test_get_n_classes_in_detector_invalid(): + model = torch.nn.Module() + with pytest.raises(ValueError, match="Unsupported model class: foo"): + _ = _get_n_classes_in_detector(model, "foo") + + +@pytest.mark.parametrize( + "model_class, counting_function", + [ + ("fasterrcnn_resnet50_fpn_v2", _get_n_classes_fasterrcnn), + ("fasterrcnn_mobilenet_v3_large_fpn", _get_n_classes_fasterrcnn), + ("fcos_resnet50_fpn", _get_n_classes_anchor_based), + ("retinanet_resnet50_fpn_v2", _get_n_classes_anchor_based), + ], +) +@pytest.mark.parametrize( + "num_classes", + [ + 2, # minimum meaningful value (background + 1 object class) + 100, + ], +) +def test_get_n_classes_specific_function( + model_class, counting_function, num_classes +): + """Test the _get_n_classes... architecture-specific functions.""" + # Instantiate model using torch convenience fn + model = get_model(model_class, num_classes=num_classes) + assert counting_function(model) == num_classes + + +@pytest.mark.parametrize( + "checkpoint_dict", + [ + {"layer1.weight": 1, "layer1.bias": 2}, + # state dict at first level (torch no nesting) + {"state_dict": {"layer1.weight": 1, "layer1.bias": 2}}, + # state dict under "state_dict" key (torch nested) + {"state_dict": {"model.layer1.weight": 1, "model.layer1.bias": 2}}, + # state dict "state_dict" key using "model." prefix + # (nested with model. prefix, Lightning convention) + ], +) +def test_get_model_state_dict(checkpoint_dict): + """Test state_dict extraction from different checkpoint dict formats.""" + expected_state_dict = {"layer1.weight": 1, "layer1.bias": 2} + state_dict_out = ObjectDetector._get_model_state_dict(checkpoint_dict) + assert state_dict_out == expected_state_dict + + +# ---------- Inference --------------------- + + +def test_predict_step(): + """Check predict_step returns expected structure.""" + # Create a minimal detector + config = {"model_class": "fasterrcnn_resnet50_fpn_v2"} + detector = ObjectDetector(config) + detector.eval() + + # Create a fake batch of 2 RGB images and empty annotations + batch_size = 2 + images = torch.rand(batch_size, 3, 224, 224) + annotations = {} # unused by predict_step + batch = (images, annotations) + + # Run predict_step + predictions = detector.predict_step(batch, batch_idx=0) + + # Check output is a list of dicts + assert isinstance(predictions, list) + assert len(predictions) == batch_size + assert all(isinstance(pred, dict) for pred in predictions) + + # Check relevant keys exist + assert all("boxes" in pred for pred in predictions) + assert all("scores" in pred for pred in predictions) + assert all("labels" in pred for pred in predictions) + + # Check arrays are torch tensors + assert all(isinstance(pred["boxes"], torch.Tensor) for pred in predictions) + assert all( + isinstance(pred["scores"], torch.Tensor) for pred in predictions + ) + assert all( + isinstance(pred["labels"], torch.Tensor) for pred in predictions + ) + + +def test_run_inference(valid_predictions_dataset): + """Test that both .predict and ._format_predictions are called.""" + # Create a minimal detector + config = {"model_class": "fasterrcnn_resnet50_fpn_v2"} + detector = ObjectDetector(config) + + # Mock inputs to run_inference + mock_trainer = MagicMock() + mock_dataloader = MagicMock() + ds_attrs = {"source": "test"} + + # Mock the output of trainer.predict + mock_predictions = [ + [ # batch 0 + { # image 0 + "boxes": torch.tensor( + [ + [10.0, 20.0, 50.0, 60.0], + [15.0, 25.0, 55.0, 65.0], + ] + ), + "scores": torch.tensor([0.95, 0.87]), + "labels": torch.tensor([1, 2]), + } + ] + ] + mock_trainer.predict.return_value = mock_predictions + + # Mock the output of detector._format_predictions + with patch.object( + ObjectDetector, + "_format_predictions", + return_value=valid_predictions_dataset, + ) as mock_format_fn: + result = detector.run_inference( + mock_trainer, + mock_dataloader, + attrs=ds_attrs, + ) + + # Verify trainer.predict was called once + mock_trainer.predict.assert_called_once_with(detector, mock_dataloader) + + # Verify _format_predictions was called once + mock_format_fn.assert_called_once_with(mock_predictions, attrs=ds_attrs) + + # Check the result from run_inference is the same as _format_predictions + assert result == valid_predictions_dataset + + +@pytest.mark.parametrize( + ( + "list_raw_predictions, " + "expected_n_images, expected_max_detections, expected_exception" + ), + [ + # ---------------- Two valid batches -------------------- + ( + [ + # Batch 0: 3 images + [ + { + "boxes": torch.tensor([[10.0, 20.0, 30.0, 40.0]]), + "scores": torch.tensor([0.9]), + "labels": torch.tensor([1]), + }, # one detection in image 1. + # Note: float literals are torch.float32 by default + { + "boxes": torch.tensor( + [ + [14.0, 24.0, 34.0, 44.0], + [15.0, 25.0, 35.0, 45.0], + ], + ), + "scores": torch.tensor([0.95, 0.85]), + "labels": torch.tensor([1, 0]), + }, # two detections in image 2 + { + "boxes": torch.tensor( + [ + [18.0, 28.0, 38.0, 48.0], + [19.0, 29.0, 39.0, 49.0], + [20.0, 30.0, 40.0, 50.0], + [21.0, 31.0, 41.0, 51.0], + ], + ), + "scores": torch.tensor([0.91, 0.81, 0.71, 0.61]), + "labels": torch.tensor([0, 0, 1, 1]), + }, # four detections in image 3 + ], + # Batch 1: 2 images + [ + { + "boxes": torch.tensor( + [ + [22.0, 32.0, 42.0, 52.0], + [23.0, 33.0, 43.0, 53.0], + ], + ), + "scores": torch.tensor([0.92, 0.82]), + "labels": torch.tensor([1, 1]), + }, # 2 detections in image 1 + { + "boxes": torch.tensor( + [ + [26.0, 36.0, 46.0, 56.0], + [27.0, 37.0, 47.0, 57.0], + ], + ), + "scores": torch.tensor([0.93, 0.83]), + "labels": torch.tensor([0, 1]), + }, # 2 detections in image 2 + ], + ], + 5, # expected_n_images + 4, # expected_n_max_detections + does_not_raise(), + ), + # ---------- Single valid batch with one image and one detection ----- + ( + [ + [ + { + "boxes": torch.tensor( + [[10.0, 20.0, 30.0, 40.0]], dtype=float + ), + "scores": torch.tensor([0.95]), + "labels": torch.tensor([2]), + } + ], + ], + 1, # expected_n_images + 1, # expected_max_detections + does_not_raise(), + # ------ One valid batch with one image and one empty batch ----- + ), + ( + [ + [ + { + "boxes": torch.tensor( + [[10.0, 20.0, 30.0, 40.0]], dtype=float + ), + "scores": torch.tensor([0.95]), + "labels": torch.tensor([2]), + } + ], + [], + ], + 1, # expected_n_images + 1, # expected_max_detections + does_not_raise(), + ), + # ---------------- Invalid input lists ---------------- + ( + {}, # wrong type + None, + None, + pytest.raises(TypeError, match="predictions must be a list"), + ), + ( + [], # empty outer lists (no batches) + None, + None, + pytest.raises(ValueError, match="predictions list is empty"), + ), + ( + [[], []], # two empty batches, 0 images in each batch + None, + None, + pytest.raises( + ValueError, match="predictions list contains no image data" + ), + ), + ], + ids=[ + "two_batches_padding", + "single_batch_no_padding", + "wrong_type", + "no_batches", + "two_batches_both_no_images", + "two_batches_one_no_images", + ], +) +def test_format_predictions( + list_raw_predictions, + expected_n_images, + expected_max_detections, + expected_exception, +): + """Test that predictions are formatted as a detections dataset.""" + with expected_exception as excinfo: + ds = ObjectDetector._format_predictions(list_raw_predictions) + + # If valid, check array shapes match input data + if not excinfo: + assert ds.position.shape == ( + expected_n_images, + 2, # space dimension + expected_max_detections, + ) + assert ds.shape.shape == ( + expected_n_images, + 2, # space dimension + expected_max_detections, + ) + assert ds.confidence.shape == ( + expected_n_images, + expected_max_detections, + ) + assert ds.category.shape == ( + expected_n_images, + expected_max_detections, + ) diff --git a/tests/test_unit/test_detectors/test_utils.py b/tests/test_unit/test_detectors/test_utils.py new file mode 100644 index 00000000..dd3c8f5a --- /dev/null +++ b/tests/test_unit/test_detectors/test_utils.py @@ -0,0 +1,164 @@ +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest + +from ethology.detectors.utils import ( + _get_padding_width, + _pad_to_max_first_dimension, + centroid_shape_to_corners, + corners_to_centroid_shape, +) + + +@pytest.mark.parametrize( + "array, final_first_dim, expected_exception", + [ + (np.zeros(3), 5, does_not_raise()), + ( + np.zeros(100), + 5, + pytest.raises( + ValueError, + match="more rows than the requested padded size", + ), + ), + (np.zeros((2, 2)), 5, does_not_raise()), + (np.zeros((2, 2, 3)), 5, does_not_raise()), + ], +) +def test_get_padding_width(array, final_first_dim, expected_exception): + """Test the computation of the width to pad along the first dimension.""" + with expected_exception as excinfo: + pad_width = _get_padding_width(array, final_first_dim) + + if not excinfo: + expected_pad_width = array.ndim * [(0, 0)] + expected_pad_width[0] = (0, final_first_dim - array.shape[0]) + assert pad_width == expected_pad_width + + +@pytest.mark.parametrize( + "fill_value", + [np.nan, 42, -1], +) +@pytest.mark.parametrize( + "list_arrays", + [ + [np.zeros((1, 1)), np.zeros((3, 1)), np.zeros((42, 1))], + ], +) +def test_pad_to_max_first_dimension(list_arrays, fill_value): + """Test the padding of a list of arrays to the max first dimension size.""" + # Pad input arrays with fill value + list_arrays_padded = _pad_to_max_first_dimension(list_arrays, fill_value) + + # Check shapes + max_first_dimension = max([x.shape[0] for x in list_arrays]) + assert all([x.shape[0] == max_first_dimension for x in list_arrays_padded]) + + # Check fill value + assert all( + np.allclose(padded[orig.shape[0] :], fill_value, equal_nan=True) + for orig, padded in zip(list_arrays, list_arrays_padded, strict=True) + if padded[orig.shape[0] :].size > 0 + ) + + +@pytest.mark.parametrize( + "fill_value", + [np.nan, 0.5], +) +def test_pad_to_max_first_dimension_dtype_mismatch(fill_value): + """Test that TypeError is raised for incompatible fill_value dtype. + + We test a list of integer input arrays. + """ + with pytest.raises( + TypeError, + match="Ensure fill_value is compatible with array dtype", + ): + list_int_arrays = [ + np.zeros((2, 2), dtype=int), + np.zeros((3, 2), dtype=int), + ] + _pad_to_max_first_dimension(list_int_arrays, fill_value) + + +@pytest.mark.parametrize( + "position, shape, expected_exception", + [ + ( + np.zeros((2, 2)), + np.ones((2, 2)), + does_not_raise(), + ), + ( + np.zeros((2, 2)), + np.ones((1, 2)), + pytest.raises( + ValueError, match="position and shape must have the same shape" + ), + ), + ( + np.zeros((2, 3)), + np.ones((2, 3)), + pytest.raises( + ValueError, + match=( + "Dimension at index 1 must be 2 " + "for both position and shape" + ), + ), + ), + ], +) +def test_centroid_shape_to_corners(position, shape, expected_exception): + """Test conversion of centroid and shape to x1y1, x2y2 corner arrays.""" + with expected_exception as excinfo: + x1y1, x2y2 = centroid_shape_to_corners(position, shape) + + if not excinfo: + # Check values + assert np.allclose(x1y1, np.minimum(x1y1, x2y2)) + assert np.allclose(x2y2, np.maximum(x1y1, x2y2)) + assert np.allclose(x1y1, position - shape / 2) + assert np.allclose(x2y2, position + shape / 2) + + +@pytest.mark.parametrize( + "x1y1, x2y2, expected_exception", + [ + ( + np.zeros((2, 2)), + np.ones((2, 2)), + does_not_raise(), + ), + ( + np.zeros((2, 2)), + np.ones((1, 2)), + pytest.raises( + ValueError, match="x1y1 and x2y2 must have the same shape" + ), + ), + ( + np.zeros((2, 3)), + np.ones((2, 3)), + pytest.raises( + ValueError, + match=( + "Dimension at index 1 must be 2 for both x1y1 and x2y2" + ), + ), + ), + ], +) +def test_corners_to_centroid_shape(x1y1, x2y2, expected_exception): + """Test conversion of x1y1, x2y2 arrays to centroid and shape arrays.""" + with expected_exception as excinfo: + centroid, shape = corners_to_centroid_shape(x1y1, x2y2) + + if not excinfo: + # Check values + assert np.allclose(centroid, 0.5 * (x1y1 + x2y2)) + assert np.allclose(shape, x2y2 - x1y1) diff --git a/tests/test_unit/test_validators/__init__.py b/tests/test_unit/test_validators/__init__.py new file mode 100644 index 00000000..e69de29b