Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions movement/kinematics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Compute variables derived from ``position`` data."""

from movement.kinematics.collective import compute_group_spread
from movement.kinematics.distances import compute_pairwise_distances
from movement.kinematics.kinematics import (
compute_acceleration,
Expand All @@ -19,6 +20,7 @@
from movement.kinematics.kinetic_energy import compute_kinetic_energy

__all__ = [
"compute_group_spread",
"compute_displacement",
"compute_forward_displacement",
"compute_backward_displacement",
Expand Down
72 changes: 72 additions & 0 deletions movement/kinematics/collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Kinematics utilities for groups of individuals.

Currently contains functions to compute group-level spread metrics.
"""

import numpy as np
import xarray as xr


def compute_group_spread(
position: xr.DataArray,
*,
keypoint: str = "centroid",
method: str = "radius_of_gyration",
) -> xr.DataArray:
"""Compute per-frame group spread (radius of gyration) from positions.

Parameters
----------
position : xr.DataArray
Positions with dims including ``time``, ``space`` and
``individuals``. May also include ``keypoints``.
keypoint : str, default "centroid"
Keypoint to use when ``keypoints`` coordinate exists.
method : str, default "radius_of_gyration"
Currently only "radius_of_gyration" is supported.

Returns
-------
xr.DataArray
Spread per frame (dimension: ``time``).

"""
required_dims = {"time", "space", "individuals"}
# mypy: position.dims is tuple[Hashable, ...]; convert to str for set ops
missing_dims = required_dims - set(map(str, position.dims))
if missing_dims:
raise ValueError(
f"`position` must contain dimensions {sorted(required_dims)}. "
f"Missing: {sorted(missing_dims)}"
)

if "keypoints" in position.dims:
kp_coord = position.coords.get("keypoints")
if position.sizes.get("keypoints", 0) == 1:
position = position.isel(keypoints=0)
else:
# try to select the requested keypoint (defaults to 'centroid')
if kp_coord is not None and keypoint in list(kp_coord.values):
position = position.sel(keypoints=keypoint)
else:
raise ValueError(
"Multiple keypoints present; pass `keypoint` to select "
"one or include a keypoint named 'centroid'."
)

# validate method
if method != "radius_of_gyration":
raise ValueError(
f"Unsupported method '{method}'. Supported: 'radius_of_gyration'"
)

center = position.mean(dim="individuals", skipna=True)
diff = position - center
sqdist = (diff**2).sum(dim="space")
rg2 = sqdist.mean(dim="individuals", skipna=True)

spread: xr.DataArray = xr.apply_ufunc(np.sqrt, rg2)
spread.name = "group_spread"
spread.attrs["method"] = method

return spread
138 changes: 138 additions & 0 deletions tests/test_unit/test_kinematics/test_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np
import pytest
import xarray as xr

from movement.kinematics.collective import compute_group_spread


def test_compute_group_spread_with_known_values():
"""Test group spread on a simple dataset with known outputs."""
position = xr.DataArray(
np.array(
[
[[0.0, 2.0], [0.0, 0.0]],
[[0.0, 0.0], [0.0, 4.0]],
]
),
dims=["time", "space", "individuals"],
coords={
"time": [0, 1],
"space": ["x", "y"],
"individuals": ["id_0", "id_1"],
},
)

result = compute_group_spread(position)

expected = xr.DataArray(
np.array([1.0, 2.0]),
dims=["time"],
coords={"time": [0, 1]},
name="group_spread",
)
xr.testing.assert_equal(result, expected)


def test_compute_group_spread_prefers_centroid(valid_poses_dataset):
"""Test automatic keypoint selection when ``centroid`` is present."""
result = compute_group_spread(valid_poses_dataset.position)

expected = xr.DataArray(
np.arange(valid_poses_dataset.sizes["time"], dtype=float),
dims=["time"],
coords={"time": valid_poses_dataset.time.values},
name="group_spread",
)
xr.testing.assert_equal(result, expected)


def test_compute_group_spread_with_explicit_keypoint(valid_poses_dataset):
"""Test spread computation for an explicitly selected keypoint."""
result = compute_group_spread(
valid_poses_dataset.position, keypoint="left"
)

time = valid_poses_dataset.time.values.astype(float)
expected = xr.DataArray(
np.sqrt(time**2 + time + 0.5),
dims=["time"],
coords={"time": valid_poses_dataset.time.values},
name="group_spread",
)
xr.testing.assert_allclose(result, expected)


def test_compute_group_spread_with_single_keypoint(valid_poses_dataset):
"""Test that a single available keypoint is selected automatically."""
single_keypoint_position = valid_poses_dataset.position.sel(
keypoints=["left"]
)

result = compute_group_spread(single_keypoint_position)
expected = compute_group_spread(
valid_poses_dataset.position, keypoint="left"
)

xr.testing.assert_equal(result, expected)


def test_compute_group_spread_ignores_nans():
"""Test NaN-safe behavior when some individuals are missing."""
position = xr.DataArray(
np.array(
[
[[0.0, 2.0, np.nan], [0.0, 0.0, np.nan]],
[[0.0, 4.0, 2.0], [0.0, 0.0, 0.0]],
]
),
dims=["time", "space", "individuals"],
coords={
"time": [0, 1],
"space": ["x", "y"],
"individuals": ["id_0", "id_1", "id_2"],
},
)

result = compute_group_spread(position)

expected = xr.DataArray(
np.array([1.0, np.sqrt(8.0 / 3.0)]),
dims=["time"],
coords={"time": [0, 1]},
name="group_spread",
)
xr.testing.assert_allclose(result, expected)


def test_compute_group_spread_without_individuals_dimension():
"""Test that missing ``individuals`` raises a clear error."""
position = xr.DataArray(
np.zeros((2, 2)),
dims=["time", "space"],
coords={"time": [0, 1], "space": ["x", "y"]},
)

with pytest.raises(
ValueError, match="must have an 'individuals' dimension"
):
compute_group_spread(position)


def test_compute_group_spread_with_multiple_keypoints_without_selection():
"""Test that ambiguous keypoint input raises a clear error."""
position = xr.DataArray(
np.zeros((2, 2, 2, 2)),
dims=["time", "space", "keypoints", "individuals"],
coords={
"time": [0, 1],
"space": ["x", "y"],
"keypoints": ["left", "right"],
"individuals": ["id_0", "id_1"],
},
)

with pytest.raises(
ValueError,
match="Multiple keypoints present; pass `keypoint` to select one",
):
compute_group_spread(position)