diff --git a/movement/kinematics/__init__.py b/movement/kinematics/__init__.py index 7216a367d..e8a98611f 100644 --- a/movement/kinematics/__init__.py +++ b/movement/kinematics/__init__.py @@ -1,5 +1,6 @@ """Compute variables derived from ``position`` data.""" +from movement.kinematics.collective import compute_group_spread from movement.kinematics.distances import compute_pairwise_distances from movement.kinematics.kinematics import ( compute_acceleration, @@ -19,6 +20,7 @@ from movement.kinematics.kinetic_energy import compute_kinetic_energy __all__ = [ + "compute_group_spread", "compute_displacement", "compute_forward_displacement", "compute_backward_displacement", diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py new file mode 100644 index 000000000..7223b59bb --- /dev/null +++ b/movement/kinematics/collective.py @@ -0,0 +1,72 @@ +"""Kinematics utilities for groups of individuals. + +Currently contains functions to compute group-level spread metrics. +""" + +import numpy as np +import xarray as xr + + +def compute_group_spread( + position: xr.DataArray, + *, + keypoint: str = "centroid", + method: str = "radius_of_gyration", +) -> xr.DataArray: + """Compute per-frame group spread (radius of gyration) from positions. + + Parameters + ---------- + position : xr.DataArray + Positions with dims including ``time``, ``space`` and + ``individuals``. May also include ``keypoints``. + keypoint : str, default "centroid" + Keypoint to use when ``keypoints`` coordinate exists. + method : str, default "radius_of_gyration" + Currently only "radius_of_gyration" is supported. + + Returns + ------- + xr.DataArray + Spread per frame (dimension: ``time``). + + """ + required_dims = {"time", "space", "individuals"} + # mypy: position.dims is tuple[Hashable, ...]; convert to str for set ops + missing_dims = required_dims - set(map(str, position.dims)) + if missing_dims: + raise ValueError( + f"`position` must contain dimensions {sorted(required_dims)}. " + f"Missing: {sorted(missing_dims)}" + ) + + if "keypoints" in position.dims: + kp_coord = position.coords.get("keypoints") + if position.sizes.get("keypoints", 0) == 1: + position = position.isel(keypoints=0) + else: + # try to select the requested keypoint (defaults to 'centroid') + if kp_coord is not None and keypoint in list(kp_coord.values): + position = position.sel(keypoints=keypoint) + else: + raise ValueError( + "Multiple keypoints present; pass `keypoint` to select " + "one or include a keypoint named 'centroid'." + ) + + # validate method + if method != "radius_of_gyration": + raise ValueError( + f"Unsupported method '{method}'. Supported: 'radius_of_gyration'" + ) + + center = position.mean(dim="individuals", skipna=True) + diff = position - center + sqdist = (diff**2).sum(dim="space") + rg2 = sqdist.mean(dim="individuals", skipna=True) + + spread: xr.DataArray = xr.apply_ufunc(np.sqrt, rg2) + spread.name = "group_spread" + spread.attrs["method"] = method + + return spread diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py new file mode 100644 index 000000000..3ec22cd8c --- /dev/null +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -0,0 +1,138 @@ +import numpy as np +import pytest +import xarray as xr + +from movement.kinematics.collective import compute_group_spread + + +def test_compute_group_spread_with_known_values(): + """Test group spread on a simple dataset with known outputs.""" + position = xr.DataArray( + np.array( + [ + [[0.0, 2.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 4.0]], + ] + ), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1], + "space": ["x", "y"], + "individuals": ["id_0", "id_1"], + }, + ) + + result = compute_group_spread(position) + + expected = xr.DataArray( + np.array([1.0, 2.0]), + dims=["time"], + coords={"time": [0, 1]}, + name="group_spread", + ) + xr.testing.assert_equal(result, expected) + + +def test_compute_group_spread_prefers_centroid(valid_poses_dataset): + """Test automatic keypoint selection when ``centroid`` is present.""" + result = compute_group_spread(valid_poses_dataset.position) + + expected = xr.DataArray( + np.arange(valid_poses_dataset.sizes["time"], dtype=float), + dims=["time"], + coords={"time": valid_poses_dataset.time.values}, + name="group_spread", + ) + xr.testing.assert_equal(result, expected) + + +def test_compute_group_spread_with_explicit_keypoint(valid_poses_dataset): + """Test spread computation for an explicitly selected keypoint.""" + result = compute_group_spread( + valid_poses_dataset.position, keypoint="left" + ) + + time = valid_poses_dataset.time.values.astype(float) + expected = xr.DataArray( + np.sqrt(time**2 + time + 0.5), + dims=["time"], + coords={"time": valid_poses_dataset.time.values}, + name="group_spread", + ) + xr.testing.assert_allclose(result, expected) + + +def test_compute_group_spread_with_single_keypoint(valid_poses_dataset): + """Test that a single available keypoint is selected automatically.""" + single_keypoint_position = valid_poses_dataset.position.sel( + keypoints=["left"] + ) + + result = compute_group_spread(single_keypoint_position) + expected = compute_group_spread( + valid_poses_dataset.position, keypoint="left" + ) + + xr.testing.assert_equal(result, expected) + + +def test_compute_group_spread_ignores_nans(): + """Test NaN-safe behavior when some individuals are missing.""" + position = xr.DataArray( + np.array( + [ + [[0.0, 2.0, np.nan], [0.0, 0.0, np.nan]], + [[0.0, 4.0, 2.0], [0.0, 0.0, 0.0]], + ] + ), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1], + "space": ["x", "y"], + "individuals": ["id_0", "id_1", "id_2"], + }, + ) + + result = compute_group_spread(position) + + expected = xr.DataArray( + np.array([1.0, np.sqrt(8.0 / 3.0)]), + dims=["time"], + coords={"time": [0, 1]}, + name="group_spread", + ) + xr.testing.assert_allclose(result, expected) + + +def test_compute_group_spread_without_individuals_dimension(): + """Test that missing ``individuals`` raises a clear error.""" + position = xr.DataArray( + np.zeros((2, 2)), + dims=["time", "space"], + coords={"time": [0, 1], "space": ["x", "y"]}, + ) + + with pytest.raises( + ValueError, match="must have an 'individuals' dimension" + ): + compute_group_spread(position) + + +def test_compute_group_spread_with_multiple_keypoints_without_selection(): + """Test that ambiguous keypoint input raises a clear error.""" + position = xr.DataArray( + np.zeros((2, 2, 2, 2)), + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": [0, 1], + "space": ["x", "y"], + "keypoints": ["left", "right"], + "individuals": ["id_0", "id_1"], + }, + ) + + with pytest.raises( + ValueError, + match="Multiple keypoints present; pass `keypoint` to select one", + ): + compute_group_spread(position)