diff --git a/test/test_cross_sections.py b/test/test_cross_sections.py index 081e9b8e2..5ceba29da 100644 --- a/test/test_cross_sections.py +++ b/test/test_cross_sections.py @@ -226,3 +226,27 @@ def test_cross_section(gridpath, datasetpath): _ = uxds['RELHUM'].cross_section(start=(45, 45)) _ = uxds['RELHUM'].cross_section(lon=45, end=(45, 45)) _ = uxds['RELHUM'].cross_section() + + +def test_cross_section_cumulative_integrate(gridpath, datasetpath): + uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc")) + cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=6) + cs = cs.assign_coords(distance=("steps", np.linspace(0.0, 1.0, cs.sizes["steps"]))) + + cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid) + + result = cs_ux.cumulative_integrate(coord="distance") + expected = cs.cumulative_integrate(coord="distance") + + assert isinstance(result, ux.UxDataArray) + assert result.uxgrid == cs_ux.uxgrid + xr.testing.assert_allclose(result.to_xarray(), expected) + + +def test_cumulative_integrate_requires_coord(gridpath, datasetpath): + uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc")) + cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=3) + cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid) + + with pytest.raises(ValueError, match="Coordinate .* must be specified"): + cs_ux.cumulative_integrate() diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 7ad62050c..e1d51bed8 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -2,7 +2,7 @@ import warnings from html import escape -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional +from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional, Sequence from warnings import warn import numpy as np @@ -513,6 +513,40 @@ def integrate( return uxda + def cumulative_integrate( + self, + coord: Hashable | Sequence[Hashable] | None = None, + datetime_unit: Optional[str] = None, + ) -> "UxDataArray": + """ + Integrate cumulatively along the given coordinate using the trapezoidal rule. + + Mirrors :py:meth:`xarray.DataArray.cumulative_integrate` while preserving + ``uxgrid`` on the result. + + Parameters + ---------- + coord : Hashable or sequence of Hashable + Coordinate(s) used for the integration. This must be provided. + datetime_unit : str, optional + Unit to use when integrating over datetime coordinates. + + Returns + ------- + UxDataArray + The cumulative integral along the specified coordinate. + """ + if coord is None: + raise ValueError( + "Coordinate ('coord') must be specified for cumulative_integrate." + ) + + integrated = super().cumulative_integrate( + coord=coord, datetime_unit=datetime_unit + ) + + return UxDataArray(integrated, uxgrid=self.uxgrid) + def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs): """Compute non-conservative or conservative averages of a face-centered variable along lines of constant latitude or latitude bands.