diff --git a/movement/utils/vector.py b/movement/utils/vector.py index ca7430df1..e65febfda 100644 --- a/movement/utils/vector.py +++ b/movement/utils/vector.py @@ -42,7 +42,13 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - validate_dims_coords(data, {"space": ["x", "y"]}) + # Allow both 2D and 3D + if len(data.coords["space"]) == 2: + validate_dims_coords(data, {"space": ["x", "y"]}) + elif len(data.coords["space"]) == 3: + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + else: + _raise_error_for_invalid_spatial_dim_length("space", 2, 3) return xr.apply_ufunc( np.linalg.norm, data, @@ -50,7 +56,13 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: kwargs={"axis": -1}, ) elif "space_pol" in data.dims: - validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + # Allow both 2D polar and 3D cylindrical + if len(data.coords["space_pol"]) == 2: + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + elif len(data.coords["space_pol"]) == 3: + validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]}) + else: + _raise_error_for_invalid_spatial_dim_length("space_pol", 2, 3) return data.sel(space_pol="rho", drop=True) else: _raise_error_for_missing_spatial_dim() @@ -81,10 +93,22 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - validate_dims_coords(data, {"space": ["x", "y"]}) + # Allow both 2D and 3D + if len(data.coords["space"]) == 2: + validate_dims_coords(data, {"space": ["x", "y"]}) + elif len(data.coords["space"]) == 3: + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + else: + _raise_error_for_invalid_spatial_dim_length("space", 2, 3) return data / compute_norm(data) elif "space_pol" in data.dims: - validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + # Allow both 2D polar and 3D cylindrical + if len(data.coords["space_pol"]) == 2: + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + elif len(data.coords["space_pol"]) == 3: + validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]}) + else: + _raise_error_for_invalid_spatial_dim_length("space_pol", 2, 3) # Set both rho and phi values to NaN at null vectors (where rho = 0) new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data) # Set the rho values to 1 for non-null vectors (phi is preserved) @@ -97,21 +121,26 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray: def cart2pol(data: xr.DataArray) -> xr.DataArray: - """Transform Cartesian coordinates to polar. + """Transform Cartesian coordinates to polar (2D) or cylindrical (3D). Parameters ---------- data The input data containing ``space`` as a dimension, - with ``x`` and ``y`` in the dimension coordinate. + with ``x`` and ``y`` (2D) or ``x``, ``y``, and ``z`` (3D) + in the dimension coordinate. Returns ------- xarray.DataArray - An xarray DataArray containing the polar coordinates - stored in the ``space_pol`` dimension, with ``rho`` - and ``phi`` in the dimension coordinate. The angles - ``phi`` returned are in radians, in the range ``[-pi, pi]``. + An xarray DataArray containing the polar/cylindrical coordinates + stored in the ``space_pol`` dimension: + + - 2D: ``rho`` and ``phi`` + - 3D: ``rho``, ``phi``, and ``z`` (cylindrical coordinates) + + The angle ``phi`` is in radians, in the range ``[-pi, pi]``. + For 3D input, ``z`` is passed through unchanged. Notes ----- @@ -124,7 +153,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: References ---------- - .. [1] ISO/IEC standard 9899:1999, “Programming language C.” + .. [1] ISO/IEC standard 9899:1999, "Programming language C." .. [2] https://en.wikipedia.org/wiki/Atan2 .. [3] https://en.wikipedia.org/wiki/Signed_zero @@ -133,63 +162,200 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: :obj:`numpy.arctan2` """ - validate_dims_coords(data, {"space": ["x", "y"]}) - rho = compute_norm(data) - phi = xr.apply_ufunc( - np.arctan2, - data.sel(space="y"), - data.sel(space="x"), - ) + # Validate space dimension exists + if "space" not in data.dims: + raise logger.error( + ValueError("Input data must contain 'space' as a dimension.") + ) + + # Validate 2D or 3D input + is_3d = len(data.coords["space"]) == 3 + if is_3d: + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + else: + validate_dims_coords(data, {"space": ["x", "y"]}) + + x = data.sel(space="x", drop=True) + y = data.sel(space="y", drop=True) + rho = (x**2 + y**2) ** 0.5 + phi = xr.apply_ufunc(np.arctan2, y, x) # Make all zeros in phi positive zeros # - where rho == 0, set phi to 0 # - where rho != 0, keep the phi value from atan2 phi = xr.where(np.isclose(rho.values, 0.0, atol=1e-9), 0.0, phi) + # Build output components + components = [ + rho.assign_coords({"space_pol": "rho"}), + phi.assign_coords({"space_pol": "phi"}), + ] + + # For 3D, pass z through unchanged + if is_3d: + z = data.sel(space="z", drop=True) + components.append(z.assign_coords({"space_pol": "z"})) + # Replace space dim with space_pol dims = list(data.dims) dims[dims.index("space")] = "space_pol" - return xr.concat( - [ - rho.assign_coords({"space_pol": "rho"}), - phi.assign_coords({"space_pol": "phi"}), - ], - dim="space_pol", - ).transpose(*dims) + return xr.concat(components, dim="space_pol").transpose(*dims) def pol2cart(data: xr.DataArray) -> xr.DataArray: - """Transform polar coordinates to Cartesian. + """Transform polar (2D) or cylindrical (3D) coordinates to Cartesian. Parameters ---------- data The input data containing ``space_pol`` as a dimension, - with ``rho`` and ``phi`` in the dimension coordinate. + with ``rho`` and ``phi`` (2D) or ``rho``, ``phi``, and ``z`` (3D) + in the dimension coordinate. Returns ------- xarray.DataArray An xarray DataArray containing the Cartesian coordinates - stored in the ``space`` dimension, with ``x`` and ``y`` - in the dimension coordinate. + stored in the ``space`` dimension: + + - 2D: ``x`` and ``y`` + - 3D: ``x``, ``y``, and ``z`` """ - validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) - rho = data.sel(space_pol="rho") - phi = data.sel(space_pol="phi") + # Validate space_pol dimension exists + if "space_pol" not in data.dims: + raise logger.error( + ValueError("Input data must contain 'space_pol' as a dimension.") + ) + + # Validate 2D or 3D input + is_3d = len(data.coords["space_pol"]) == 3 + if is_3d: + validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]}) + else: + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + + rho = data.sel(space_pol="rho", drop=True) + phi = data.sel(space_pol="phi", drop=True) x = rho * np.cos(phi) y = rho * np.sin(phi) + # Build output components + components = [ + x.assign_coords({"space": "x"}), + y.assign_coords({"space": "y"}), + ] + + # For 3D, pass z through unchanged + if is_3d: + z = data.sel(space_pol="z", drop=True) + components.append(z.assign_coords({"space": "z"})) + # Replace space_pol dim with space dims = list(data.dims) dims[dims.index("space_pol")] = "space" + return xr.concat(components, dim="space").transpose(*dims) + + +def cart2sph(data: xr.DataArray) -> xr.DataArray: + """Transform 3D Cartesian coordinates to spherical. + + Parameters + ---------- + data + The input data containing ``space`` as a dimension, + with ``x``, ``y``, and ``z`` in the dimension coordinate. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the spherical coordinates + stored in the ``space_sph`` dimension, with ``rho``, + ``azimuth``, and ``elevation`` in the dimension coordinate: + + - ``rho``: radial distance (magnitude of the vector) + - ``azimuth``: angle in the x-y plane from the positive x-axis, + in radians, in the range ``[-pi, pi]`` + - ``elevation``: angle from the x-y plane, in radians, + in the range ``[-pi/2, pi/2]`` + + See Also + -------- + sph2cart : Inverse transformation from spherical to Cartesian. + + """ + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + + x = data.sel(space="x", drop=True) + y = data.sel(space="y", drop=True) + z = data.sel(space="z", drop=True) + + rho = (x**2 + y**2 + z**2) ** 0.5 + azimuth = xr.apply_ufunc(np.arctan2, y, x) + # Compute elevation, handling zero-magnitude vectors + elevation = xr.where( + rho > 0, + np.arcsin((z / rho).clip(-1, 1)), + 0.0, + ) + + # Replace space dim with space_sph + dims = list(data.dims) + dims[dims.index("space")] = "space_sph" + return xr.concat( + [ + rho.assign_coords({"space_sph": "rho"}), + azimuth.assign_coords({"space_sph": "azimuth"}), + elevation.assign_coords({"space_sph": "elevation"}), + ], + dim="space_sph", + coords="minimal", + ).transpose(*dims) + + +def sph2cart(data: xr.DataArray) -> xr.DataArray: + """Transform spherical coordinates to 3D Cartesian. + + Parameters + ---------- + data + The input data containing ``space_sph`` as a dimension, + with ``rho``, ``azimuth``, and ``elevation`` in the + dimension coordinate. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the Cartesian coordinates + stored in the ``space`` dimension, with ``x``, ``y``, and ``z`` + in the dimension coordinate. + + See Also + -------- + cart2sph : Inverse transformation from Cartesian to spherical. + + """ + validate_dims_coords(data, {"space_sph": ["rho", "azimuth", "elevation"]}) + + rho = data.sel(space_sph="rho", drop=True) + azimuth = data.sel(space_sph="azimuth", drop=True) + elevation = data.sel(space_sph="elevation", drop=True) + + x = rho * np.cos(elevation) * np.cos(azimuth) + y = rho * np.cos(elevation) * np.sin(azimuth) + z = rho * np.sin(elevation) + + # Replace space_sph dim with space + dims = list(data.dims) + dims[dims.index("space_sph")] = "space" return xr.concat( [ x.assign_coords({"space": "x"}), y.assign_coords({"space": "y"}), + z.assign_coords({"space": "z"}), ], dim="space", + coords="minimal", ).transpose(*dims) @@ -314,3 +480,14 @@ def _raise_error_for_missing_spatial_dim() -> NoReturn: "as dimensions." ) ) + + +def _raise_error_for_invalid_spatial_dim_length( + dim_name: str, *valid_lengths: int +) -> NoReturn: + lengths_str = " or ".join(str(n) for n in valid_lengths) + raise logger.error( + ValueError( + f"Dimension '{dim_name}' must have {lengths_str} coordinates." + ) + ) diff --git a/tests/test_unit/test_vector.py b/tests/test_unit/test_vector.py index 2e9f28576..2669b34a4 100644 --- a/tests/test_unit/test_vector.py +++ b/tests/test_unit/test_vector.py @@ -89,6 +89,69 @@ def cart_pol_dataset_missing_pol_coords(self, cart_pol_dataset): cart_pol_dataset["space_pol"] = ["a", "b"] return cart_pol_dataset + @pytest.fixture + def cart_pol_dataset_3d(self): + """Return xarray.Dataset with 3D Cartesian and cylindrical coords.""" + # 3D Cartesian coordinates + x_vals = np.array([1.0, 0.0, -1.0, 1.0, 0.0]) + y_vals = np.array([0.0, 1.0, 0.0, 1.0, 0.0]) + z_vals = np.array([0.0, 0.0, 2.0, 3.0, -1.0]) + time_coords = np.arange(len(x_vals)) + + # Expected cylindrical coordinates (rho in x-y plane, z passes through) + rho = np.sqrt(x_vals**2 + y_vals**2) + phi = np.arctan2(y_vals, x_vals) + # Handle null vectors in x-y plane + phi = np.where(np.isclose(rho, 0.0), 0.0, phi) + + cart = xr.DataArray( + np.column_stack((x_vals, y_vals, z_vals)), + dims=["time", "space"], + coords={"time": time_coords, "space": ["x", "y", "z"]}, + ) + pol = xr.DataArray( + np.column_stack((rho, phi, z_vals)), + dims=["time", "space_pol"], + coords={"time": time_coords, "space_pol": ["rho", "phi", "z"]}, + ) + return xr.Dataset(data_vars={"cart": cart, "pol": pol}) + + @pytest.fixture + def cart_sph_dataset(self): + """Return an xarray.Dataset with 3D Cartesian and spherical coords.""" + # 3D Cartesian coordinates (points on/around unit sphere) + cart_data = np.array( + [ + [1.0, 0.0, 0.0], # +x axis + [0.0, 1.0, 0.0], # +y axis + [0.0, 0.0, 1.0], # +z axis (north pole) + [0.0, 0.0, -1.0], # -z axis (south pole) + [1.0, 1.0, 0.0], # x-y plane, 45 deg + ] + ) + time_coords = np.arange(len(cart_data)) + + # Expected spherical coordinates + x, y, z = cart_data[:, 0], cart_data[:, 1], cart_data[:, 2] + rho = np.sqrt(x**2 + y**2 + z**2) + azimuth = np.arctan2(y, x) + elevation = np.where(rho > 0, np.arcsin(z / rho), 0.0) + + cart = xr.DataArray( + cart_data, + dims=["time", "space"], + coords={"time": time_coords, "space": ["x", "y", "z"]}, + ) + sph = xr.DataArray( + np.column_stack((rho, azimuth, elevation)), + dims=["time", "space_sph"], + coords={ + "time": time_coords, + "space_sph": ["rho", "azimuth", "elevation"], + }, + ) + return xr.Dataset(data_vars={"cart": cart, "sph": sph}) + @pytest.mark.parametrize( "ds, expected_exception", [ @@ -197,6 +260,70 @@ def test_convert_to_unit(self, ds, expected_exception, request): expected_unit_pol = expected_unit_pol.where(~expected_nan_idxs) xr.testing.assert_allclose(unit_pol, expected_unit_pol) + def test_cart2pol_3d(self, cart_pol_dataset_3d): + """Test 3D Cartesian to cylindrical coordinates.""" + result = vector.cart2pol(cart_pol_dataset_3d.cart) + xr.testing.assert_allclose(result, cart_pol_dataset_3d.pol) + # Verify z passes through unchanged + xr.testing.assert_equal( + result.sel(space_pol="z", drop=True), + cart_pol_dataset_3d.cart.sel(space="z", drop=True), + ) + + def test_pol2cart_3d(self, cart_pol_dataset_3d): + """Test 3D cylindrical to Cartesian coordinates.""" + result = vector.pol2cart(cart_pol_dataset_3d.pol) + xr.testing.assert_allclose(result, cart_pol_dataset_3d.cart) + + def test_compute_norm_3d(self, cart_pol_dataset_3d): + """Test norm computation on 3D Cartesian and cylindrical data.""" + cart = cart_pol_dataset_3d.cart + pol = cart_pol_dataset_3d.pol + + # 3D Cartesian norm: sqrt(x^2 + y^2 + z^2) + result_cart = vector.compute_norm(cart) + expected = np.sqrt( + cart.sel(space="x", drop=True) ** 2 + + cart.sel(space="y", drop=True) ** 2 + + cart.sel(space="z", drop=True) ** 2 + ) + xr.testing.assert_allclose(result_cart, expected) + + # Cylindrical norm should return rho (x-y plane magnitude) + result_pol = vector.compute_norm(pol) + expected_pol = pol.sel(space_pol="rho", drop=True) + xr.testing.assert_allclose(result_pol, expected_pol) + + def test_convert_to_unit_3d(self, cart_pol_dataset_3d): + """Test unit vector conversion on 3D Cartesian data.""" + cart = cart_pol_dataset_3d.cart + unit_cart = vector.convert_to_unit(cart) + + # Unit vectors should have norm = 1 + norms = vector.compute_norm(unit_cart) + # Skip null vectors (where original norm was 0) + original_norms = vector.compute_norm(cart) + valid_mask = original_norms > 0 + xr.testing.assert_allclose( + norms.where(valid_mask), xr.ones_like(norms).where(valid_mask) + ) + + def test_cart2sph(self, cart_sph_dataset): + """Test 3D Cartesian to spherical coordinates.""" + result = vector.cart2sph(cart_sph_dataset.cart) + xr.testing.assert_allclose(result, cart_sph_dataset.sph) + + def test_sph2cart(self, cart_sph_dataset): + """Test spherical to 3D Cartesian coordinates.""" + result = vector.sph2cart(cart_sph_dataset.sph) + xr.testing.assert_allclose(result, cart_sph_dataset.cart) + + def test_cart2sph_sph2cart_roundtrip(self, cart_sph_dataset): + """Test roundtrip conversion: cart -> sph -> cart.""" + cart = cart_sph_dataset.cart + roundtrip = vector.sph2cart(vector.cart2sph(cart)) + xr.testing.assert_allclose(roundtrip, cart) + class TestComputeSignedAngle: """Tests for the compute_signed_angle_2d method."""