diff --git a/ethology/io/annotations/load_idtracker.py b/ethology/io/annotations/load_idtracker.py new file mode 100644 index 00000000..be4cf1da --- /dev/null +++ b/ethology/io/annotations/load_idtracker.py @@ -0,0 +1,477 @@ +"""Load bounding box annotations from idtracker.ai output files.""" + +import pickle +from pathlib import Path + +import numpy as np +import xarray as xr +from loguru import logger + +from ethology.validators.annotations import ValidBboxAnnotationsDataset +from ethology.validators.utils import _check_output + + +@_check_output(ValidBboxAnnotationsDataset) +def from_idtracker( + trajectories_path: Path | str, + frame_indices: list[int], + bbox_size: tuple[float, float] | None = None, + blobs_collection_path: Path | str | None = None, + images_dir: Path | str | None = None, +) -> xr.Dataset: + """Generate a bounding box annotations dataset from idtracker.ai output. + + Creates an ``ethology`` bounding box annotations dataset from + idtracker.ai trajectory data and, optionally, from blob detection + data. Each selected frame becomes one ``image_id``; each tracked + animal becomes one ``id`` entry per frame. + + Parameters + ---------- + trajectories_path : Path or str + Path to the idtracker.ai trajectories file (``.npy``). + The array must have shape ``(n_frames, n_animals, 2)`` where the + last dimension holds ``(x, y)`` centroid coordinates in pixels. + ``NaN`` values indicate that an animal was not detected in a + given frame. + frame_indices : list[int] + Zero-based frame indices for which to generate bounding box + annotations. Duplicate indices are silently removed. Every + index must be non-negative and within the number of frames + recorded in the trajectories file. + bbox_size : tuple[float, float], optional + Fixed bounding box size ``(width, height)`` in pixels applied to + every detected animal. Required when ``blobs_collection_path`` + is ``None``. Ignored when ``blobs_collection_path`` is provided. + Both values must be strictly positive. + blobs_collection_path : Path or str, optional + Path to the idtracker.ai blobs collection file (``.pkl``). When + provided, bounding boxes are extracted directly from the blob + objects rather than computed from centroids and a fixed size. + The pickled object must expose a ``.blobs_in_video`` attribute: + a list (indexed by frame) of lists of blob objects, where each + blob exposes: + + - ``.bounding_box``: sequence ``[x_min, y_min, x_max, y_max]`` + in pixel coordinates; + - ``.identity``: 1-based integer animal identity (``0`` or + ``None`` means the blob was not identified). + + images_dir : Path or str, optional + Directory that contains the extracted video frames. Stored in + dataset attributes when provided but otherwise not used. + + Returns + ------- + xarray.Dataset + A valid ``ethology`` bounding box annotations dataset with + dimensions ``image_id``, ``space``, ``id`` and data variables: + + - ``position`` (``image_id``, ``space``, ``id``): bbox centroid + ``(x, y)`` in pixels. ``NaN`` for undetected animals. + - ``shape`` (``image_id``, ``space``, ``id``): bbox + ``(width, height)`` in pixels. ``NaN`` for undetected animals. + - ``category`` (``image_id``, ``id``): 1-based integer animal + identity. ``-1`` for undetected animals. + + Dataset attributes: + + - ``trajectories_file``: path to the trajectories file. + - ``blobs_collection_file``: path to the blobs file, or ``None``. + - ``images_directory``: path to the images directory, or ``None``. + - ``map_category_to_str``: mapping from 1-based animal ID to the + string label ``"animal_"``. + - ``map_image_id_to_filename``: mapping from ``image_id`` to the + canonical frame filename ``"frame_.png"``. + + Raises + ------ + FileNotFoundError + If ``trajectories_path`` or ``blobs_collection_path`` does not + exist on disk. + ValueError + If ``frame_indices`` is empty or contains negative values. + If any frame index is out of range for the given trajectories. + If neither ``bbox_size`` nor ``blobs_collection_path`` is + provided. + If ``bbox_size`` is provided but does not have exactly two + elements or contains non-positive values. + If the loaded trajectories array does not have shape + ``(n_frames, n_animals, 2)``. + If the blobs collection object does not have a + ``.blobs_in_video`` attribute. + + Notes + ----- + The ``image_id`` coordinate is assigned as the 0-based position of + each frame in the **sorted, deduplicated** ``frame_indices`` list. + The ``map_image_id_to_filename`` attribute maps each ``image_id`` to + a canonical frame filename ``"frame_{frame_index:06d}.png"``. + + The ``id`` coordinate ranges from ``0`` to ``n_animals - 1`` and + corresponds directly to the column index in the trajectories array. + For the blobs case the animal index is derived from the 1-based blob + identity as ``identity - 1``. + + Examples + -------- + Generate annotations with a fixed bounding box size: + + >>> import numpy as np + >>> from ethology.io.annotations.load_idtracker import from_idtracker + >>> ds = from_idtracker( + ... trajectories_path="path/to/trajectories.npy", + ... frame_indices=[0, 10, 20], + ... bbox_size=(50.0, 50.0), + ... ) + >>> print(ds.position.shape) # (3, 2, n_animals) + + Generate annotations from a blobs collection: + + >>> ds = from_idtracker( + ... trajectories_path="path/to/trajectories.npy", + ... frame_indices=[0, 10, 20], + ... blobs_collection_path="path/to/blobs_collection.pkl", + ... ) + + """ + # Input validation + load trajectories + trajectories_path = Path(trajectories_path) + if blobs_collection_path is not None: + blobs_collection_path = Path(blobs_collection_path) + + trajectories = _validate_inputs( + trajectories_path=trajectories_path, + frame_indices=frame_indices, + bbox_size=bbox_size, + blobs_collection_path=blobs_collection_path, + ) + n_total_frames, n_animals, _ = trajectories.shape + + # Sort and deduplicate frame indices + + frame_indices_sorted = sorted(set(frame_indices)) + n_selected_frames = len(frame_indices_sorted) + + # Build position / shape / category arrays + + if blobs_collection_path is not None: + logger.info( + "Loading bounding boxes from blobs collection: " + f"{blobs_collection_path}" + ) + position_arr, shape_arr, category_arr = _arrays_from_blobs( + blobs_collection_path, + frame_indices_sorted, + n_animals, + ) + else: + logger.info( + "Computing bounding boxes from trajectories with fixed " + f"bbox_size={bbox_size}." + ) + position_arr, shape_arr, category_arr = _arrays_from_trajectories( + trajectories, + frame_indices_sorted, + n_animals, + bbox_size, # type: ignore[arg-type] # cannot be None here + ) + + # Build metadata maps + + map_image_id_to_filename = { + img_id: f"frame_{frame_idx:06d}.png" + for img_id, frame_idx in enumerate(frame_indices_sorted) + } + map_category_to_str = { + animal_id + 1: f"animal_{animal_id + 1}" + for animal_id in range(n_animals) + } + + # Assemble xarray dataset + + return xr.Dataset( + data_vars={ + "position": (["image_id", "space", "id"], position_arr), + "shape": (["image_id", "space", "id"], shape_arr), + "category": (["image_id", "id"], category_arr), + }, + coords={ + "image_id": list(range(n_selected_frames)), + "space": ["x", "y"], + "id": list(range(n_animals)), + }, + attrs={ + "trajectories_file": str(trajectories_path), + "blobs_collection_file": ( + str(blobs_collection_path) + if blobs_collection_path is not None + else None + ), + "images_directory": ( + str(images_dir) if images_dir is not None else None + ), + "map_category_to_str": map_category_to_str, + "map_image_id_to_filename": map_image_id_to_filename, + }, + ) + + +def _validate_paths_and_bbox( + trajectories_path: Path, + frame_indices: list[int], + bbox_size: tuple[float, float] | None, + blobs_collection_path: Path | None, +) -> None: + """Validate file paths, frame indices and bbox_size. + + Parameters + ---------- + trajectories_path + Path to the trajectories ``.npy`` file. + frame_indices + List of 0-based frame indices to process. + bbox_size + Fixed bounding box ``(width, height)``, or ``None``. + blobs_collection_path + Path to the blobs collection ``.pkl``, or ``None``. + + Raises + ------ + FileNotFoundError + If either file path does not exist. + ValueError + If any input constraint is violated. + + """ + if not trajectories_path.exists(): + raise FileNotFoundError( + f"Trajectories file not found: {trajectories_path}" + ) + if not frame_indices: + raise ValueError("frame_indices must not be empty.") + negative = [i for i in frame_indices if i < 0] + if negative: + raise ValueError( + f"All frame indices must be non-negative integers, got {negative}." + ) + if blobs_collection_path is None and bbox_size is None: + raise ValueError( + "Either bbox_size or blobs_collection_path must be provided." + ) + if bbox_size is not None: + if len(bbox_size) != 2: + raise ValueError( + "bbox_size must be a tuple of exactly two elements " + "(width, height)." + ) + if any(v <= 0 for v in bbox_size): + raise ValueError( + "Both elements of bbox_size must be strictly positive, " + f"got {bbox_size}." + ) + blobs_missing = ( + blobs_collection_path is not None + and not blobs_collection_path.exists() + ) + if blobs_missing: + raise FileNotFoundError( + f"Blobs collection file not found: {blobs_collection_path}" + ) + + +def _validate_inputs( + trajectories_path: Path, + frame_indices: list[int], + bbox_size: tuple[float, float] | None, + blobs_collection_path: Path | None, +) -> np.ndarray: + """Validate all inputs and return the loaded trajectories array. + + Parameters + ---------- + trajectories_path + Path to the trajectories ``.npy`` file. + frame_indices + List of 0-based frame indices to process. + bbox_size + Fixed bounding box ``(width, height)``, or ``None``. + blobs_collection_path + Path to the blobs collection ``.pkl``, or ``None``. + + Returns + ------- + np.ndarray + Loaded trajectories array of shape ``(n_frames, n_animals, 2)``. + + Raises + ------ + FileNotFoundError + If either file path does not exist. + ValueError + If any input constraint is violated. + + """ + _validate_paths_and_bbox( + trajectories_path, frame_indices, bbox_size, blobs_collection_path + ) + trajectories = np.load(trajectories_path, allow_pickle=False) + if trajectories.ndim != 3 or trajectories.shape[2] != 2: + raise ValueError( + "Expected trajectories array of shape " + f"(n_frames, n_animals, 2), got {trajectories.shape}." + ) + out_of_range = [i for i in frame_indices if i >= trajectories.shape[0]] + if out_of_range: + raise ValueError( + f"Frame indices {out_of_range} are out of range for the " + f"trajectories array with {trajectories.shape[0]} frames." + ) + return trajectories + + +def _arrays_from_trajectories( + trajectories: np.ndarray, + frame_indices: list[int], + n_animals: int, + bbox_size: tuple[float, float], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build position, shape and category arrays from a trajectories array. + + Parameters + ---------- + trajectories + Array of shape ``(n_frames, n_animals, 2)`` holding centroid + coordinates ``(x, y)``. ``NaN`` signals an undetected animal. + frame_indices + Sorted, deduplicated list of 0-based frame indices to process. + n_animals + Number of tracked animals (second axis of ``trajectories``). + bbox_size + Fixed bounding box ``(width, height)`` in pixels applied to + every detected animal. + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + - ``position_arr`` shape ``(n_frames, 2, n_animals)``. + - ``shape_arr`` shape ``(n_frames, 2, n_animals)``. + - ``category_arr`` shape ``(n_frames, n_animals)``, dtype int. + ``NaN`` / ``-1`` are used for undetected animals. + + """ + n_frames = len(frame_indices) + bbox_width = float(bbox_size[0]) + bbox_height = float(bbox_size[1]) + + position_arr = np.full((n_frames, 2, n_animals), np.nan) + shape_arr = np.full((n_frames, 2, n_animals), np.nan) + category_arr = np.full((n_frames, n_animals), -1, dtype=int) + + for img_id, frame_idx in enumerate(frame_indices): + for animal_idx in range(n_animals): + centroid = trajectories[frame_idx, animal_idx, :] + if np.any(np.isnan(centroid)): + continue # not detected: keep NaN / -1 defaults + + position_arr[img_id, 0, animal_idx] = centroid[0] # x + position_arr[img_id, 1, animal_idx] = centroid[1] # y + shape_arr[img_id, 0, animal_idx] = bbox_width + shape_arr[img_id, 1, animal_idx] = bbox_height + category_arr[img_id, animal_idx] = animal_idx + 1 # 1-based + + return position_arr, shape_arr, category_arr + + +def _arrays_from_blobs( + blobs_collection_path: Path, + frame_indices: list[int], + n_animals: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build position, shape and category arrays from a blobs collection. + + Parameters + ---------- + blobs_collection_path + Path to a pickled blobs collection. The object must expose + ``.blobs_in_video``: a list (indexed by frame) of lists of blob + objects, each with: + + - ``.bounding_box``: ``[x_min, y_min, x_max, y_max]``; + - ``.identity``: 1-based int animal identity (0/None = unidentified). + + frame_indices + Sorted, deduplicated list of 0-based frame indices to process. + n_animals + Number of tracked animals derived from the trajectories array. + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + Same shape convention as :func:`_arrays_from_trajectories`. + + Raises + ------ + ValueError + If the loaded object does not have a ``.blobs_in_video`` attribute. + + """ + n_frames = len(frame_indices) + + with open(blobs_collection_path, "rb") as fh: + blobs_collection = pickle.load(fh) + + if not hasattr(blobs_collection, "blobs_in_video"): + raise ValueError( + "The blobs collection object loaded from " + f"'{blobs_collection_path}' does not have a " + "'blobs_in_video' attribute." + ) + + n_video_frames = len(blobs_collection.blobs_in_video) + + position_arr = np.full((n_frames, 2, n_animals), np.nan) + shape_arr = np.full((n_frames, 2, n_animals), np.nan) + category_arr = np.full((n_frames, n_animals), -1, dtype=int) + + for img_id, frame_idx in enumerate(frame_indices): + if frame_idx >= n_video_frames: + logger.warning( + f"Frame index {frame_idx} exceeds the number of frames in " + f"the blobs collection ({n_video_frames}). Skipping." + ) + continue + + for blob in blobs_collection.blobs_in_video[frame_idx]: + if not hasattr(blob, "bounding_box") or not hasattr( + blob, "identity" + ): + logger.warning( + f"Blob in frame {frame_idx} is missing 'bounding_box' " + "or 'identity' attributes. Skipping." + ) + continue + + identity = blob.identity + if not identity: # 0 or None: unidentified blob + continue + + animal_idx = identity - 1 # convert 1-based to 0-based + if animal_idx >= n_animals: + logger.warning( + f"Blob identity {identity} exceeds n_animals={n_animals}." + " Skipping." + ) + continue + + x_min, y_min, x_max, y_max = blob.bounding_box + width = float(x_max - x_min) + height = float(y_max - y_min) + + position_arr[img_id, 0, animal_idx] = x_min + width / 2.0 + position_arr[img_id, 1, animal_idx] = y_min + height / 2.0 + shape_arr[img_id, 0, animal_idx] = width + shape_arr[img_id, 1, animal_idx] = height + category_arr[img_id, animal_idx] = identity + + return position_arr, shape_arr, category_arr diff --git a/tests/test_unit/test_io_annotations/test_load_idtracker.py b/tests/test_unit/test_io_annotations/test_load_idtracker.py new file mode 100644 index 00000000..4065c19f --- /dev/null +++ b/tests/test_unit/test_io_annotations/test_load_idtracker.py @@ -0,0 +1,679 @@ +"""Tests for ethology.io.annotations.load_idtracker.""" + +import pickle +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr + +from ethology.io.annotations.load_idtracker import ( + _arrays_from_blobs, + _arrays_from_trajectories, + from_idtracker, +) + +# Constants used across all tests +_N_FRAMES = 10 +_N_ANIMALS = 3 +_BBOX_SIZE = (40.0, 30.0) + + +# Fixtures +@pytest.fixture +def sample_trajectories() -> np.ndarray: + """Return a synthetic trajectories array of shape (10, 3, 2). + + Animal 0: detected in every frame. + Animal 1: detected in even frames only (NaN on odd frames). + Animal 2: never detected (all NaN). + """ + rng = np.random.default_rng(42) + traj = rng.uniform(100.0, 900.0, size=(_N_FRAMES, _N_ANIMALS, 2)) + traj[1::2, 1, :] = np.nan # animal 1 absent on odd frames + traj[:, 2, :] = np.nan # animal 2 always absent + return traj + + +@pytest.fixture +def trajectories_file(tmp_path: Path, sample_trajectories: np.ndarray) -> Path: + """Save sample_trajectories to a .npy file and return its path.""" + path = tmp_path / "trajectories.npy" + np.save(path, sample_trajectories) + return path + + +class _Blob: + """Minimal picklable blob object matching the idtracker.ai interface.""" + + def __init__(self, identity: int, bounding_box: list[float]): + self.identity = identity + self.bounding_box = bounding_box + + +class _BlobsCollection: + """Minimal picklable blobs collection object.""" + + def __init__(self, blobs_in_video: list[list[_Blob]]): + self.blobs_in_video = blobs_in_video + + +@pytest.fixture +def sample_blobs_collection( + sample_trajectories: np.ndarray, +) -> _BlobsCollection: + """Return a picklable blobs collection whose bboxes match trajectories. + + For every detected animal in each frame the blob bounding box is + centred at the trajectory centroid with a fixed size of (40, 30) + pixels, i.e. [cx-20, cy-15, cx+20, cy+15]. + """ + blobs_in_video = [] + for frame_idx in range(_N_FRAMES): + frame_blobs = [] + for animal_idx in range(_N_ANIMALS): + centroid = sample_trajectories[frame_idx, animal_idx, :] + if np.any(np.isnan(centroid)): + continue + cx, cy = float(centroid[0]), float(centroid[1]) + frame_blobs.append( + _Blob( + identity=animal_idx + 1, + bounding_box=[cx - 20.0, cy - 15.0, cx + 20.0, cy + 15.0], + ) + ) + blobs_in_video.append(frame_blobs) + return _BlobsCollection(blobs_in_video) + + +@pytest.fixture +def blobs_collection_file( + tmp_path: Path, sample_blobs_collection: _BlobsCollection +) -> Path: + """Pickle sample_blobs_collection to disk and return its path.""" + path = tmp_path / "blobs_collection.pkl" + with open(path, "wb") as fh: + pickle.dump(sample_blobs_collection, fh) + return path + + +# Tests: from_idtracker – trajectories + fixed bbox_size +class TestFromIdtrackerTrajectories: + """Tests for from_idtracker when using trajectories + bbox_size.""" + + def test_returns_xarray_dataset(self, trajectories_file: Path): + """Output must be an xr.Dataset.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0, 2, 4], + bbox_size=_BBOX_SIZE, + ) + assert isinstance(ds, xr.Dataset) + + def test_required_data_vars_present(self, trajectories_file: Path): + """position, shape and category must all be present.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + for var in ("position", "shape", "category"): + assert var in ds.data_vars + + def test_required_dims_present(self, trajectories_file: Path): + """image_id, space and id must all be dimensions.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + for dim in ("image_id", "space", "id"): + assert dim in ds.dims + + def test_position_array_shape(self, trajectories_file: Path): + """Position shape: (n_selected_frames, 2, n_animals).""" + frame_indices = [0, 2, 4] + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=frame_indices, + bbox_size=_BBOX_SIZE, + ) + assert ds.position.shape == (len(frame_indices), 2, _N_ANIMALS) + + def test_shape_array_shape(self, trajectories_file: Path): + """Shape array shape: (n_selected_frames, 2, n_animals).""" + frame_indices = [0, 2, 4] + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=frame_indices, + bbox_size=_BBOX_SIZE, + ) + assert ds.shape.shape == (len(frame_indices), 2, _N_ANIMALS) + + def test_category_array_shape(self, trajectories_file: Path): + """Category shape: (n_selected_frames, n_animals).""" + frame_indices = [0, 2, 4] + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=frame_indices, + bbox_size=_BBOX_SIZE, + ) + assert ds.category.shape == (len(frame_indices), _N_ANIMALS) + + def test_space_coordinate_values(self, trajectories_file: Path): + """Space coordinate must be ['x', 'y'].""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert list(ds.coords["space"].values) == ["x", "y"] + + def test_image_id_coordinate_length(self, trajectories_file: Path): + """image_id length must equal the number of unique selected frames.""" + frame_indices = [0, 2, 4] + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=frame_indices, + bbox_size=_BBOX_SIZE, + ) + assert len(ds.coords["image_id"]) == len(frame_indices) + + def test_id_coordinate_length(self, trajectories_file: Path): + """Id coordinate length must equal n_animals.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert len(ds.coords["id"]) == _N_ANIMALS + + def test_nan_for_undetected_animal_in_frame( + self, + trajectories_file: Path, + sample_trajectories: np.ndarray, + ): + """Animal not detected in a frame must produce NaN position/shape.""" + # Frame 1 is odd -> animal 1 (idx 1) is absent + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[1], + bbox_size=_BBOX_SIZE, + ) + assert np.isnan(ds.position.values[0, :, 1]).all() + assert np.isnan(ds.shape.values[0, :, 1]).all() + + def test_minus_one_category_for_undetected_animal( + self, trajectories_file: Path + ): + """Undetected animal must have category == -1.""" + # Animal 2 is never detected + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert ds.category.values[0, 2] == -1 + + def test_detected_animal_position_equals_centroid( + self, + trajectories_file: Path, + sample_trajectories: np.ndarray, + ): + """Detected animal position must equal the trajectory centroid.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + # Animal 0 is always detected + assert ds.position.values[0, 0, 0] == pytest.approx( + sample_trajectories[0, 0, 0] + ) + assert ds.position.values[0, 1, 0] == pytest.approx( + sample_trajectories[0, 0, 1] + ) + + def test_fixed_bbox_size_in_shape_array(self, trajectories_file: Path): + """Shape array values must match the supplied bbox_size.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + # Animal 0 detected in frame 0 + assert ds.shape.values[0, 0, 0] == pytest.approx(_BBOX_SIZE[0]) + assert ds.shape.values[0, 1, 0] == pytest.approx(_BBOX_SIZE[1]) + + def test_category_values_are_one_based(self, trajectories_file: Path): + """Detected animal categories must be 1-based integers.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert ds.category.values[0, 0] == 1 # animal_idx=0 -> category=1 + + def test_duplicate_frame_indices_removed(self, trajectories_file: Path): + """Duplicate frame indices must be silently deduplicated.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0, 0, 2, 2, 4], + bbox_size=_BBOX_SIZE, + ) + assert len(ds.coords["image_id"]) == 3 # 0, 2, 4 + + def test_image_id_to_filename_map(self, trajectories_file: Path): + """map_image_id_to_filename must use frame_.png format.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[3, 7], + bbox_size=_BBOX_SIZE, + ) + assert ds.attrs["map_image_id_to_filename"] == { + 0: "frame_000003.png", + 1: "frame_000007.png", + } + + def test_image_id_assigned_in_sorted_order(self, trajectories_file: Path): + """Frames must be sorted before assigning image_ids.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[7, 3], # deliberately unsorted + bbox_size=_BBOX_SIZE, + ) + # image_id 0 -> frame 3, image_id 1 -> frame 7 + assert ds.attrs["map_image_id_to_filename"][0] == "frame_000003.png" + assert ds.attrs["map_image_id_to_filename"][1] == "frame_000007.png" + + def test_category_to_str_map(self, trajectories_file: Path): + """map_category_to_str must cover all animals with 'animal_'.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + expected = {i + 1: f"animal_{i + 1}" for i in range(_N_ANIMALS)} + assert ds.attrs["map_category_to_str"] == expected + + def test_trajectories_file_attribute(self, trajectories_file: Path): + """trajectories_file attr must equal the input path as string.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert ds.attrs["trajectories_file"] == str(trajectories_file) + + def test_blobs_collection_file_none_when_not_provided( + self, trajectories_file: Path + ): + """blobs_collection_file attr must be None when blobs not given.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert ds.attrs["blobs_collection_file"] is None + + def test_images_dir_stored_in_attributes( + self, trajectories_file: Path, tmp_path: Path + ): + """images_directory attr must match the supplied images_dir.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + images_dir=tmp_path, + ) + assert ds.attrs["images_directory"] == str(tmp_path) + + def test_images_dir_none_when_not_provided(self, trajectories_file: Path): + """images_directory attr must be None when images_dir not given.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + assert ds.attrs["images_directory"] is None + + +# Tests: from_idtracker – blobs collection +class TestFromIdtrackerBlobsCollection: + """Tests for from_idtracker when using a blobs collection.""" + + def test_returns_xarray_dataset( + self, + trajectories_file: Path, + blobs_collection_file: Path, + ): + """Output must be an xr.Dataset when blobs are used.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0, 2], + blobs_collection_path=blobs_collection_file, + ) + assert isinstance(ds, xr.Dataset) + + def test_position_from_blob_centroid( + self, + trajectories_file: Path, + blobs_collection_file: Path, + sample_trajectories: np.ndarray, + ): + """Position must equal the centroid implied by the blob bbox.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=blobs_collection_file, + ) + # Blob bbox centred at trajectory centroid -> position must match + assert ds.position.values[0, 0, 0] == pytest.approx( + sample_trajectories[0, 0, 0] + ) + assert ds.position.values[0, 1, 0] == pytest.approx( + sample_trajectories[0, 0, 1] + ) + + def test_shape_from_blob_bounding_box( + self, + trajectories_file: Path, + blobs_collection_file: Path, + ): + """Shape values must reflect the actual blob bbox dimensions.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=blobs_collection_file, + ) + # Blob: x_min=cx-20, x_max=cx+20 -> width=40 + # y_min=cy-15, y_max=cy+15 -> height=30 + assert ds.shape.values[0, 0, 0] == pytest.approx(40.0) + assert ds.shape.values[0, 1, 0] == pytest.approx(30.0) + + def test_blobs_collection_file_attribute( + self, + trajectories_file: Path, + blobs_collection_file: Path, + ): + """blobs_collection_file attr must equal the input path as string.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=blobs_collection_file, + ) + assert ds.attrs["blobs_collection_file"] == str(blobs_collection_file) + + def test_bbox_size_ignored_when_blobs_provided( + self, + trajectories_file: Path, + blobs_collection_file: Path, + ): + """bbox_size must be ignored when a blobs collection is supplied.""" + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=(999.0, 999.0), # should be overridden by blobs + blobs_collection_path=blobs_collection_file, + ) + # Width must come from the blob (40), not from bbox_size (999) + assert ds.shape.values[0, 0, 0] == pytest.approx(40.0) + + def test_nan_for_undetected_animal_via_blobs( + self, + trajectories_file: Path, + blobs_collection_file: Path, + ): + """Animal with no blob in a frame must produce NaN position/shape.""" + # Animal 2 is never detected -> no blob exists for it + ds = from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=blobs_collection_file, + ) + assert np.isnan(ds.position.values[0, :, 2]).all() + assert np.isnan(ds.shape.values[0, :, 2]).all() + assert ds.category.values[0, 2] == -1 + + +# Tests: from_idtracker – error cases +class TestFromIdtrackerErrors: + """Tests that from_idtracker raises the correct errors.""" + + def test_trajectories_file_not_found(self, tmp_path: Path): + """FileNotFoundError when trajectories file is missing.""" + with pytest.raises( + FileNotFoundError, match="Trajectories file not found" + ): + from_idtracker( + trajectories_path=tmp_path / "missing.npy", + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + + def test_blobs_file_not_found( + self, trajectories_file: Path, tmp_path: Path + ): + """FileNotFoundError when blobs collection file is missing.""" + with pytest.raises( + FileNotFoundError, match="Blobs collection file not found" + ): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=tmp_path / "missing.pkl", + ) + + def test_empty_frame_indices(self, trajectories_file: Path): + """ValueError when frame_indices is an empty list.""" + with pytest.raises( + ValueError, match="frame_indices must not be empty" + ): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[], + bbox_size=_BBOX_SIZE, + ) + + def test_negative_frame_index(self, trajectories_file: Path): + """ValueError when a frame index is negative.""" + with pytest.raises(ValueError, match="non-negative"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[-1, 0], + bbox_size=_BBOX_SIZE, + ) + + def test_out_of_range_frame_index(self, trajectories_file: Path): + """ValueError when a frame index exceeds trajectories length.""" + with pytest.raises(ValueError, match="out of range"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[_N_FRAMES + 5], + bbox_size=_BBOX_SIZE, + ) + + def test_no_bbox_size_and_no_blobs(self, trajectories_file: Path): + """ValueError when neither bbox_size nor blobs path is given.""" + with pytest.raises( + ValueError, + match="Either bbox_size or blobs_collection_path", + ): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + ) + + def test_bbox_size_wrong_number_of_elements(self, trajectories_file: Path): + """ValueError when bbox_size has the wrong number of elements.""" + with pytest.raises(ValueError, match="two elements"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=(50.0,), # type: ignore[arg-type] + ) + + def test_bbox_size_non_positive_width(self, trajectories_file: Path): + """ValueError when bbox_size width is non-positive.""" + with pytest.raises(ValueError, match="positive"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=(0.0, 20.0), + ) + + def test_bbox_size_non_positive_height(self, trajectories_file: Path): + """ValueError when bbox_size height is non-positive.""" + with pytest.raises(ValueError, match="positive"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + bbox_size=(20.0, -5.0), + ) + + def test_invalid_trajectories_ndim(self, tmp_path: Path): + """ValueError when trajectories array is 2-D.""" + path = tmp_path / "bad.npy" + np.save(path, np.zeros((_N_FRAMES, _N_ANIMALS))) + with pytest.raises(ValueError, match="Expected trajectories array"): + from_idtracker( + trajectories_path=path, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + + def test_invalid_trajectories_last_dim(self, tmp_path: Path): + """ValueError when last dim of trajectories array is not 2.""" + path = tmp_path / "bad.npy" + np.save(path, np.zeros((_N_FRAMES, _N_ANIMALS, 3))) + with pytest.raises(ValueError, match="Expected trajectories array"): + from_idtracker( + trajectories_path=path, + frame_indices=[0], + bbox_size=_BBOX_SIZE, + ) + + def test_blobs_missing_attribute( + self, trajectories_file: Path, tmp_path: Path + ): + """ValueError when blobs object has no blobs_in_video attribute.""" + bad_path = tmp_path / "bad.pkl" + with open(bad_path, "wb") as fh: + pickle.dump({"wrong_key": []}, fh) + with pytest.raises(ValueError, match="blobs_in_video"): + from_idtracker( + trajectories_path=trajectories_file, + frame_indices=[0], + blobs_collection_path=bad_path, + ) + + +# Tests: _arrays_from_trajectories +class TestArraysFromTrajectories: + """Unit tests for the _arrays_from_trajectories private helper.""" + + def test_output_shapes(self, sample_trajectories: np.ndarray): + """All three output arrays must have the correct shapes.""" + frame_indices = [0, 2, 4] + pos, shp, cat = _arrays_from_trajectories( + sample_trajectories, frame_indices, _N_ANIMALS, _BBOX_SIZE + ) + assert pos.shape == (3, 2, _N_ANIMALS) + assert shp.shape == (3, 2, _N_ANIMALS) + assert cat.shape == (3, _N_ANIMALS) + + def test_nan_for_always_undetected_animal( + self, sample_trajectories: np.ndarray + ): + """Animal 2 (always NaN) must produce NaN position and shape.""" + pos, shp, cat = _arrays_from_trajectories( + sample_trajectories, [0], _N_ANIMALS, _BBOX_SIZE + ) + assert np.isnan(pos[0, :, 2]).all() + assert np.isnan(shp[0, :, 2]).all() + assert cat[0, 2] == -1 + + def test_detected_animal_position_correct( + self, sample_trajectories: np.ndarray + ): + """Detected animal position must equal the trajectory centroid.""" + pos, _, _ = _arrays_from_trajectories( + sample_trajectories, [0], _N_ANIMALS, _BBOX_SIZE + ) + assert pos[0, 0, 0] == pytest.approx(sample_trajectories[0, 0, 0]) + assert pos[0, 1, 0] == pytest.approx(sample_trajectories[0, 0, 1]) + + def test_detected_animal_shape_equals_bbox_size( + self, sample_trajectories: np.ndarray + ): + """Detected animal shape must equal the supplied bbox_size.""" + _, shp, _ = _arrays_from_trajectories( + sample_trajectories, [0], _N_ANIMALS, _BBOX_SIZE + ) + assert shp[0, 0, 0] == pytest.approx(_BBOX_SIZE[0]) + assert shp[0, 1, 0] == pytest.approx(_BBOX_SIZE[1]) + + def test_detected_animal_category_is_one_based( + self, sample_trajectories: np.ndarray + ): + """Detected animal category must equal animal_idx + 1.""" + _, _, cat = _arrays_from_trajectories( + sample_trajectories, [0], _N_ANIMALS, _BBOX_SIZE + ) + assert cat[0, 0] == 1 + + def test_partially_detected_animal(self, sample_trajectories: np.ndarray): + """Animal 1 must be detected on even frames and absent on odd ones.""" + # Even frame: detected + _, _, cat_even = _arrays_from_trajectories( + sample_trajectories, [0], _N_ANIMALS, _BBOX_SIZE + ) + assert cat_even[0, 1] == 2 # 1-based + + # Odd frame: not detected + pos_odd, _, cat_odd = _arrays_from_trajectories( + sample_trajectories, [1], _N_ANIMALS, _BBOX_SIZE + ) + assert np.isnan(pos_odd[0, :, 1]).all() + assert cat_odd[0, 1] == -1 + + +# Tests: _arrays_from_blobs +class TestArraysFromBlobs: + """Unit tests for the _arrays_from_blobs private helper.""" + + def test_output_shapes(self, blobs_collection_file: Path): + """All three output arrays must have the correct shapes.""" + frame_indices = [0, 2] + pos, shp, cat = _arrays_from_blobs( + blobs_collection_file, frame_indices, _N_ANIMALS + ) + assert pos.shape == (2, 2, _N_ANIMALS) + assert shp.shape == (2, 2, _N_ANIMALS) + assert cat.shape == (2, _N_ANIMALS) + + def test_position_equals_blob_centroid( + self, + blobs_collection_file: Path, + sample_trajectories: np.ndarray, + ): + """Position must be the centre of the blob bounding box.""" + pos, _, _ = _arrays_from_blobs(blobs_collection_file, [0], _N_ANIMALS) + assert pos[0, 0, 0] == pytest.approx(sample_trajectories[0, 0, 0]) + assert pos[0, 1, 0] == pytest.approx(sample_trajectories[0, 0, 1]) + + def test_shape_equals_blob_bbox_dimensions( + self, blobs_collection_file: Path + ): + """Shape values must equal the actual blob bbox width and height.""" + _, shp, _ = _arrays_from_blobs(blobs_collection_file, [0], _N_ANIMALS) + assert shp[0, 0, 0] == pytest.approx(40.0) # width + assert shp[0, 1, 0] == pytest.approx(30.0) # height + + def test_missing_attribute_raises_value_error(self, tmp_path: Path): + """ValueError when the loaded object has no blobs_in_video attr.""" + bad_path = tmp_path / "bad.pkl" + with open(bad_path, "wb") as fh: + pickle.dump(object(), fh) + with pytest.raises(ValueError, match="blobs_in_video"): + _arrays_from_blobs(bad_path, [0], _N_ANIMALS)