Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_cross_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,27 @@ def test_cross_section(gridpath, datasetpath):
_ = uxds['RELHUM'].cross_section(start=(45, 45))
_ = uxds['RELHUM'].cross_section(lon=45, end=(45, 45))
_ = uxds['RELHUM'].cross_section()


def test_cross_section_cumulative_integrate(gridpath, datasetpath):
uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc"))
cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=6)
cs = cs.assign_coords(distance=("steps", np.linspace(0.0, 1.0, cs.sizes["steps"])))

cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid)

result = cs_ux.cumulative_integrate(coord="distance")
expected = cs.cumulative_integrate(coord="distance")

assert isinstance(result, ux.UxDataArray)
assert result.uxgrid == cs_ux.uxgrid
xr.testing.assert_allclose(result.to_xarray(), expected)


def test_cumulative_integrate_requires_coord(gridpath, datasetpath):
uxds = ux.open_dataset(gridpath("scrip", "ne30pg2", "grid.nc"), datasetpath("scrip", "ne30pg2", "data.nc"))
cs = uxds['RELHUM'].cross_section(start=(-45, -45), end=(45, 45), steps=3)
cs_ux = ux.UxDataArray(cs, uxgrid=uxds.uxgrid)

with pytest.raises(ValueError, match="Coordinate .* must be specified"):
cs_ux.cumulative_integrate()
36 changes: 35 additions & 1 deletion uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from html import escape
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional, Sequence
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -513,6 +513,40 @@ def integrate(

return uxda

def cumulative_integrate(
self,
coord: Hashable | Sequence[Hashable] | None = None,
datetime_unit: Optional[str] = None,
) -> "UxDataArray":
"""
Integrate cumulatively along the given coordinate using the trapezoidal rule.

Mirrors :py:meth:`xarray.DataArray.cumulative_integrate` while preserving
``uxgrid`` on the result.

Parameters
----------
coord : Hashable or sequence of Hashable
Coordinate(s) used for the integration. This must be provided.
datetime_unit : str, optional
Unit to use when integrating over datetime coordinates.

Returns
-------
UxDataArray
The cumulative integral along the specified coordinate.
"""
if coord is None:
raise ValueError(
"Coordinate ('coord') must be specified for cumulative_integrate."
)

integrated = super().cumulative_integrate(
coord=coord, datetime_unit=datetime_unit
)

return UxDataArray(integrated, uxgrid=self.uxgrid)

def zonal_mean(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs):
"""Compute non-conservative or conservative averages of a face-centered variable along lines of constant latitude or latitude bands.

Expand Down
Loading