diff --git a/docs/requirements.txt b/docs/requirements.txt index cb55a343..b270cde9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,4 +9,5 @@ sphinx-autodoc-typehints sphinx-design sphinx-gallery sphinx-notfound-page +sphinx-paramlinks sphinx-sitemap diff --git a/docs/source/conf.py b/docs/source/conf.py index 6db7d86d..d906807e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,6 +37,7 @@ "sphinx.ext.autosummary", "sphinx.ext.viewcode", "sphinx.ext.intersphinx", + "sphinx.ext.doctest", # for lightning docstrings "myst_parser", "nbsphinx", "notfound.extension", @@ -44,6 +45,7 @@ "sphinx_gallery.gen_gallery", "sphinx_sitemap", "sphinx.ext.autosectionlabel", + "sphinx_paramlinks", ] # Configure the myst parser to enable cool markdown features @@ -186,6 +188,8 @@ "https://python-jsonschema.readthedocs.io/en/stable/", None, ), + "torch": ("https://pytorch.org/docs/stable/", None), + "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None), } diff --git a/ethology/detectors/__init__.py b/ethology/detectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ethology/detectors/ensembles/__init__.py b/ethology/detectors/ensembles/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ethology/detectors/ensembles/fusion.py b/ethology/detectors/ensembles/fusion.py new file mode 100644 index 00000000..d0dbfe71 --- /dev/null +++ b/ethology/detectors/ensembles/fusion.py @@ -0,0 +1,393 @@ +"""Wrappers around ensemble-boxes fusion functions.""" + +from collections.abc import Callable +from functools import partial +from typing import Literal, TypeAlias, TypedDict, Unpack + +import ensemble_boxes +import numpy as np +import pandas as pd +import xarray as xr +from joblib import Parallel, delayed +from tqdm import tqdm + +from ethology.detectors.ensembles.utils import ( + _centroid_shape_to_corners, + _corners_to_centroid_shape, +) +from ethology.validators.detections import ( + ValidBboxDetectionsDataset, + ValidBboxDetectionsEnsembleDataset, +) +from ethology.validators.utils import _check_input, _check_output + +# ------------------- Supported fusion methods ------------------ +# from ensemble_boxes +VALID_FUSION_METHODS = { + "weighted_boxes_fusion": ensemble_boxes.weighted_boxes_fusion, + "nms": ensemble_boxes.nms, + "soft_nms": ensemble_boxes.soft_nms, + "non_maxium_weighted": ensemble_boxes.non_maximum_weighted, +} + + +# ------------------ Custom types ---------------------- +TypeFusionMethod = Literal[ + "weighted_boxes_fusion", + "nms", + "soft_nms", + "non_maxium_weighted", +] + +TupleFourDataArrays: TypeAlias = tuple[ + xr.DataArray, + xr.DataArray, + xr.DataArray, + xr.DataArray, +] + + +class _TypeFusionMethodKwargs(TypedDict, total=False): + """Type hints for fusion method keyword arguments. + + Parameters for methods as described in the ensemble_boxes documentation. + See https://github.com/ZFTurbo/Weighted-Boxes-Fusion + """ + + weights: list[float] | None + iou_thr: float + skip_box_thr: float + sigma: float + thresh: float + conf_type: Literal["avg", "box_and_model_avg", "absent_model_aware_avg"] + allows_overflow: bool + + +# ---------------------------------- + + +@_check_input(ValidBboxDetectionsEnsembleDataset) +@_check_output(ValidBboxDetectionsDataset) +def fuse_detections( + ensemble_detections_ds: xr.Dataset, + fusion_method: TypeFusionMethod, + fusion_method_kwargs: dict | None = None, + max_n_detections: int | None = None, + n_workers: int | None = -1, +) -> xr.Dataset: + """Fuse ensemble detections across models using the selected method. + + You can set a max_n_detections if upper bound is known a prior to + reduce memory usage. n_workers: number of workers for joblib.Parallel + + """ + # Check if image_width_height defined in dataset + image_shape = ensemble_detections_ds.attrs.get("image_shape") + if image_shape is None: + raise KeyError( + "Required attribute 'image_shape' not found in the dataset " + "attributes. Please ensure the dataset has 'image_shape' " + "(width, height in pixels) in its attributes." + ) + image_width_height = _validate_image_shape(image_shape) + + # Compute upper bound of max_n_detections + if not max_n_detections: + max_n_detections = _estimate_max_n_detections(ensemble_detections_ds) + + # Build single-image partial function for the selected fusion method + if fusion_method not in VALID_FUSION_METHODS: + raise ValueError( + f"Invalid fusion method: {fusion_method}. " + f"Valid methods are: {list(VALID_FUSION_METHODS.keys())}" + ) + fusion_function = VALID_FUSION_METHODS[fusion_method] + _fuse_single_image_detections_partial = partial( + _fuse_single_image_detections, fusion_function + ) + + # Parallelise fusion across image_id + results_per_img_id = Parallel(n_jobs=n_workers)( + delayed(_fuse_single_image_detections_partial)( + ensemble_detections_ds.position.sel(image_id=img_id).values, + ensemble_detections_ds.shape.sel(image_id=img_id).values, + ensemble_detections_ds.confidence.sel(image_id=img_id).values, + ensemble_detections_ds.label.sel(image_id=img_id).values, + image_width_height, + max_n_detections, + **fusion_method_kwargs, + ) + for img_id in tqdm(ensemble_detections_ds.image_id) + ) + + # Postprocess data arrays + fused_detections_ds = _postprocess_multi_image_fused_arrays( + results_per_img_id, ensemble_detections_ds.image_id + ) + + return fused_detections_ds + + +# ------- Multi image fusion ------------------ + + +@_check_output(ValidBboxDetectionsDataset) +def _postprocess_multi_image_fused_arrays( + results_per_img_id: list[TupleFourDataArrays], + list_img_id: list, +) -> xr.Dataset: + """Postprocess fused data arrays on multiple images after fusion. + + Fix padding and assign id coordinates. + """ + # Transpose results from list-of-tuples to tuple-of-lists + da_names = ("position", "shape", "confidence", "label") + da_lists = zip(*results_per_img_id, strict=True) + + # Concatenate lists of dataarrays along image_id dimension and + # remove extra padding in "id" dimension + fused_da_dict = {} + for da_str, list_da in zip(da_names, da_lists, strict=True): + fused_da_dict[da_str] = xr.concat( + list_da, pd.Index(list_img_id, name="image_id") + ).dropna(dim="id", how="all") + + # Pad labels with -1 rather than nan + fused_da_dict["label"] = fused_da_dict["label"].fillna(-1).astype(int) + + return xr.Dataset(data_vars=fused_da_dict) + + +def _validate_image_shape(image_shape) -> np.ndarray: + """Validate and cast image shape as numpy array.""" + # Try casting as numpy array + try: + image_shape = np.asarray(image_shape) + except (TypeError, ValueError) as e: + raise ValueError( + f"Cannot convert 'image_shape' to array: {e}. " + "Expected format: (width, height) as tuple or array-like." + ) from e + + # Check number of elements in array + if image_shape.size != 2: + raise ValueError( + f"'image_shape' must have exactly 2 elements (width, height), " + f"got shape {image_shape.shape}" + ) + return image_shape + + +@_check_input(ValidBboxDetectionsEnsembleDataset) +def _estimate_max_n_detections(ensemble_detections_ds: xr.Dataset) -> int: + """Get upper bound for maximum number of boxes per image after fusion. + + We assume no detections are fused and all images have as many + detections as the maximum number of non-nan detections per image. + """ + detections_w_non_nan_position = ( + ensemble_detections_ds.position.notnull().all(dim="space") + ) # True if non-nan x and y + return ( + detections_w_non_nan_position.sum(dim="id") + .max(dim="image_id") + .sum() + .item() + ) + + +# ------- Single image fusion ------------------ + + +def _fuse_single_image_detections( + fusion_function: Callable, + position: np.ndarray, + shape: np.ndarray, + confidence: np.ndarray, + label: np.ndarray, + image_width_height: np.ndarray, + max_n_detections: int, + **fusion_kwargs: Unpack[_TypeFusionMethodKwargs], # method-only kwargs +) -> TupleFourDataArrays: + """Fuse detections for a single image with selected method.""" + # Prepare single image arrays for fusion + list_bboxes_per_model, list_confidence_per_model, list_label_per_model = ( + _preprocess_single_image_detections( + position, shape, confidence, label, image_width_height + ) + ) + + # Run fusion method on one image + ensemble_x1y1_x2y2_norm, ensemble_scores, ensemble_labels = ( + fusion_function( + list_bboxes_per_model, + list_confidence_per_model, + list_label_per_model, + **fusion_kwargs, + ) + ) + + # Format output as xarray dataarrays + centroid_da, shape_da, confidence_da, label_da = ( + _postprocess_single_image_detections( + ensemble_x1y1_x2y2_norm, + ensemble_scores, + ensemble_labels, + image_width_height, + max_n_detections, + ) + ) + + return centroid_da, shape_da, confidence_da, label_da + + +def _preprocess_single_image_detections( + position: xr.DataArray, + shape: xr.DataArray, + confidence: xr.DataArray, + label: xr.DataArray, + image_width_height: np.ndarray, +) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray]]: + """Prepare detections of an ensemble on a single image for fusion.""" + # Prepare boxes array + # transform position and shape arrays to x1y1x2y normalised + x1y1, x2y2 = _centroid_shape_to_corners(position, shape) + bboxes_x1y1 = x1y1 / image_width_height[:, None, None] + bboxes_x2y2 = x2y2 / image_width_height[:, None, None] + bboxes_x1y1_x2y2_normalised = np.transpose( + np.concat( + [bboxes_x1y1, bboxes_x2y2] + ), # shape: 4, max_n_annotations_per_frame, n_models + (1, 0, 2), # shape: max_n_annotations_per_frame, 4, n_models + ) + + # -------------------- + # Get list of bboxes per model + # arrays need to be tall for fusion methods + n_models = bboxes_x1y1_x2y2_normalised.shape[-1] + list_x1y1_x2y2_norm_per_model = [ + arr.squeeze() + for arr in np.split(bboxes_x1y1_x2y2_normalised, n_models, axis=-1) + ] + list_confidence_per_model = [ + arr.squeeze() for arr in np.split(confidence, n_models, axis=-1) + ] + list_label_per_model = [ + arr.squeeze() for arr in np.split(label, n_models, axis=-1) + ] + # -------------------- + + # Remove rows with nan coordinates and return lists of arrays + list_non_nan_bboxes_per_model = [ + sum(~np.any(np.isnan(arr), axis=1)) + for arr in list_x1y1_x2y2_norm_per_model + ] + return ( + _chop_end_of_array( + list_x1y1_x2y2_norm_per_model, list_non_nan_bboxes_per_model + ), + _chop_end_of_array( + list_confidence_per_model, list_non_nan_bboxes_per_model + ), + _chop_end_of_array( + list_label_per_model, list_non_nan_bboxes_per_model + ), + ) + + +def _chop_end_of_array( + list_arrays: list[np.ndarray], list_end_lengths: list[int] +) -> list[np.ndarray]: + """Chop end of arrays in list to desired length along first dimension.""" + return [ + arr[:n] for arr, n in zip(list_arrays, list_end_lengths, strict=True) + ] + + +def _postprocess_single_image_detections( + ensemble_x1y1_x2y2_norm, + ensemble_scores, + ensemble_labels, + image_width_height, + max_n_detections, +): + """Postprocess fused single-image detections as dataarrays. + + Unnormalise, pad and format as data arrays. + """ + # Undo boxes x1y1 x2y2 normalization + ensemble_x1y1_x2y2 = ensemble_x1y1_x2y2_norm * np.tile( + image_width_height, (1, 2) + ) + + # Get 1d array for non-nan boxes + bool_non_nan_array = ~np.any(np.isnan(ensemble_x1y1_x2y2), axis=1) + n_non_nan_boxes = bool_non_nan_array.sum() + if n_non_nan_boxes > max_n_detections: + raise ValueError( + "Insufficient padding provided. " + "The estimated maximum number of detections per image was set to " + f"{max_n_detections}, " + f"but {n_non_nan_boxes} detections were " + "found in one of the images after fusion. Please increase the " + "maximum number of detections per image." + ) + + # Retain non-nan boxes only and pad each array + return _parse_single_image_detections_as_dataarrays( + *( + _remove_nan_and_pad_to_max( + arr, bool_non_nan_array, max_n_detections + ) + for arr in (ensemble_x1y1_x2y2, ensemble_scores, ensemble_labels) + ), + ) + + +def _remove_nan_and_pad_to_max( + input_array, mask_non_nan_rows, max_n_detections, fill_value=np.nan +): + """Remove non-nan from input array and pad, all along first dimension.""" + # Initialise array with nans + padded_array = np.full( + (max_n_detections, *input_array.shape[1:]), + fill_value, + dtype=input_array.dtype, + ) + # Replace top "mask_non_nan_rows.sum()" chunk with non-nan values from + # input array + padded_array[: mask_non_nan_rows.sum()] = input_array[mask_non_nan_rows] + return padded_array + + +def _parse_single_image_detections_as_dataarrays( + x1y1_x2y2_array: np.ndarray, + scores_array: np.ndarray, + labels_array: np.ndarray, + id_array: np.ndarray | None = None, +) -> TupleFourDataArrays: + """Format array of single image fused results as data arrays.""" + if id_array is None: + n_detections = x1y1_x2y2_array.shape[0] + id_array = np.arange(n_detections) + + # Extract bbox centre and shape + centroid, shape = _corners_to_centroid_shape( + x1y1_x2y2_array[:, 0:2], x1y1_x2y2_array[:, 2:4] + ) + + # Shared coordinates + id_coords = {"id": id_array} + spatial_id_coords = {"space": ["x", "y"], **id_coords} + + # Build all DataArrays + return ( + xr.DataArray( + centroid.T, + dims=["space", "id"], + coords=spatial_id_coords, + ), + xr.DataArray(shape.T, dims=["space", "id"], coords=spatial_id_coords), + xr.DataArray(scores_array, dims=["id"], coords=id_coords), + xr.DataArray(labels_array, dims=["id"], coords=id_coords), + ) diff --git a/ethology/detectors/ensembles/models.py b/ethology/detectors/ensembles/models.py new file mode 100644 index 00000000..ab1eada4 --- /dev/null +++ b/ethology/detectors/ensembles/models.py @@ -0,0 +1,226 @@ +"""Lightning Modules for ensembles of detectors.""" + +from itertools import chain +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import xarray as xr +import yaml +from lightning import LightningModule +from torchvision.models import detection, get_model, list_models + +from ethology.detectors.ensembles.utils import ( + _corners_to_centroid_shape, + _pad_to_max_first_dimension, +) +from ethology.validators.detections import ValidBboxDetectionsEnsembleDataset +from ethology.validators.utils import _check_output + + +class EnsembleDetector(LightningModule): + """Ensemble of (trained) detectors for inference. + + Attributes + ---------- + config_file: str + Path to the YAML config file. + + """ + + def __init__(self, config_file: str | Path): + """Initialise ensemble of detectors.""" + super().__init__() + + # Load config + self.config_file = Path(config_file) + with open(self.config_file) as f: + self.config = yaml.safe_load(f) + + # Run checks + self._validate_model_class(self.config["models"]["model_class"]) + + # Load list of models (nn.ModuleList) + self.list_models = self._load_models() + + @staticmethod + def _validate_model_class(model_class_str: str) -> None: + """Validate that the model is part of torchvision.models.detection.""" + valid_models = set(list_models(module=detection)) + if model_class_str not in valid_models: + valid_sorted = ", ".join(sorted(valid_models)) + raise ValueError( + f"'{model_class_str}' is not a supported detection model. " + f"Valid options: {valid_sorted}" + ) + + def _load_models(self) -> nn.ModuleList: + """Load models from checkpoints.""" + # Get model config + models_config = self.config["models"] + + # Load weights + list_models = [] + for checkpoint_path in models_config["checkpoints"]: + # Get checkpoint + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # Instantiate model with ckpt weights + model = get_model( + models_config["model_class"], + **models_config.get("model_kwargs", {}), + ) + model_state_dict = self._get_model_state_dict(checkpoint) + model.load_state_dict(model_state_dict, strict=True) + + # Append model to list + list_models.append(model) + + return nn.ModuleList(list_models) + + @staticmethod + def _get_model_state_dict(checkpoint): + # Handle different checkpoint formats + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif isinstance(checkpoint, dict): + # Checkpoint might be the state dict itself + state_dict = checkpoint + else: + raise ValueError( + "Checkpoint format not recognized. " + "Expected 'state_dict' key or dict of tensors." + ) + + # Load state dict into model + # PyTorch Lightning saves the model with a "model." + # prefix in the state_dict keys if you defined self.model + # in your LightningModule - we remove the prefix here. + if any(key.startswith("model.") for key in state_dict): + model_state_dict = { + key.replace("model.", "", 1): value + for key, value in state_dict.items() + if key.startswith("model.") + } + else: + model_state_dict = state_dict + + return model_state_dict + + def predict_step(self, batch, batch_idx): + """Predict step for a single batch.""" + # Run all models in ensemble in GPU + images_batch, _annotations_batch = batch + # # ----------------------------------- + raw_prediction_dicts_per_model = [ + model(images_batch) for model in self.list_models + ] # [num_models][batch_size] + + # Run all models in parallel on this GPU + # inputs = [(images_batch,)] * len(self.list_models[:3]) + # raw_prediction_dicts_per_model = parallel_apply( + # modules=self.list_models, #----- + # inputs=[(images_batch,)] * len(self.list_models), + # ) + # # ----------------------------------- + + # Transpose to [batch_size][num_models] for easier downstream + # processing + raw_prediction_dicts_per_sample = [ + list(one_sample_all_models) + for one_sample_all_models in zip( + *raw_prediction_dicts_per_model, strict=True + ) + ] # [batch_size][num_models] + + return raw_prediction_dicts_per_sample + + @staticmethod + @_check_output(ValidBboxDetectionsEnsembleDataset) + def format_predictions( + predictions: list[dict], attrs: dict | None = None + ) -> xr.Dataset: + """Format as ethology detections dataset with model axis. + + predictions: raw_predictions_per_model + """ + # Get results from trainer + # raw_predictions_per_model = self.trainer.predict_loop.predictions + + # Flatten batches + raw_prediction_dicts_per_sample = list( + chain.from_iterable(predictions) + ) # [sample][model] + n_models = len(raw_prediction_dicts_per_sample[0]) + + # Parse output from dicts + output_per_sample: dict[str, list] = { + "boxes": [], + "scores": [], + "labels": [], + } + for ky in output_per_sample: + output_per_sample[ky] = [ + [sample[m][ky] for m in range(n_models)] + for sample in raw_prediction_dicts_per_sample + ] # [sample][model] + + # Pad across models and across image_ids + fill_value = {"boxes": np.nan, "scores": np.nan, "labels": -1} + output_per_sample_padded: dict[str, list] = { + ky: [] for ky in output_per_sample + } + for ky in output_per_sample_padded: + output_per_sample_padded[ky] = _pad_to_max_first_dimension( + [ + # pad across models + np.stack( + _pad_to_max_first_dimension( + output_one_sample, fill_value[ky] + ), + axis=-1, + ) + for output_one_sample in output_per_sample[ky] + ], + fill_value[ky], + ) + + # Stack and reorder dimensions + bboxes_array = np.transpose( + np.stack(output_per_sample_padded["boxes"]), + (0, -2, 1, -1), + ) + scores_array = np.stack(output_per_sample_padded["scores"]) + labels_array = np.stack(output_per_sample_padded["labels"]) + # arrays of shape (image_id, 4/1, n_max_detections, n_models) + + # Compute centroid and shape arrays + # centroid_array = 0.5 * (bboxes_array[:, 0:2] + bboxes_array[:, 2:4]) + # shape_array = bboxes_array[:, 2:4] - bboxes_array[:, 0:2] + centroid_array, shape_array = _corners_to_centroid_shape( + bboxes_array[:, 0:2], bboxes_array[:, 2:4] + ) + + # Return as ethology detections dataset + max_n_detections = bboxes_array.shape[-2] + n_images = bboxes_array.shape[0] + + return xr.Dataset( + data_vars={ + "position": ( + ["image_id", "space", "id", "model"], + centroid_array, + ), + "shape": (["image_id", "space", "id", "model"], shape_array), + "confidence": (["image_id", "id", "model"], scores_array), + "label": (["image_id", "id", "model"], labels_array), + }, + coords={ + "image_id": np.arange(n_images), + "space": ["x", "y"], + "id": np.arange(max_n_detections), + "model": np.arange(n_models), + }, + attrs=attrs if attrs else {}, + ) diff --git a/ethology/detectors/ensembles/utils.py b/ethology/detectors/ensembles/utils.py new file mode 100644 index 00000000..4a686d44 --- /dev/null +++ b/ethology/detectors/ensembles/utils.py @@ -0,0 +1,52 @@ +"""Utility functions for reshaping outputs of ensembles of detectors.""" + +import numpy as np + + +def _get_padding_width(array, max_n): + """Get pad width for array to max_n detections in the first dimension.""" + pad_width = array.ndim * [(0, 0)] + pad_width[0] = (0, max_n - array.shape[0]) # before, after + return pad_width + + +def _pad_to_max_first_dimension(list_arrays, fill_value=np.nan): + """Pad arrays in list to maximum size of their first dimension.""" + max_n_detections = max(array.shape[0] for array in list_arrays) + list_arrays_padded = [ + np.pad( + arr, + _get_padding_width(arr, max_n_detections), + mode="constant", + constant_values=fill_value, + ) + for arr in list_arrays + ] + return list_arrays_padded + + +def _centroid_shape_to_corners(position, shape): + """Convert centroid and shape arrays to x1y1, x2y2 corner arrays. + + x1y1 is the top left corner (min x-coordinate, min y-coordinate), + x2y2 is the bottom right corner (max x-coordinate, max y-coordinate) + of the bounding box. + + Space dimension is assumed to be the second dimension. + """ + half_shape = shape / 2 + return ( + position - half_shape, # x1y1 + position + half_shape, # x2y2 + ) + + +def _corners_to_centroid_shape(x1y1, x2y2): + """Convert x1y1, x2y2 corner arrays to centroid and shape arrays. + + Space dimension is assumed to be the second dimension. + """ + return ( + 0.5 * (x1y1 + x2y2), # centroid + x2y2 - x1y1, # shape + ) diff --git a/ethology/validators/detections.py b/ethology/validators/detections.py index 8b75f5b2..cad21678 100644 --- a/ethology/validators/detections.py +++ b/ethology/validators/detections.py @@ -59,3 +59,58 @@ class ValidBboxDetectionsDataset(ValidDataset): "shape": {"image_id", "space", "id"}, "confidence": {"image_id", "id"}, } + + +@define +class ValidBboxDetectionsEnsembleDataset(ValidDataset): + """Class for valid ``ethology`` bounding box ensemble detections datasets. + + This class validates that the input dataset: + + - is an xarray Dataset, + - has ``image_id``, ``space``, ``id`` and ``model`` as dimensions, + - has ``position``, ``shape`` and ``confidence`` as data variables, + - ``position`` and ``shape`` span at least the dimensions ``image_id``, + ``space``, ``id`` and ``model``, + - ``confidence`` spans at least the dimensions ``image_id``, ``id`` + and ``model``. + + + Attributes + ---------- + dataset : xarray.Dataset + The xarray dataset to validate. + required_dims : ClassVar[set] + The set of required dimension names: ``image_id``, ``space``, ``id`` + and ``model``. + required_data_vars : ClassVar[dict[str, set]] + A dictionary mapping data variable names to their required minimum + dimensions: + + - ``position`` maps to ``image_id``, ``space``, ``id`` and ``model``, + - ``shape`` maps to ``image_id``, ``space``, ``id`` and ``model``, + - ``confidence`` maps to ``image_id``, ``id`` and ``model``. + + Raises + ------ + TypeError + If the input is not an xarray Dataset. + ValueError + If the dataset is missing required data variables or dimensions, + or if any required dimensions are missing for any data variable. + + Notes + ----- + The dataset can have other data variables and dimensions, but only the + required ones are checked. + + """ + + # Minimum requirements for a bbox dataset holding detections + # Should not be modified after initialization + required_dims: ClassVar[set] = {"image_id", "space", "id", "model"} + required_data_vars: ClassVar[dict[str, set]] = { + "position": {"image_id", "space", "id", "model"}, + "shape": {"image_id", "space", "id", "model"}, + "confidence": {"image_id", "id", "model"}, + } diff --git a/pyproject.toml b/pyproject.toml index 2f9b9cc7..4a7b7b63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,13 +19,19 @@ classifiers = [ "License :: OSI Approved :: BSD License", ] dependencies = [ - "movement", + "xarray", + "pooch", + "pyyaml", "pandera[pandas]", "pycocotools", + "movement", "scikit-learn", "torch", "torchvision", + "ensemble-boxes", + "lightning", "loguru", + "joblib", ] [project.urls] diff --git a/tests/test_unit/test_datasets/__init__.py b/tests/test_unit/test_datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_detectors_ensembles/__init__.py b/tests/test_unit/test_detectors_ensembles/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_detectors_ensembles/test_utils.py b/tests/test_unit/test_detectors_ensembles/test_utils.py new file mode 100644 index 00000000..fcd1a54a --- /dev/null +++ b/tests/test_unit/test_detectors_ensembles/test_utils.py @@ -0,0 +1,119 @@ +import numpy as np +import pytest + +from ethology.detectors.ensembles.utils import ( + _centroid_shape_to_corners, + _corners_to_centroid_shape, + _get_padding_width, + _pad_to_max_first_dimension, +) + + +@pytest.mark.parametrize( + "array, target_first_dim, expected_pad_width_first_dim", + [ + ( + np.zeros((3,)), + 5, + (0, 2), + ), # 1D array + ( + np.zeros((1, 2, 3)), + 4, + (0, 3), + ), # 3D array + ( + np.zeros((10, 2, 3)), + 10, + (0, 0), + ), # No padding needed + ], +) +def test_get_padding_width( + array, target_first_dim, expected_pad_width_first_dim +): + """Test getting padding width for arrays of different dimensions.""" + pad_width = _get_padding_width(array, target_first_dim) + + assert len(pad_width) == array.ndim + assert pad_width[0] == expected_pad_width_first_dim + assert all(pw == (0, 0) for pw in pad_width[1:]) + + +@pytest.mark.parametrize( + "fill_value", + [np.nan, np.inf, 42], +) +def test_pad_to_max_first_dimension(fill_value): + """Test padding all arrays in list along first dimension.""" + # Get max array length + list_arrays = [np.zeros((1, 2, 3)), np.zeros((10, 2, 3))] + max_array_length = max([arr.shape[0] for arr in list_arrays]) + + # Pad + list_arrays_padded = _pad_to_max_first_dimension(list_arrays, fill_value) + + # Assert all same length + assert all( + [arr.shape[0] == max_array_length for arr in list_arrays_padded] + ) + # Assert other dimensions stay the same + assert all( + [ + arr.shape[1:] == arr_input.shape[1:] + for arr, arr_input in zip( + list_arrays_padded, list_arrays, strict=True + ) + ] + ) + # Assert padding value + assert all( + [ + np.allclose( + arr[arr_input.shape[0] :], + np.full_like(arr[arr_input.shape[0] :], fill_value), + equal_nan=True, + ) + for arr, arr_input in zip( + list_arrays_padded, list_arrays, strict=True + ) + ] + ) + + +@pytest.mark.parametrize( + "position, shape, expected_x1y1, expected_x2y2", + [ + ( + np.zeros((1, 2)), + np.array([[4, 2]]), + np.array([[-2, -1]]), + np.array([[2, 1]]), + ) + ], +) +def test_centroid_shape_to_corners( + position, shape, expected_x1y1, expected_x2y2 +): + x1y1, x2y2 = _centroid_shape_to_corners(position, shape) + np.testing.assert_array_equal(x1y1, expected_x1y1) + np.testing.assert_array_equal(x2y2, expected_x2y2) + + +@pytest.mark.parametrize( + "x1y1, x2y2, expected_position, expected_shape", + [ + ( + np.zeros((1, 2)), + np.ones((1, 2)), + np.array([[0.5, 0.5]]), + np.array([[1, 1]]), + ) + ], +) +def test_corners_to_centroid_shape( + x1y1, x2y2, expected_position, expected_shape +): + position, shape = _corners_to_centroid_shape(x1y1, x2y2) + np.testing.assert_array_equal(position, expected_position) + np.testing.assert_array_equal(shape, expected_shape) diff --git a/tests/test_unit/test_validators/__init__.py b/tests/test_unit/test_validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_validators/test_detections.py b/tests/test_unit/test_validators/test_detections.py index d053d6ef..5a60da56 100644 --- a/tests/test_unit/test_validators/test_detections.py +++ b/tests/test_unit/test_validators/test_detections.py @@ -4,7 +4,10 @@ import pytest import xarray as xr -from ethology.validators.detections import ValidBboxDetectionsDataset +from ethology.validators.detections import ( + ValidBboxDetectionsDataset, + ValidBboxDetectionsEnsembleDataset, +) @pytest.fixture @@ -38,6 +41,28 @@ def valid_bbox_detections_dataset(): return ds +@pytest.fixture +def valid_bbox_detections_ensemble_dataset(valid_bbox_detections_dataset): + """Create a valid bbox detections ensemble_dataset for validation.""" + # Add model dimension + ds = valid_bbox_detections_dataset.expand_dims( + model=["model_a", "model_b"] + ) + + return ds + + +@pytest.fixture +def valid_bbox_detections_ensemble_dataset_extra_vars_and_dims( + valid_bbox_detections_ensemble_dataset: xr.Dataset, +) -> xr.Dataset: + ds = valid_bbox_detections_ensemble_dataset.copy(deep=True) + ds.coords["extra_dim"] = [10, 20, 30] + ds["extra_var_1"] = (["image_id"], np.random.rand(len(ds.image_id))) + ds["extra_var_2"] = (["id"], np.random.rand(len(ds.id))) + return ds + + @pytest.fixture def valid_bbox_detections_dataset_extra_vars_and_dims( valid_bbox_detections_dataset: xr.Dataset, @@ -49,44 +74,71 @@ def valid_bbox_detections_dataset_extra_vars_and_dims( return ds +# Define validator configurations +VALIDATOR_CONFIGS: dict = { + "detections_ds": { + "validator_class": ValidBboxDetectionsDataset, + "valid_fixture": "valid_bbox_detections_dataset", + "valid_fixture_extra": ( + "valid_bbox_detections_dataset_extra_vars_and_dims" + ), + "required_dims": {"image_id", "space", "id"}, + "required_data_vars": { + "position": {"image_id", "space", "id"}, + "shape": {"image_id", "space", "id"}, + "confidence": {"image_id", "id"}, + }, + }, + "ensemble_ds": { + "validator_class": ValidBboxDetectionsEnsembleDataset, + "valid_fixture": "valid_bbox_detections_ensemble_dataset", + "valid_fixture_extra": ( + "valid_bbox_detections_ensemble_dataset_extra_vars_and_dims" + ), + "required_dims": {"image_id", "space", "id", "model"}, + "required_data_vars": { + "position": {"image_id", "space", "id", "model"}, + "shape": {"image_id", "space", "id", "model"}, + "confidence": {"image_id", "id", "model"}, + }, + }, +} + + +@pytest.mark.parametrize("validator_type", ["detections_ds", "ensemble_ds"]) +@pytest.mark.parametrize( + "valid_fixture_key", + [ + "valid_fixture", + "valid_fixture_extra", + ], +) +def test_validator_bbox_detections_dataset_valid( + validator_type: str, + valid_fixture_key: str, + request: pytest.FixtureRequest, +): + """Test bbox detections dataset validation with valid datasets.""" + config = VALIDATOR_CONFIGS[validator_type] + fixture_name = config[valid_fixture_key] + dataset = request.getfixturevalue(fixture_name) + + validator_class = config["validator_class"] + with does_not_raise(): + validator = validator_class(dataset=dataset) + + assert validator.dataset is dataset + assert validator.required_dims == config["required_dims"] + assert validator.required_data_vars == config["required_data_vars"] + + +@pytest.mark.parametrize( + "validator", + [ValidBboxDetectionsDataset, ValidBboxDetectionsEnsembleDataset], +) @pytest.mark.parametrize( "sample_dataset, expected_exception, expected_error_message", [ - ( - "valid_bbox_detections_dataset", - does_not_raise(), - "", - ), - ( - "valid_bbox_detections_dataset_extra_vars_and_dims", - does_not_raise(), - "", - ), - ( - xr.Dataset( - coords={ - "image_id": np.arange(3), - "space": np.arange(2), - "id": np.arange(2), - }, - data_vars={ - "position": ( - ["image_id", "space", "id"], - np.zeros((3, 2, 2)), - ), - "shape": ( - ["image_id", "space", "id", "foo"], - np.zeros((3, 2, 2, 1)), - ), - "confidence": ( - ["image_id", "id"], - np.zeros((3, 2)), - ), - }, - ), - does_not_raise(), - "", - ), ( {"position": [1, 2, 3], "shape": [4, 5, 6]}, pytest.raises(TypeError), @@ -130,13 +182,56 @@ def valid_bbox_detections_dataset_extra_vars_and_dims( pytest.raises(ValueError), "Missing required data variables: ['confidence', 'shape']", ), + ], + ids=[ + "invalid_type", + "invalid_missing_data_var", + "invalid_missing_multiple_data_vars", + ], +) +def test_validator_bbox_detections_dataset_invalid( + validator: type[ValidBboxDetectionsDataset] + | type[ValidBboxDetectionsEnsembleDataset], + sample_dataset: xr.Dataset, + expected_exception: pytest.raises, + expected_error_message: str, +): + """Test bbox annotations dataset validation in various input scenarios.""" + # Run validation and check exception + with expected_exception as excinfo: + _validator = validator(dataset=sample_dataset) + if excinfo: + error_msg = str(excinfo.value) + assert error_msg in expected_error_message + + +@pytest.mark.parametrize( + "validator", + [ValidBboxDetectionsDataset, ValidBboxDetectionsEnsembleDataset], +) +@pytest.mark.parametrize( + "sample_dataset, expected_exception, expected_error_message", + [ ( xr.Dataset( - coords={"image_id": np.arange(3), "id": np.arange(2)}, + coords={ + "image_id": np.arange(3), + "id": np.arange(2), + "model": np.arange(2), + }, data_vars={ - "position": (["image_id", "id"], np.zeros((3, 2))), - "shape": (["image_id", "id"], np.zeros((3, 2))), - "confidence": (["image_id", "id"], np.zeros((3, 2))), + "position": ( + ["image_id", "id", "model"], + np.zeros((3, 2, 2)), + ), + "shape": ( + ["image_id", "id", "model"], + np.zeros((3, 2, 2)), + ), + "confidence": ( + ["image_id", "id", "model"], + np.zeros((3, 2, 2)), + ), }, ), pytest.raises(ValueError), @@ -148,19 +243,20 @@ def valid_bbox_detections_dataset_extra_vars_and_dims( "foo": np.arange(3), "bar": ["x", "y"], "id": np.arange(2), + "model": np.arange(2), }, data_vars={ "position": ( - ["foo", "bar", "id"], - np.zeros((3, 2, 2)), + ["foo", "bar", "id", "model"], + np.zeros((3, 2, 2, 2)), ), "shape": ( - ["foo", "bar", "id"], - np.zeros((3, 2, 2)), + ["foo", "bar", "id", "model"], + np.zeros((3, 2, 2, 2)), ), "confidence": ( - ["foo", "id"], - np.zeros((3, 2)), + ["foo", "id", "model"], + np.zeros((3, 2, 2)), ), }, ), @@ -173,19 +269,20 @@ def valid_bbox_detections_dataset_extra_vars_and_dims( "image_id": np.arange(3), "space": np.arange(2), "id": np.arange(2), + "model": np.arange(2), }, data_vars={ "position": ( - ["image_id", "space", "id"], - np.zeros((3, 2, 2)), + ["image_id", "space", "id", "model"], + np.zeros((3, 2, 2, 2)), ), "shape": ( - ["image_id", "id"], - np.zeros((3, 2)), + ["image_id", "id", "model"], + np.zeros((3, 2, 2)), ), "confidence": ( - ["image_id", "id"], - np.zeros((3, 2)), + ["image_id", "id", "model"], + np.zeros((3, 2, 2)), ), }, ), @@ -197,42 +294,21 @@ def valid_bbox_detections_dataset_extra_vars_and_dims( ), ], ids=[ - "valid_bbox_detections", - "valid_bbox_detections_extra_vars_and_dims", - "valid_bbox_detections_extra_dims_in_shape_var", - "invalid_bbox_detections_type", - "invalid_bbox_detections_dataset_missing_data_var", - "invalid_bbox_detections_missing_multiple_data_vars", - "invalid_bbox_detections_missing_dimension", - "invalid_bbox_detections_missing_multiple_dimensions", - "invalid_bbox_detections_missing_dimension_in_data_var", + "invalid_missing_dimension", + "invalid_missing_multiple_dimensions", + "invalid_missing_dimension_in_data_var", ], ) -def test_validator_bbox_detections_dataset( - sample_dataset: str | dict, +def test_validator_bbox_detections_dataset_missing_dims( + validator: type[ValidBboxDetectionsDataset] + | type[ValidBboxDetectionsEnsembleDataset], + sample_dataset: xr.Dataset, expected_exception: pytest.raises, expected_error_message: str, - request: pytest.FixtureRequest, ): - """Test bbox annotations dataset validation in various input scenarios.""" - # Get dataset to validate - if isinstance(sample_dataset, str): - dataset = request.getfixturevalue(sample_dataset) - else: - dataset = sample_dataset - # Run validation and check exception with expected_exception as excinfo: - validator = ValidBboxDetectionsDataset(dataset=dataset) - + _validator = validator(dataset=sample_dataset) if excinfo: error_msg = str(excinfo.value) assert error_msg in expected_error_message - else: - assert validator.dataset is dataset - assert validator.required_dims == {"image_id", "space", "id"} - assert validator.required_data_vars == { - "position": {"image_id", "space", "id"}, - "shape": {"image_id", "space", "id"}, - "confidence": {"image_id", "id"}, - }