-
Notifications
You must be signed in to change notification settings - Fork 139
feat: Implement compute_directional_change for trajectory complexity
#991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0956cd5
3b35bb9
84dcf6a
998c58e
9da3971
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,11 +6,32 @@ | |
| import xarray as xr | ||
|
|
||
| from movement.kinematics import ( | ||
| compute_directional_change, | ||
| compute_path_length, | ||
| compute_path_straightness, | ||
| compute_turning_angle, | ||
| ) | ||
|
|
||
| # Shared by all metrics that require at least 2 time points. | ||
| time_points_value_error = pytest.raises( | ||
| ValueError, | ||
| match="At least 2 time points are required", | ||
| ) | ||
|
|
||
| # Pre-sliced time-range cases shared by the straightness and directional | ||
| # change tests, which validate via the same minimum-time-points check. | ||
| time_range_cases = [ | ||
| pytest.param(slice(None, None), does_not_raise(), id="full-range"), | ||
| pytest.param(slice(0, 9), does_not_raise(), id="explicit-full-range"), | ||
| pytest.param(slice(1, 8), does_not_raise(), id="partial-range"), | ||
| pytest.param( | ||
| slice(9, 0), time_points_value_error, id="start-greater-than-stop" | ||
| ), | ||
| pytest.param( | ||
| slice(0, 0.5), time_points_value_error, id="too-few-time-points" | ||
| ), | ||
| ] | ||
|
|
||
| # ───────────────────────────────────────────── | ||
| # Fixtures | ||
| # ───────────────────────────────────────────── | ||
|
|
@@ -98,11 +119,6 @@ def sharp_turn_paths(straight_paths): | |
| # Path length tests | ||
| # ───────────────────────────────────────────── | ||
|
|
||
| time_points_value_error = pytest.raises( | ||
| ValueError, | ||
| match="At least 2 time points are required to compute path length", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "time_slice, expected_exception", | ||
|
|
@@ -288,42 +304,8 @@ def test_path_length_nan_warn_threshold( | |
| # Path straightness tests | ||
| # ───────────────────────────────────────────── | ||
|
|
||
| time_points_value_error_straightness = pytest.raises( | ||
| ValueError, | ||
| match=("At least 2 time points are required to compute path straightness"), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "time_slice, expected_exception", | ||
| [ | ||
| pytest.param( | ||
| slice(None, None), | ||
| does_not_raise(), | ||
| id="full-range", | ||
| ), | ||
| pytest.param( | ||
| slice(0, 9), | ||
| does_not_raise(), | ||
| id="explicit-full-range", | ||
| ), | ||
| pytest.param( | ||
| slice(1, 8), | ||
| does_not_raise(), | ||
| id="partial-range", | ||
| ), | ||
| pytest.param( | ||
| slice(9, 0), | ||
| time_points_value_error_straightness, | ||
| id="start-greater-than-stop", | ||
| ), | ||
| pytest.param( | ||
| slice(0, 0.5), | ||
| time_points_value_error_straightness, | ||
| id="too-few-time-points", | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("time_slice, expected_exception", time_range_cases) | ||
| def test_path_straightness_across_time_ranges( | ||
| valid_poses_dataset, time_slice, expected_exception | ||
| ): | ||
|
|
@@ -633,3 +615,85 @@ def test_turning_angle_stationary_keypoint_independent_masking(): | |
| # Stationary keypoint (kp_1): should be all NaN | ||
| angles_kp1 = angles.isel(keypoint=1) | ||
| assert np.all(np.isnan(angles_kp1.values)) | ||
|
|
||
|
|
||
| # ───────────────────────────────────────────── | ||
| # Directional change tests | ||
| # ───────────────────────────────────────────── | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "fixture_name, expected_value, expected_all_nan", | ||
| [ | ||
| pytest.param("straight_paths", 0.0, False, id="straight-line"), | ||
| pytest.param("stationary_paths", None, True, id="stationary"), | ||
| ], | ||
| ) | ||
| def test_directional_change_known_values( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests only assert DC = 0 (straight) and all-NaN (stationary). Neither exercises a known nonzero DC value, so neither the radians-per-time scaling nor the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, for better coverage, added |
||
| request, fixture_name, expected_value, expected_all_nan | ||
| ): | ||
| """Test directional change for trajectories with known geometry. | ||
|
|
||
| Straight-line motion produces zero turning angle, so DC is 0 at | ||
| every valid time step. Stationary paths produce NaN turning angles, | ||
| so DC is NaN everywhere. | ||
| """ | ||
| position = request.getfixturevalue(fixture_name) | ||
| dc = compute_directional_change(position) | ||
| assert dc.name == "directional_change" | ||
| assert dc.attrs["long_name"] == "Directional Change" | ||
|
|
||
| if expected_all_nan: | ||
| assert dc.isnull().all() | ||
| else: | ||
| valid = dc.isel(time=slice(2, None)) | ||
| xr.testing.assert_allclose(valid, xr.full_like(valid, expected_value)) | ||
| # Only the first two time steps are NaN. | ||
| assert dc.isel(time=[0, 1]).isnull().all() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "in_degrees, expected_turn", | ||
| [ | ||
| pytest.param(False, 3 * np.pi / 4, id="radians"), | ||
| pytest.param(True, 135.0, id="degrees"), | ||
| ], | ||
| ) | ||
| def test_directional_change_nonzero_value( | ||
| sharp_turn_paths, in_degrees, expected_turn | ||
| ): | ||
| """Test DC against a known nonzero value on non-uniform time. | ||
|
|
||
| In ``sharp_turn_paths`` both individuals move in a straight line for | ||
| 8 steps, then turn sharply on the final step, producing a turning | ||
| angle of ``3 * pi / 4`` (135 degrees) at the last time step. | ||
| Dividing by the turning interval ``t[-1] - t[-3]`` gives the | ||
| expected DC there. The non-uniform ``time`` coordinates ensure the | ||
| interval is aligned to the turning angle's support (positions | ||
| ``i-2..i``) rather than a centered difference around ``i``. | ||
| """ | ||
| time = np.array([0, 1, 2, 4, 7, 11, 16, 22, 29, 37], dtype=float) | ||
| path = sharp_turn_paths.assign_coords(time=time) | ||
|
|
||
| dc = compute_directional_change(path, in_degrees=in_degrees) | ||
|
|
||
| expected_last = expected_turn / (time[-1] - time[-3]) | ||
| last = dc.isel(time=-1) | ||
| xr.testing.assert_allclose(last, xr.full_like(last, expected_last)) | ||
| # Only the first two time steps are NaN; the last step is now valid. | ||
| assert dc.isel(time=[0, 1]).isnull().all() | ||
| assert dc.isel(time=slice(2, None)).notnull().all() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("time_slice, expected_exception", time_range_cases) | ||
| def test_directional_change_across_time_ranges( | ||
| valid_poses_dataset, time_slice, expected_exception | ||
| ): | ||
| """Test that DC raises with too few time points, and works | ||
| otherwise. | ||
| """ | ||
| position = valid_poses_dataset.position.sel(time=time_slice) | ||
| with expected_exception: | ||
| dc = compute_directional_change(position) | ||
| assert dc.name == "directional_change" | ||
| assert dc.attrs["long_name"] == "Directional Change" | ||
Uh oh!
There was an error while loading. Please reload this page.