diff --git a/docs/source/conf.py b/docs/source/conf.py index b113e3a2..8558a8aa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -46,10 +46,13 @@ "notfound.extension", "sphinx_design", "sphinx_gallery.gen_gallery", - "sphinx_sitemap", "sphinx.ext.autosectionlabel", ] +# Only enable sphinx_sitemap if not running linkcheck +if "linkcheck" not in sys.argv: + extensions.append("sphinx_sitemap") + # Configure the myst parser to enable cool markdown features # See https://sphinx-design.readthedocs.io myst_enable_extensions = [ diff --git a/ethology/io/annotations/__init__.py b/ethology/io/annotations/__init__.py index e69de29b..49ce4f6b 100644 --- a/ethology/io/annotations/__init__.py +++ b/ethology/io/annotations/__init__.py @@ -0,0 +1,10 @@ +"""Load and export annotations datasets.""" + +from . import load_bboxes, load_keypoints, save_bboxes, save_keypoints + +__all__ = [ + "load_bboxes", + "save_bboxes", + "load_keypoints", + "save_keypoints", +] diff --git a/ethology/io/annotations/load_keypoints.py b/ethology/io/annotations/load_keypoints.py new file mode 100644 index 00000000..03899458 --- /dev/null +++ b/ethology/io/annotations/load_keypoints.py @@ -0,0 +1,410 @@ +"""Load keypoints annotations into ``ethology``.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import xarray as xr + +from ethology.validators.annotations import ValidKeypointsAnnotationsDataset +from ethology.validators.utils import _check_output + + +def _require_sleap_io(): + try: + import sleap_io as sio # type: ignore + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "sleap-io is required for keypoints IO. " + "Install it with `pip install sleap-io`." + ) from exc + return sio + + +def _get_labeled_frames(labels: Any) -> list[Any]: + if hasattr(labels, "labeled_frames"): + return list(labels.labeled_frames) + if hasattr(labels, "frames"): + return list(labels.frames) + if hasattr(labels, "labeled_frames_by_video"): + frames: list[Any] = [] + for frames_list in labels.labeled_frames_by_video.values(): + frames.extend(list(frames_list)) + return frames + raise AttributeError( + "Could not find labeled frames on sleap Labels object." + ) + + +def _get_frame_index(frame: Any) -> int: + for attr in ["frame_idx", "frame_index", "frame_number"]: + if hasattr(frame, attr): + return int(getattr(frame, attr)) + raise AttributeError("Could not find frame index on labeled frame.") + + +def _get_video_filename(frame: Any) -> str | None: + if not hasattr(frame, "video") or frame.video is None: + return None + video = frame.video + for attr in ["filename", "path", "source", "name"]: + if hasattr(video, attr): + value = getattr(video, attr) + if value is None: + continue + return str(value) + return None + + +def _get_instances(frame: Any) -> list[Any]: + for attr in ["user_instances", "instances", "predicted_instances"]: + if hasattr(frame, attr): + instances = getattr(frame, attr) + if instances is None: + continue + instances_list = list(instances) + if instances_list: + return instances_list + return [] + + +def _points_from_point_objects( + points: list[Any], n_keypoints: int +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]: + coords = np.full((n_keypoints, 2), np.nan, dtype=np.float64) + confidence = np.full((n_keypoints,), np.nan, dtype=np.float64) + visibility = np.full((n_keypoints,), np.nan, dtype=np.float64) + for idx, point in enumerate(points): + if point is None or idx >= n_keypoints: + continue + x = getattr(point, "x", None) + y = getattr(point, "y", None) + if x is None or y is None: + continue + visible = getattr(point, "visible", None) + if visible is None: + visible = getattr(point, "is_visible", None) + if visible is False: + visibility[idx] = 0.0 + continue + if visible is True: + visibility[idx] = 1.0 + coords[idx] = [float(x), float(y)] + confidence[idx] = getattr( + point, + "score", + getattr(point, "confidence", np.nan), + ) + return coords, confidence, visibility + + +def _points_from_instance( + instance: Any, n_keypoints: int +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]: + for attr in ["numpy", "to_numpy", "points_array", "points"]: + if not hasattr(instance, attr): + continue + value = getattr(instance, attr) + data = value() if callable(value) else value + if isinstance(data, list): + return _points_from_point_objects(data, n_keypoints) + if isinstance(data, np.ndarray): + arr = data.astype(np.float64, copy=False) + if arr.ndim == 2 and arr.shape[1] >= 2: + if arr.shape[0] == n_keypoints: + return arr[:, :2], None, None + if arr.shape[1] == n_keypoints and arr.shape[0] >= 2: + return arr[:2, :].T, None, None + if arr.ndim == 3 and arr.shape[-1] >= 2: + # Some formats store (n_keypoints, 1, 2) + arr = arr.reshape(arr.shape[0], -1) + if arr.shape[0] == n_keypoints: + return arr[:, :2], None, None + raise ValueError( + "Unsupported instance points format in sleap Labels object." + ) + + +def _get_skeleton_keypoints(labels: Any) -> list[str]: + if hasattr(labels, "skeletons") and labels.skeletons: + skeleton = labels.skeletons[0] + if hasattr(skeleton, "nodes"): + return [node.name for node in skeleton.nodes] + if hasattr(labels, "skeleton"): + skeleton = labels.skeleton + if hasattr(skeleton, "nodes"): + return [node.name for node in skeleton.nodes] + return [] + + +def _infer_keypoint_count(instance: Any) -> int: + for attr in ["numpy", "to_numpy", "points_array", "points"]: + if not hasattr(instance, attr): + continue + value = getattr(instance, attr) + data = value() if callable(value) else value + if isinstance(data, list): + return len(data) + if isinstance(data, np.ndarray): + arr = data + if arr.ndim == 2: + if arr.shape[1] == 2: + return arr.shape[0] + if arr.shape[0] == 2: + return arr.shape[1] + return arr.shape[0] + if arr.ndim == 3: + return arr.shape[0] + raise ValueError("Could not infer keypoint count from instance.") + + +def _prepare_frame_records(labels: Any) -> list[dict[str, Any]]: + frame_records = [] + for frame in _get_labeled_frames(labels): + frame_idx = _get_frame_index(frame) + video_filename = _get_video_filename(frame) + frame_records.append( + { + "frame": frame, + "frame_idx": frame_idx, + "video_filename": video_filename, + } + ) + frame_records.sort( + key=lambda r: ( + r["video_filename"] or "", + r["frame_idx"], + ) + ) + return frame_records + + +def _frame_label(video_filename: str | None, frame_idx: int) -> str: + if video_filename: + return f"{video_filename}::frame_{frame_idx}" + return f"frame_{frame_idx}" + + +def _from_single_file( # noqa: C901 + file_path: Path | str, + format: Literal["SLEAP"], + images_dirs: Path | str | list[Path | str] | None, +) -> xr.Dataset: + if format != "SLEAP": + raise ValueError(f"Unsupported format: {format}") + + sio = _require_sleap_io() + labels = sio.load_file(file_path) + keypoint_names = _get_skeleton_keypoints(labels) + + frame_records = _prepare_frame_records(labels) + if not frame_records: + raise ValueError("No labeled frames found in keypoints file.") + + max_instances = 0 + for record in frame_records: + instances = _get_instances(record["frame"]) + max_instances = max(max_instances, len(instances)) + + if max_instances == 0: + raise ValueError("No instances found in keypoints file.") + + if not keypoint_names: + # Fallback: infer number of keypoints from the first instance + # Find the first frame that actually has instances + # (frame 0 might be empty) + first_instance_to_infer = None + for record in frame_records: + insts = _get_instances(record["frame"]) + if insts: + first_instance_to_infer = insts[0] + break + + if first_instance_to_infer is None: + raise ValueError("No instances found to infer keypoints.") + + n_keypoints = _infer_keypoint_count(first_instance_to_infer) + keypoint_names = [f"keypoint_{i}" for i in range(n_keypoints)] + + n_keypoints = len(keypoint_names) + n_frames = len(frame_records) + + position = np.full( + (n_frames, 2, n_keypoints, max_instances), + np.nan, + dtype=np.float64, + ) + confidence = np.full( + (n_frames, n_keypoints, max_instances), + np.nan, + dtype=np.float64, + ) + visibility = np.full( + (n_frames, n_keypoints, max_instances), + np.nan, + dtype=np.float64, + ) + + map_image_id_to_filename: dict[int, str] = {} + map_image_id_to_video: dict[int, str] = {} + map_image_id_to_frame_idx: dict[int, int] = {} + + for image_id, record in enumerate(frame_records): + frame = record["frame"] + frame_idx = record["frame_idx"] + video_filename = record["video_filename"] + map_image_id_to_filename[image_id] = _frame_label( + video_filename, frame_idx + ) + if video_filename: + map_image_id_to_video[image_id] = video_filename + map_image_id_to_frame_idx[image_id] = frame_idx + + for inst_idx, instance in enumerate(_get_instances(frame)): + coords, conf, vis = _points_from_instance(instance, n_keypoints) + if coords.shape[0] != n_keypoints: + raise ValueError( + "Instance keypoints do not match skeleton definition." + ) + position[image_id, :, :, inst_idx] = coords.T + if conf is not None: + confidence[image_id, :, inst_idx] = conf + if vis is not None: + visibility[image_id, :, inst_idx] = vis + + ds = xr.Dataset( + data_vars={ + "position": ( + ["image_id", "space", "keypoint", "id"], + position, + ), + }, + coords={ + "image_id": np.arange(n_frames), + "space": ["x", "y"], + "keypoint": keypoint_names, + "id": np.arange(max_instances), + }, + ) + + if np.isfinite(confidence).any(): + ds["confidence"] = ( + ["image_id", "keypoint", "id"], + confidence, + ) + if np.isfinite(visibility).any(): + ds["visibility"] = ( + ["image_id", "keypoint", "id"], + visibility, + ) + + ds.attrs = { + "annotation_files": file_path, + "annotation_format": format, + "images_directories": images_dirs, + "map_keypoint_to_str": dict(enumerate(keypoint_names)), + "map_image_id_to_filename": map_image_id_to_filename, + "map_image_id_to_video": map_image_id_to_video, + "map_image_id_to_frame_idx": map_image_id_to_frame_idx, + } + return ds + + +@_check_output(ValidKeypointsAnnotationsDataset) +def from_files( + file_paths: Path | str | list[Path | str], + format: Literal["SLEAP"] = "SLEAP", + images_dirs: Path | str | list[Path | str] | None = None, +) -> xr.Dataset: + """Load an ``ethology`` keypoints annotations dataset from a file. + + Parameters + ---------- + file_paths : pathlib.Path | str | list[pathlib.Path | str] + Path or list of paths to the input keypoints annotation files. + format : {"SLEAP"} + Format of the input annotation files. Currently only "SLEAP". + images_dirs : pathlib.Path | str | list[pathlib.Path | str], optional + Path or list of paths to the directories containing the images the + annotations refer to. The paths are added to the dataset attributes. + + Returns + ------- + xarray.Dataset + A valid keypoints annotations dataset with dimensions + `image_id`, `space`, `keypoint`, `id` and data variable `position`. + + """ + if isinstance(file_paths, list): + datasets = [] + map_keypoint_to_str = None + image_id_offset = 0 + for path in file_paths: + ds = _from_single_file( + path, format=format, images_dirs=images_dirs + ) + if map_keypoint_to_str is None: + map_keypoint_to_str = ds.attrs.get("map_keypoint_to_str") + elif map_keypoint_to_str != ds.attrs.get("map_keypoint_to_str"): + raise ValueError( + "Keypoint labels differ across input files; " + "cannot merge datasets." + ) + + ds = ds.assign_coords(image_id=ds.image_id + image_id_offset) + # Update mapping attrs to new image_id range + map_image_id_to_filename = {} + map_image_id_to_video = {} + map_image_id_to_frame_idx = {} + for old_id in ds.attrs["map_image_id_to_filename"]: + new_id = int(old_id) + image_id_offset + map_image_id_to_filename[new_id] = ds.attrs[ + "map_image_id_to_filename" + ][old_id] + if old_id in ds.attrs.get("map_image_id_to_video", {}): + map_image_id_to_video[new_id] = ds.attrs[ + "map_image_id_to_video" + ][old_id] + map_image_id_to_frame_idx[new_id] = ds.attrs[ + "map_image_id_to_frame_idx" + ][old_id] + ds.attrs["map_image_id_to_filename"] = map_image_id_to_filename + ds.attrs["map_image_id_to_video"] = map_image_id_to_video + ds.attrs["map_image_id_to_frame_idx"] = map_image_id_to_frame_idx + + datasets.append(ds) + image_id_offset += ds.sizes["image_id"] + + ds_all = xr.concat(datasets, dim="image_id") + ds_all.attrs = { + "annotation_files": file_paths, + "annotation_format": format, + "images_directories": images_dirs, + "map_keypoint_to_str": map_keypoint_to_str or {}, + "map_image_id_to_filename": { + k: v + for ds in datasets + for k, v in ds.attrs.get( + "map_image_id_to_filename", {} + ).items() + }, + "map_image_id_to_video": { + k: v + for ds in datasets + for k, v in ds.attrs.get("map_image_id_to_video", {}).items() + }, + "map_image_id_to_frame_idx": { + k: v + for ds in datasets + for k, v in ds.attrs.get( + "map_image_id_to_frame_idx", {} + ).items() + }, + } + return ds_all + + return _from_single_file( + file_paths, format=format, images_dirs=images_dirs + ) diff --git a/ethology/io/annotations/save_keypoints.py b/ethology/io/annotations/save_keypoints.py new file mode 100644 index 00000000..d3114c09 --- /dev/null +++ b/ethology/io/annotations/save_keypoints.py @@ -0,0 +1,209 @@ +"""Save ``ethology`` keypoints annotations datasets to various formats.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import xarray as xr + +from ethology.validators.annotations import ValidKeypointsAnnotationsDataset +from ethology.validators.utils import _check_input + + +def _require_sleap_io(): + try: + import sleap_io as sio # type: ignore + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "sleap-io is required for keypoints IO. " + "Install it with `pip install sleap-io`." + ) from exc + return sio + + +def _get_keypoint_names(ds: xr.Dataset) -> list[str]: + if "map_keypoint_to_str" in ds.attrs: + mapping = ds.attrs["map_keypoint_to_str"] + if isinstance(mapping, dict) and mapping: + return [mapping[i] for i in range(len(mapping))] + return [str(kp) for kp in ds.keypoint.values] + + +def _get_image_id_maps( + ds: xr.Dataset, +) -> tuple[dict[int, str], dict[int, str], dict[int, int]]: + map_image_id_to_filename = ds.attrs.get("map_image_id_to_filename", {}) + map_image_id_to_video = ds.attrs.get("map_image_id_to_video", {}) + map_image_id_to_frame_idx = ds.attrs.get("map_image_id_to_frame_idx", {}) + return ( + map_image_id_to_filename, + map_image_id_to_video, + map_image_id_to_frame_idx, + ) + + +def _build_sleap_objects(ds: xr.Dataset) -> Any: # noqa: C901 + sio = _require_sleap_io() + keypoint_names = _get_keypoint_names(ds) + + node_cls = getattr(sio, "Node", None) + skeleton_cls = getattr(sio, "Skeleton", None) + labeled_frame_cls = getattr(sio, "LabeledFrame", None) + instance_cls = getattr(sio, "Instance", None) + video_cls = getattr(sio, "Video", None) + point_cls = getattr(sio, "Point", None) + + if not all([skeleton_cls, labeled_frame_cls, instance_cls, video_cls]): + raise AttributeError( + "sleap-io is missing required classes for saving Labels." + ) + + # Type assertions after None check + assert skeleton_cls is not None + assert labeled_frame_cls is not None + assert instance_cls is not None + assert video_cls is not None + + nodes = ( + [node_cls(name=name) for name in keypoint_names] + if node_cls is not None + else keypoint_names + ) + try: + skeleton = skeleton_cls(nodes=nodes, edges=[]) + except TypeError: + skeleton = skeleton_cls(nodes=nodes) + + ( + map_image_id_to_filename, + map_image_id_to_video, + map_image_id_to_frame_idx, + ) = _get_image_id_maps(ds) # noqa: E501 + + videos: dict[str, Any] = {} + labeled_frames = [] + confidence = ds.get("confidence") + visibility = ds.get("visibility") + + for image_id in ds.image_id.values: + image_id_int = int(image_id) + video_filename = map_image_id_to_video.get( + image_id_int, + map_image_id_to_filename.get(image_id_int, ""), + ) + if not video_filename: + raise ValueError( + "Missing video or filename information in dataset attrs." + ) + + if video_filename not in videos: + try: + video = video_cls.from_filename(video_filename) # type: ignore + except AttributeError: + try: + video = video_cls(filename=video_filename) # type: ignore + except TypeError: + video = video_cls(video_filename) # type: ignore + videos[video_filename] = video + else: + video = videos[video_filename] + + frame_idx = map_image_id_to_frame_idx.get(image_id_int, image_id_int) + + try: + labeled_frame = labeled_frame_cls(video=video, frame_idx=frame_idx) # type: ignore + except TypeError: + labeled_frame = labeled_frame_cls(video, frame_idx) # type: ignore + + instances = [] + for inst_id in ds.id.values: + coords = ds.position.sel(image_id=image_id, id=inst_id).values + if np.isnan(coords).all(): + continue + points: list[Any] = [] + for kp_idx, _name in enumerate(keypoint_names): + x, y = coords[:, kp_idx] + if np.isnan(x) or np.isnan(y): + points.append(None) + continue + score = None + if confidence is not None: + score = float( + confidence.sel(image_id=image_id, id=inst_id).values[ + kp_idx + ] + ) + visible = None + if visibility is not None: + visible = float( + visibility.sel(image_id=image_id, id=inst_id).values[ + kp_idx + ] + ) + if point_cls is not None: + kwargs = {} + if score is not None and not np.isnan(score): + kwargs["score"] = score + if visible is not None and not np.isnan(visible): + kwargs["visible"] = bool(int(visible)) + point = point_cls(x=float(x), y=float(y), **kwargs) + points.append(point) + else: + points.append([float(x), float(y)]) + + try: + instance = instance_cls(points=points, skeleton=skeleton) # type: ignore + except TypeError: + instance = instance_cls(points, skeleton) # type: ignore + instances.append(instance) + + if instances: + labeled_frame.instances = instances + labeled_frames.append(labeled_frame) + + try: + labels = sio.Labels( + labeled_frames=labeled_frames, skeletons=[skeleton] + ) # type: ignore + except TypeError: + labels = sio.Labels(labeled_frames) # type: ignore + if hasattr(labels, "skeletons"): + labels.skeletons = [skeleton] + return labels + + +@_check_input(validator=ValidKeypointsAnnotationsDataset) +def to_file( + dataset: xr.Dataset, + output_filepath: str | Path, + format: Literal["SLEAP"] = "SLEAP", +) -> str | Path: + """Save an ``ethology`` keypoints annotations dataset to a file. + + Parameters + ---------- + dataset : xarray.Dataset + Keypoints annotations xarray dataset. + output_filepath : str or pathlib.Path + Path for the output file. + format : {"SLEAP"} + Format of the output file. + + Returns + ------- + str or pathlib.Path + Path for the output file. + + """ + if format != "SLEAP": + raise ValueError(f"Unsupported format: {format}") + + sio = _require_sleap_io() + labels = _build_sleap_objects(dataset) + try: + sio.save_file(labels, output_filepath) + except TypeError: + sio.save_file(output_filepath, labels) + return output_filepath diff --git a/ethology/validators/annotations.py b/ethology/validators/annotations.py index 6a02f10a..0977ab47 100644 --- a/ethology/validators/annotations.py +++ b/ethology/validators/annotations.py @@ -265,6 +265,53 @@ class ValidBboxAnnotationsDataset(ValidDataset): } +@define +class ValidKeypointsAnnotationsDataset(ValidDataset): + """Class for valid ``ethology`` keypoints annotations datasets. + + This class validates that the input dataset: + + - is an xarray Dataset, + - has ``image_id``, ``space``, ``keypoint``, ``id`` as dimensions, + - has ``position`` as a data variable, + - ``position`` spans at least the dimensions ``image_id``, ``space``, + ``keypoint`` and ``id``. + + Attributes + ---------- + dataset : xarray.Dataset + The xarray dataset to validate. + required_dims : ClassVar[set] + The set of required dimension names: ``image_id``, ``space``, + ``keypoint`` and ``id``. + required_data_vars : ClassVar[dict[str, set]] + A dictionary mapping data variable names to their required minimum + dimensions: + + - ``position`` maps to ``image_id``, ``space``, ``keypoint`` and + ``id``. + + Raises + ------ + TypeError + If the input is not an xarray Dataset. + ValueError + If the dataset is missing required data variables or dimensions, + or if any required dimensions are missing for any data variable. + + Notes + ----- + The dataset can have other data variables and dimensions, but only the + required ones are checked. + + """ + + required_dims: ClassVar[set] = {"image_id", "space", "keypoint", "id"} + required_data_vars: ClassVar[dict[str, set]] = { + "position": {"image_id", "space", "keypoint", "id"}, + } + + class ValidBboxAnnotationsDataFrame(pa.DataFrameModel): """Class for valid bounding boxes intermediate dataframes. diff --git a/pyproject.toml b/pyproject.toml index 1ae41b6a..36d38778 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "pandera[pandas]", "pycocotools", "scikit-learn", + "sleap-io", "torch", "torchvision", "loguru", diff --git a/tests/fixtures/annotations.py b/tests/fixtures/annotations.py index 49cc9415..282a6ef9 100644 --- a/tests/fixtures/annotations.py +++ b/tests/fixtures/annotations.py @@ -163,3 +163,51 @@ def valid_bbox_annotations_dataset_extra_vars_and_dims( ds["extra_var_1"] = (["image_id"], np.random.rand(len(ds.image_id))) ds["extra_var_2"] = (["id"], np.random.rand(len(ds.id))) return ds + + +# ----------------- Keypoints dataset validation fixtures ----------------- +@pytest.fixture +def valid_keypoints_annotations_dataset(): + """Create a valid keypoints annotations dataset for validation.""" + image_ids = [1, 2] + annotation_ids = [0, 1] + keypoints = ["nose", "tail"] + space_dims = ["x", "y"] + + position_data = np.zeros( + ( + len(image_ids), + len(space_dims), + len(keypoints), + len(annotation_ids), + ) + ) + + ds = xr.Dataset( + data_vars={ + "position": ( + ["image_id", "space", "keypoint", "id"], + position_data, + ), + }, + coords={ + "image_id": image_ids, + "space": space_dims, + "keypoint": keypoints, + "id": annotation_ids, + }, + ) + + return ds + + +@pytest.fixture +def valid_keypoints_annotations_dataset_extra_vars_and_dims( + valid_keypoints_annotations_dataset: xr.Dataset, +) -> xr.Dataset: + """Create a valid keypoints annotations dataset with extra dims/vars.""" + ds = valid_keypoints_annotations_dataset.copy(deep=True) + ds.coords["extra_dim"] = [10, 20] + ds["extra_var_1"] = (["image_id"], np.random.rand(len(ds.image_id))) + ds["extra_var_2"] = (["id"], np.random.rand(len(ds.id))) + return ds diff --git a/tests/test_unit/test_io_annotations/test_load_keypoints.py b/tests/test_unit/test_io_annotations/test_load_keypoints.py new file mode 100644 index 00000000..37221299 --- /dev/null +++ b/tests/test_unit/test_io_annotations/test_load_keypoints.py @@ -0,0 +1,485 @@ +"""Test loading keypoints annotations into ethology datasets.""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import xarray as xr + +from ethology.io.annotations.load_keypoints import ( + _frame_label, + _from_single_file, + _get_frame_index, + _get_instances, + _get_labeled_frames, + _get_skeleton_keypoints, + _get_video_filename, + _infer_keypoint_count, + _points_from_instance, + _points_from_point_objects, + _prepare_frame_records, + _require_sleap_io, + from_files, +) + + +def test_require_sleap_io_import_success(): + """Test successful import of sleap_io when it exists.""" + try: + sio = _require_sleap_io() + assert sio is not None + assert hasattr(sio, "load_file") + except ModuleNotFoundError: + pytest.skip("sleap-io not installed") + + +def test_require_sleap_io_import_missing(): + """Test that ModuleNotFoundError is raised when sleap_io is missing.""" + with ( + patch.dict("sys.modules", {"sleap_io": None}), + pytest.raises(ModuleNotFoundError, match="sleap-io is required"), + ): + _require_sleap_io() + + +def test_get_labeled_frames(): + """Test extracting labeled frames from various attributes.""" + # Test 'labeled_frames' attribute + mock_labels_1 = MagicMock() + mock_frame1, mock_frame2 = MagicMock(), MagicMock() + mock_labels_1.labeled_frames = [mock_frame1, mock_frame2] + assert _get_labeled_frames(mock_labels_1) == [mock_frame1, mock_frame2] + + # Test fallback to 'frames' attribute + mock_labels_2 = MagicMock(spec=[]) + mock_labels_2.frames = [mock_frame1] + assert _get_labeled_frames(mock_labels_2) == [mock_frame1] + + # Test fallback to 'labeled_frames_by_video' + mock_labels_3 = MagicMock(spec=[]) + mock_labels_3.labeled_frames_by_video = {"v1": [mock_frame1]} + assert _get_labeled_frames(mock_labels_3) == [mock_frame1] + + # Test error + mock_labels_empty = MagicMock(spec=[]) + with pytest.raises(AttributeError, match="Could not find labeled frames"): + _get_labeled_frames(mock_labels_empty) + + +@pytest.mark.parametrize( + "attr_name, attr_value", + [ + ("frame_idx", 10), + ("frame_index", 20), + ("frame_number", 30), + ("frame_idx", "42"), # Test string conversion + ], +) +def test_get_frame_index(attr_name, attr_value): + """Test successful extraction of frame index from various attributes.""" + # Use spec=[attr_name] to ensure the mock ONLY has this attribute. + mock_frame = MagicMock(spec=[attr_name]) + setattr(mock_frame, attr_name, attr_value) + + result = _get_frame_index(mock_frame) + assert result == int(attr_value) + assert isinstance(result, int) + + +def test_get_frame_index_error(): + """Test that AttributeError is raised when no valid attribute exists.""" + mock_frame = MagicMock(spec=[]) + with pytest.raises(AttributeError, match="Could not find frame index"): + _get_frame_index(mock_frame) + + +@pytest.mark.parametrize( + "video_attr, expected_filename", + [ + ({"filename": "v.mp4"}, "v.mp4"), + ({"path": "v.mp4", "filename": None}, "v.mp4"), + (None, None), # No video object + ({}, None), # Video object with no path/filename + ], +) +def test_get_video_filename(video_attr, expected_filename): + """Test extraction of filename from video object.""" + mock_frame = MagicMock() + if video_attr is None: + mock_frame.video = None + else: + mock_frame.video = MagicMock() + for k, v in video_attr.items(): + setattr(mock_frame.video, k, v) + # Handle case where attributes are missing from spec + if not video_attr: + mock_frame.video = MagicMock(spec=[]) + + assert _get_video_filename(mock_frame) == expected_filename + + +@pytest.mark.parametrize( + "attr_config, expected_count", + [ + ({"user_instances": [1, 2]}, 2), + ({"instances": [1, 2]}, 2), + ({"predicted_instances": [1, 2]}, 2), + ({"user_instances": [], "instances": []}, 0), + ({"user_instances": None}, 0), + ], +) +def test_get_instances(attr_config, expected_count): + """Test successful extraction of instances.""" + mock_frame = MagicMock() + # Set all potential attributes to None/Empty first + mock_frame.user_instances = None + mock_frame.instances = None + mock_frame.predicted_instances = None + + for k, v in attr_config.items(): + setattr(mock_frame, k, v) + + result = _get_instances(mock_frame) + assert len(result) == expected_count + + +def test_points_from_point_objects(): + """Test extraction of points from a list of point objects.""" + # Standard case + p1 = MagicMock(x=10.0, y=20.0, visible=True, score=0.95) + p2 = MagicMock(x=30.0, y=40.0, visible=True, score=0.85) + + coords, conf, vis = _points_from_point_objects([p1, p2], n_keypoints=2) + assert np.allclose(coords, [[10, 20], [30, 40]]) + assert np.allclose(conf, [0.95, 0.85]) + assert np.allclose(vis, [1.0, 1.0]) + + # Invisible / Missing case + p_inv = MagicMock(x=10.0, y=20.0, visible=False) + coords, _, vis = _points_from_point_objects([p_inv, None], n_keypoints=2) + assert np.isnan(coords[0]).all() + assert vis[0] == 0.0 + assert np.isnan(coords[1]).all() + + # Point with x=None + p_no_x = MagicMock(spec=["x", "y"]) + p_no_x.x = None + p_no_x.y = 5.0 + coords, _, _ = _points_from_point_objects([p_no_x], n_keypoints=1) + assert np.isnan(coords[0]).all() + + # point using is_visible fallback.. + p_isvis = MagicMock(spec=["x", "y", "is_visible", "score"]) + p_isvis.x = 10.0 + p_isvis.y = 20.0 + p_isvis.is_visible = True + p_isvis.score = 0.5 + coords, conf, vis = _points_from_point_objects([p_isvis], n_keypoints=1) + assert np.allclose(coords[0], [10.0, 20.0]) + assert vis[0] == 1.0 + + +def test_points_from_instance(): + """Test extraction of points from instance (numpy vs list).""" + # Numpy Array Case + mock_inst_np = MagicMock() + mock_inst_np.numpy = np.array([[10.0, 20.0], [30.0, 40.0]]) + c, _, _ = _points_from_instance(mock_inst_np, 2) + assert np.allclose(c, [[10.0, 20.0], [30.0, 40.0]]) + + # transposed 2D array: shape (2, 3) with n_keypoints=3 + mock_inst_t = MagicMock() + mock_inst_t.numpy = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]) + c, _, _ = _points_from_instance(mock_inst_t, 3) + assert c.shape == (3, 2) + assert np.allclose(c[0], [10.0, 40.0]) + + # 3D Array Case (Reshape) + mock_inst_3d = MagicMock() + mock_inst_3d.numpy = np.array([[[10.0, 20.0]], [[30.0, 40.0]]]) + c, _, _ = _points_from_instance(mock_inst_3d, 2) + assert c.shape == (2, 2) + + # List Case + mock_inst_list = MagicMock() + mock_inst_list.points = [MagicMock(x=10.0, y=20.0, visible=True)] + c, _, _ = _points_from_instance(mock_inst_list, 1) + assert c.shape == (1, 2) + + # Error Case + with pytest.raises(ValueError, match="Unsupported instance points format"): + _points_from_instance(MagicMock(spec=[]), 1) + + +def test_get_skeleton_keypoints(): + """Test extraction of keypoint names from skeletons.""" + node = MagicMock() + node.name = "n1" + + # Skeletons list + mock_labels = MagicMock() + mock_labels.skeletons = [MagicMock(nodes=[node])] + assert _get_skeleton_keypoints(mock_labels) == ["n1"] + + # Single skeleton attribute + node2 = MagicMock() + node2.name = "n2" + mock_labels.skeletons = [] + mock_labels.skeleton = MagicMock(nodes=[node2]) + assert _get_skeleton_keypoints(mock_labels) == ["n2"] + + +def test_infer_keypoint_count(): + """Test inferring keypoint count from different formats.""" + # Numpy (n, 2) + mock_np = MagicMock() + mock_np.numpy = np.zeros((3, 2)) + assert _infer_keypoint_count(mock_np) == 3 + + # Transposed numpy (2, n) where n != 2 + mock_t = MagicMock() + mock_t.numpy = np.zeros((2, 4)) + assert _infer_keypoint_count(mock_t) == 4 + + # Ambiguous 2D (neither dim is 2) + mock_a = MagicMock() + mock_a.numpy = np.zeros((5, 3)) + assert _infer_keypoint_count(mock_a) == 5 + + # 3D array + mock_3d = MagicMock() + mock_3d.numpy = np.zeros((4, 1, 3)) + assert _infer_keypoint_count(mock_3d) == 4 + + # List + mock_list = MagicMock() + mock_list.points = [1, 2, 3] + assert _infer_keypoint_count(mock_list) == 3 + + # Error case + with pytest.raises(ValueError, match="Could not infer keypoint count"): + _infer_keypoint_count(MagicMock(spec=[])) + + +@pytest.mark.parametrize( + "video_file, frame_idx, expected", + [ + ("v.mp4", 42, "v.mp4::frame_42"), + (None, 42, "frame_42"), + ], +) +def test_frame_label(video_file, frame_idx, expected): + assert _frame_label(video_file, frame_idx) == expected + + +def test_prepare_frame_records_sorting(): + """Test that frame records are sorted by video filename and frame index.""" + mock_labels = MagicMock() + f1 = MagicMock(frame_idx=5, video=MagicMock(filename="b.mp4")) + f2 = MagicMock(frame_idx=1, video=MagicMock(filename="a.mp4")) + f3 = MagicMock(frame_idx=3, video=None) + mock_labels.labeled_frames = [f1, f2, f3] + + records = _prepare_frame_records(mock_labels) + # Expected order: No Video (f3), a.mp4 (f2), b.mp4 (f1) + assert records[0]["frame_idx"] == 3 + assert records[1]["frame_idx"] == 1 + assert records[2]["frame_idx"] == 5 + + +def test_from_files_unsupported_format(tmp_path): + """Test that ValueError is raised for unsupported formats.""" + p = tmp_path / "test.txt" + p.touch() + with pytest.raises(ValueError, match="Unsupported format"): + from_files(p, format="INVALID") + + +@patch("ethology.io.annotations.load_keypoints._from_single_file") +def test_from_files_concatenation(mock_single): + """Test concatenation of multiple file datasets.""" + common_attrs = { + "map_keypoint_to_str": {0: "n1"}, + "map_image_id_to_filename": {0: "f"}, + "map_image_id_to_video": {0: "v.mp4"}, + "map_image_id_to_frame_idx": {0: 0}, + } + + ds1 = xr.Dataset( + { + "position": ( + ("image_id", "space", "keypoint", "id"), + np.zeros((1, 2, 1, 1)), + ) + }, + coords={ + "image_id": [0], + "keypoint": ["n1"], + "space": ["x", "y"], + "id": [0], + }, + attrs=common_attrs.copy(), + ) + ds1.attrs["map_image_id_to_filename"] = {0: "f1"} + ds1.attrs["map_image_id_to_video"] = {0: "v1.mp4"} + + ds2 = xr.Dataset( + { + "position": ( + ("image_id", "space", "keypoint", "id"), + np.zeros((1, 2, 1, 1)), + ) + }, + coords={ + "image_id": [0], + "keypoint": ["n1"], + "space": ["x", "y"], + "id": [0], + }, + attrs=common_attrs.copy(), + ) + ds2.attrs["map_image_id_to_filename"] = {0: "f2"} + ds2.attrs["map_image_id_to_video"] = {0: "v2.mp4"} + + mock_single.side_effect = [ds1, ds2] + + ds = from_files(["a", "b"], format="SLEAP") + + assert ds.sizes["image_id"] == 2 + assert ds.attrs["map_image_id_to_filename"] == {0: "f1", 1: "f2"} + assert ds.attrs["map_image_id_to_video"] == {0: "v1.mp4", 1: "v2.mp4"} + + +@patch("ethology.io.annotations.load_keypoints._from_single_file") +def test_from_files_mismatch_error(mock_single): + """Test error when keypoint labels differ across files.""" + ds1 = xr.Dataset( + coords={"image_id": [0]}, + attrs={ + "map_keypoint_to_str": {0: "A"}, + "map_image_id_to_filename": {0: "f"}, + "map_image_id_to_frame_idx": {0: 0}, + }, + ) + ds2 = xr.Dataset( + coords={"image_id": [0]}, + attrs={ + "map_keypoint_to_str": {0: "B"}, + "map_image_id_to_filename": {0: "f"}, + "map_image_id_to_frame_idx": {0: 0}, + }, + ) + mock_single.side_effect = [ds1, ds2] + + with pytest.raises(ValueError, match="Keypoint labels differ"): + from_files(["a", "b"], format="SLEAP") + + +@patch("ethology.io.annotations.load_keypoints._require_sleap_io") +def test_from_single_file_integration_mock(mock_require): + """Test the full flow of _from_single_file using mocks.""" + mock_sio = mock_require.return_value + + inst = MagicMock() + inst.points = [MagicMock(x=10, y=20, visible=True, score=0.9)] + frame = MagicMock( + frame_idx=0, + video=MagicMock(filename="v.mp4"), + user_instances=[inst], + ) + + node = MagicMock() + node.name = "nose" + skel = MagicMock(nodes=[node]) + + labels = MagicMock(labeled_frames=[frame], skeletons=[skel]) + mock_sio.load_file.return_value = labels + + ds = _from_single_file("test.slp", "SLEAP", None) + + assert "position" in ds + assert ds.keypoint.values[0] == "nose" + assert np.allclose(ds.position.values[0, 0, 0, 0], 10) + assert np.allclose(ds.position.values[0, 1, 0, 0], 20) + assert ds.confidence.values[0, 0, 0] == 0.9 + + +@patch("ethology.io.annotations.load_keypoints._require_sleap_io") +def test_from_single_file_inference_fallback(mock_require): + """Test that keypoints are inferred when no skeleton is present.""" + mock_sio = mock_require.return_value + + p1 = MagicMock(x=10, y=10, visible=True) + p2 = MagicMock(x=20, y=20, visible=True) + inst = MagicMock(points=[p1, p2]) + + frame = MagicMock(frame_idx=0, video=None, user_instances=[inst]) + + # No skeletons provided! + labels = MagicMock(labeled_frames=[frame], skeletons=[]) + mock_sio.load_file.return_value = labels + + ds = _from_single_file("test.slp", "SLEAP", None) + + assert ds.sizes["keypoint"] == 2 + assert ds.keypoint.values.tolist() == ["keypoint_0", "keypoint_1"] + + +@patch("ethology.io.annotations.load_keypoints._require_sleap_io") +def test_from_single_file_errors(mock_require): + """Test error conditions in single file loading.""" + mock_sio = mock_require.return_value + + # Case: No Frames + mock_sio.load_file.return_value = MagicMock(labeled_frames=[]) + with pytest.raises(ValueError, match="No labeled frames found"): + _from_single_file("t.slp", "SLEAP", None) + + # Case: No Instances + frame_empty = MagicMock(user_instances=[]) + mock_sio.load_file.return_value = MagicMock(labeled_frames=[frame_empty]) + with pytest.raises(ValueError, match="No instances found"): + _from_single_file("t.slp", "SLEAP", None) + + +@patch("ethology.io.annotations.load_keypoints._require_sleap_io") +def test_from_single_file_unsupported_instance_format(mock_require): + """Test error when instance points have an unsupported format.""" + mock_sio = mock_require.return_value + + node = MagicMock() + node.name = "nose" + skel = MagicMock(nodes=[node]) + + # Instance with no recognized points attribute + inst = MagicMock(spec=[]) + frame = MagicMock(frame_idx=0, video=None, user_instances=[inst]) + + mock_sio.load_file.return_value = MagicMock( + labeled_frames=[frame], skeletons=[skel] + ) + + with pytest.raises(ValueError, match="Unsupported instance points format"): + _from_single_file("d.slp", "SLEAP", None) + + +@patch("ethology.io.annotations.load_keypoints._require_sleap_io") +def test_from_single_file_multiple_instances(mock_require): + """Test multiple instances stack in the 'id' dimension.""" + mock_sio = mock_require.return_value + + inst1 = MagicMock(points=[MagicMock(x=10, y=10, visible=True)]) + inst2 = MagicMock(points=[MagicMock(x=20, y=20, visible=True)]) + + frame = MagicMock(frame_idx=0, video=None, user_instances=[inst1, inst2]) + + node = MagicMock() + node.name = "k1" + labels = MagicMock( + labeled_frames=[frame], skeletons=[MagicMock(nodes=[node])] + ) + mock_sio.load_file.return_value = labels + + ds = _from_single_file("test.slp", "SLEAP", None) + + assert ds.sizes["id"] == 2 diff --git a/tests/test_unit/test_io_annotations/test_save_keypoints.py b/tests/test_unit/test_io_annotations/test_save_keypoints.py new file mode 100644 index 00000000..2acff4ca --- /dev/null +++ b/tests/test_unit/test_io_annotations/test_save_keypoints.py @@ -0,0 +1,414 @@ +"""Test saving keypoints annotations to file formats.""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import xarray as xr + +from ethology.io.annotations.save_keypoints import ( + _build_sleap_objects, + _get_image_id_maps, + _get_keypoint_names, + _require_sleap_io, + to_file, +) + +# ============================================================================ +# Helper Functions for Testing +# ============================================================================ + + +def create_valid_keypoints_dataset( + n_images: int = 2, + n_keypoints: int = 2, + n_instances: int = 1, + include_confidence: bool = False, + include_visibility: bool = False, +) -> xr.Dataset: + """Create a valid keypoints dataset for testing.""" + position_data = np.random.rand(n_images, 2, n_keypoints, n_instances) * 100 + + data_vars = { + "position": ( + ["image_id", "space", "keypoint", "id"], + position_data, + ), + } + + if include_confidence: + data_vars["confidence"] = ( + ["image_id", "keypoint", "id"], + np.random.rand(n_images, n_keypoints, n_instances), + ) + + if include_visibility: + data_vars["visibility"] = ( + ["image_id", "keypoint", "id"], + np.random.rand(n_images, n_keypoints, n_instances), + ) + + ds = xr.Dataset( + data_vars=data_vars, + coords={ + "image_id": np.arange(n_images), + "space": ["x", "y"], + "keypoint": [f"kp_{i}" for i in range(n_keypoints)], + "id": np.arange(n_instances), + }, + ) + + # Add required attributes + ds.attrs = { + "annotation_format": "SLEAP", + "map_keypoint_to_str": { + i: f"keypoint_{i}" for i in range(n_keypoints) + }, + "map_image_id_to_filename": { + i: f"frame_{i}.png" for i in range(n_images) + }, + "map_image_id_to_video": { + i: f"video_{i}.mp4" for i in range(n_images) + }, + "map_image_id_to_frame_idx": {i: i for i in range(n_images)}, + } + + return ds + + +def test_require_sleap_io_import_success(): + """Test successful import of sleap_io when installed.""" + try: + sio = _require_sleap_io() + assert sio is not None + assert hasattr(sio, "load_file") + except ModuleNotFoundError: + pytest.skip("sleap-io not installed") + + +def test_require_sleap_io_import_missing(): + """Test that ModuleNotFoundError is raised when sleap_io missing.""" + with ( + patch.dict("sys.modules", {"sleap_io": None}), + pytest.raises(ModuleNotFoundError, match="sleap-io is required"), + ): + _require_sleap_io() + + +def test_get_keypoint_names_from_map(): + """Test extraction of keypoint names from map_keypoint_to_str.""" + ds = create_valid_keypoints_dataset(n_keypoints=3) + names = _get_keypoint_names(ds) + assert len(names) == 3 + assert names == ["keypoint_0", "keypoint_1", "keypoint_2"] + + +def test_get_keypoint_names_from_coordinates(): + """Test extraction of keypoint names from coordinates.""" + ds = create_valid_keypoints_dataset(n_keypoints=2) + del ds.attrs["map_keypoint_to_str"] + + names = _get_keypoint_names(ds) + assert len(names) == 2 + assert all(isinstance(name, str) for name in names) + + +def test_get_keypoint_names_empty_map(): + """Test fallback when map_keypoint_to_str is empty.""" + ds = create_valid_keypoints_dataset(n_keypoints=2) + ds.attrs["map_keypoint_to_str"] = {} + + names = _get_keypoint_names(ds) + assert len(names) == 2 + assert all(isinstance(name, str) for name in names) + + +def test_get_image_id_maps_all_present(): + """Test extraction of all image ID mapping attributes.""" + ds = create_valid_keypoints_dataset() + maps = _get_image_id_maps(ds) + + assert len(maps) == 3 + assert len(maps[0]) == 2 # filename + assert len(maps[1]) == 2 # video + assert len(maps[2]) == 2 # frame_idx + + +def test_get_image_id_maps_missing_attributes(): + """Test handling of missing mapping attributes.""" + ds = create_valid_keypoints_dataset() + del ds.attrs["map_image_id_to_video"] + + maps = _get_image_id_maps(ds) + assert len(maps[1]) == 0 # Video map empty + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_basic_structure(mock_sio): + """Test basic structure of SLEAP objects built from dataset.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=2) + + # Mock return values for sleap classes + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + labels = _build_sleap_objects(ds) + assert labels is not None + # Ensure Labels constructor was called + mock_module.Labels.assert_called_once() + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_missing_video_info_error(mock_sio): + """Test error when video/filename info is missing.""" + ds = create_valid_keypoints_dataset() + del ds.attrs["map_image_id_to_video"] + del ds.attrs["map_image_id_to_filename"] + + with pytest.raises(ValueError, match="Missing video or filename"): + _build_sleap_objects(ds) + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_skips_all_nan_instances(mock_sio): + """Test that instances with all NaN coordinates are skipped.""" + ds = create_valid_keypoints_dataset( + n_images=1, n_keypoints=2, n_instances=2 + ) + # Set second instance to all NaN + ds["position"].values[0, :, :, 1] = np.nan + + _build_sleap_objects(ds) + + # Check that Instance() was instantiated fewer times than total potential + # instances. We expect 1 instance to be created (the valid one). + assert mock_sio.return_value.Instance.call_count == 1 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_handles_missing_keypoints(mock_sio): + """Test handling of missing keypoints (NaN values).""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=3) + # Set one keypoint to NaN + ds["position"].values[0, :, 1, 0] = np.nan + + _build_sleap_objects(ds) + + # Verify Point was called. + # Total points = 3. One is NaN, so we expect 2 Point creations. + assert mock_sio.return_value.Point.call_count == 2 + + +def test_to_file_unsupported_format(tmp_path): + """Test that ValueError is raised for unsupported formats.""" + ds = create_valid_keypoints_dataset() + output_file = tmp_path / "output.sleap" + with pytest.raises(ValueError, match="Unsupported format"): + to_file(ds, output_file, format="INVALID") + + +def test_to_file_validates_input(tmp_path): + """Test that to_file validates the input dataset.""" + invalid_ds = xr.Dataset() # Missing vars + output_file = tmp_path / "output.sleap" + # Validator raises ValueError, not TypeError, for missing vars + with pytest.raises(ValueError): + to_file(invalid_ds, output_file, format="SLEAP") + + +@pytest.mark.parametrize( + "dataset_params", + [ + {"n_images": 1, "n_keypoints": 2}, # Single Image + {"n_images": 2, "n_keypoints": 17}, # Many Keypoints + {"include_confidence": True}, # With Confidence + {"include_visibility": True}, # With Visibility + ], +) +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +@patch("ethology.io.annotations.save_keypoints._build_sleap_objects") +def test_to_file_sleap_variations( + mock_build, mock_sio, dataset_params, tmp_path +): + """Test saving to SLEAP format with various dataset configurations.""" + ds = create_valid_keypoints_dataset(**dataset_params) + output_file = tmp_path / "output.sleap" + + # Mock the internal calls + mock_build.return_value = MagicMock() + mock_sio.return_value.save_file = MagicMock() + + result = to_file(ds, output_file, format="SLEAP") + + assert result == output_file + mock_build.assert_called_once() + mock_sio.return_value.save_file.assert_called_once() + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +@patch("ethology.io.annotations.save_keypoints._build_sleap_objects") +def test_to_file_output_path_as_string(mock_build, mock_sio, tmp_path): + """Test that output_filepath can be a string.""" + ds = create_valid_keypoints_dataset() + output_file = str(tmp_path / "output.sleap") + + # Mock return values so we don't crash on saving + mock_build.return_value = MagicMock() + mock_sio.return_value.save_file = MagicMock() + + result = to_file(ds, output_file, format="SLEAP") + + # The code returns input path as-is, so we check equality, not type + assert result == output_file + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_missing_classes(mock_sio): + """Test error when sleap-io is missing required classes.""" + ds = create_valid_keypoints_dataset() + mock_module = mock_sio.return_value + mock_module.Instance = None + + with pytest.raises(AttributeError, match="sleap-io is missing"): + _build_sleap_objects(ds) + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_skeleton_fallback(mock_sio): + """Test Skeleton creation falls back when edges kwarg not supported.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + + # First call with edges=[] raises TypeError, second without works + mock_skeleton = MagicMock() + mock_module.Skeleton.side_effect = [TypeError, mock_skeleton] + mock_module.Labels.return_value = MagicMock() + + labels = _build_sleap_objects(ds) + assert labels is not None + assert mock_module.Skeleton.call_count == 2 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_video_fallbacks(mock_sio): + """Test Video construction fallback chain.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + # from_filename raises AttributeError, then filename= raises TypeError + mock_video = MagicMock() + mock_module.Video.from_filename.side_effect = AttributeError + mock_module.Video.side_effect = [TypeError, mock_video] + + _build_sleap_objects(ds) + assert mock_module.Video.call_count == 2 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_video_cache(mock_sio): + """Test that videos are reused across frames from the same source.""" + ds = create_valid_keypoints_dataset(n_images=2, n_keypoints=1) + # Both images map to the same video + ds.attrs["map_image_id_to_video"] = {0: "shared.mp4", 1: "shared.mp4"} + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + _build_sleap_objects(ds) + # Video constructor called only once for the shared filename + mock_module.Video.from_filename.assert_called_once() + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_labeled_frame_fallback(mock_sio): + """Test LabeledFrame creation falls back to positional args.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + mock_lf = MagicMock() + mock_module.LabeledFrame.side_effect = [TypeError, mock_lf] + + _build_sleap_objects(ds) + assert mock_module.LabeledFrame.call_count == 2 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_with_confidence_and_visibility(mock_sio): + """Test point creation with confidence and visibility data.""" + ds = create_valid_keypoints_dataset( + n_images=1, + n_keypoints=2, + n_instances=1, + include_confidence=True, + include_visibility=True, + ) + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + _build_sleap_objects(ds) + + # Points should be created with score and visible kwargs + point_calls = mock_module.Point.call_args_list + assert len(point_calls) == 2 + for call in point_calls: + assert "score" in call.kwargs or "x" in call.kwargs + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_no_point_cls(mock_sio): + """Test fallback when Point class is not available in sleap-io.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + mock_module.Point = None + mock_module.Labels.return_value = MagicMock() + + _build_sleap_objects(ds) + + # Instance should be created with list-of-lists points + inst_call = mock_module.Instance.call_args + points_arg = inst_call.kwargs.get("points") or inst_call.args[0] + assert isinstance(points_arg[0], list) + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_instance_fallback(mock_sio): + """Test Instance creation falls back to positional args.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + mock_module.Labels.return_value = MagicMock() + + mock_inst = MagicMock() + mock_module.Instance.side_effect = [TypeError, mock_inst] + + _build_sleap_objects(ds) + assert mock_module.Instance.call_count == 2 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +def test_build_sleap_objects_labels_fallback(mock_sio): + """Test Labels creation falls back when skeletons kwarg not supported.""" + ds = create_valid_keypoints_dataset(n_images=1, n_keypoints=1) + mock_module = mock_sio.return_value + + mock_labels = MagicMock() + mock_module.Labels.side_effect = [TypeError, mock_labels] + + _build_sleap_objects(ds) + assert mock_module.Labels.call_count == 2 + + +@patch("ethology.io.annotations.save_keypoints._require_sleap_io") +@patch("ethology.io.annotations.save_keypoints._build_sleap_objects") +def test_to_file_save_fallback(mock_build, mock_sio, tmp_path): + """Test save_file falls back to swapped argument order.""" + ds = create_valid_keypoints_dataset() + output_file = tmp_path / "output.sleap" + + mock_build.return_value = MagicMock() + mock_module = mock_sio.return_value + mock_module.save_file.side_effect = [TypeError, None] + + result = to_file(ds, output_file, format="SLEAP") + assert result == output_file + assert mock_module.save_file.call_count == 2 diff --git a/tests/test_unit/test_validators/test_annotations.py b/tests/test_unit/test_validators/test_annotations.py index 93a15713..c8e271ca 100644 --- a/tests/test_unit/test_validators/test_annotations.py +++ b/tests/test_unit/test_validators/test_annotations.py @@ -8,6 +8,7 @@ from ethology.validators.annotations import ( ValidBboxAnnotationsDataset, ValidCOCO, + ValidKeypointsAnnotationsDataset, ValidVIA, ) @@ -399,3 +400,114 @@ def test_validator_bbox_annotations_dataset( "shape": {"id", "image_id", "space"}, "category": {"id", "image_id"}, } + + +@pytest.mark.parametrize( + "sample_dataset, expected_exception, expected_error_message", + [ + ( + "valid_keypoints_annotations_dataset", + does_not_raise(), + "", + ), + ( + "valid_keypoints_annotations_dataset_extra_vars_and_dims", + does_not_raise(), + "", + ), + ( + {"position": [1, 2, 3]}, + pytest.raises(TypeError), + "Expected an xarray Dataset, but got .", + ), + ( + xr.Dataset( + coords={ + "image_id": np.arange(3), + "space": ["x", "y"], + "keypoint": ["nose", "tail"], + "id": np.arange(2), + }, + data_vars={}, + ), + pytest.raises(ValueError), + "Missing required data variables: ['position']", + ), + ( + xr.Dataset( + coords={ + "image_id": np.arange(3), + "space": ["x", "y"], + "id": np.arange(2), + }, + data_vars={ + "position": ( + ["image_id", "space", "id"], + np.zeros((3, 2, 2)), + ), + }, + ), + pytest.raises(ValueError), + "Missing required dimensions: ['keypoint']", + ), + ( + xr.Dataset( + coords={ + "image_id": np.arange(3), + "space": ["x", "y"], + "keypoint": ["nose", "tail"], + "id": np.arange(2), + }, + data_vars={ + "position": ( + ["image_id", "id", "keypoint"], + np.zeros((3, 2, 2)), + ), + }, + ), + pytest.raises(ValueError), + ( + "Some data variables are missing required dimensions:" + "\n - data variable 'position' is missing dimensions " + "['space']" + ), + ), + ], + ids=[ + "valid_keypoints_annotations", + "valid_keypoints_annotations_extra_vars_and_dims", + "invalid_keypoints_annotations_type", + "invalid_keypoints_annotations_missing_data_var", + "invalid_keypoints_annotations_missing_dimension", + "invalid_keypoints_annotations_missing_dimension_in_data_var", + ], +) +def test_validator_keypoints_annotations_dataset( + sample_dataset: str | dict, + expected_exception: pytest.raises, + expected_error_message: str, + request: pytest.FixtureRequest, +): + """Test keypoints annotations dataset validation in various scenarios.""" + if isinstance(sample_dataset, str): + dataset = request.getfixturevalue(sample_dataset) + else: + dataset = sample_dataset + + with expected_exception as excinfo: + validator = ValidKeypointsAnnotationsDataset(dataset=dataset) + + if excinfo: + error_msg = str(excinfo.value) + assert error_msg in expected_error_message + else: + assert validator.dataset is dataset + assert validator.required_dims == { + "image_id", + "space", + "keypoint", + "id", + } + assert validator.required_data_vars == { + "position": {"id", "image_id", "space", "keypoint"}, + }