Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from html import escape
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional
from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal, Mapping, Optional
from warnings import warn

import numpy as np
Expand All @@ -19,6 +19,7 @@
_calculate_edge_node_difference,
_calculate_grad_on_edge_from_faces,
)
from uxarray.constants import GRID_DIMS
from uxarray.core.utils import _map_dims_to_ugrid
from uxarray.core.zonal import _compute_non_conservative_zonal_mean
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
Expand Down Expand Up @@ -1229,7 +1230,6 @@ def isel(
ValueError
If more than one grid dimension is selected and `ignore_grid=False`.
"""
from uxarray.constants import GRID_DIMS
from uxarray.core.dataarray import UxDataArray

# merge dict‐style + kw‐style indexers
Expand Down Expand Up @@ -1424,6 +1424,121 @@ def get_dual(self):

return uxda


def neighborhood_filter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation looks great! May we move the bulk of the logic into the uxarray.grid.neighbors module and call that helper from here?

We can keep the data-mapping checks here, and anything related to constructing and returining the final data array but the bulk of the computations would go inside a helper in the module mentioned above.

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to think about how to do that, but I am happy to defer to you.

self,
func: Callable = np.mean,
r: float = 1.0,
) -> UxDataArray:
"""Apply neighborhood filter
Parameters:
-----------
func: Callable, default=np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
Returns:
--------
destination_data : np.ndarray
Filtered data.
"""

if self._face_centered():
data_mapping = "face centers"
elif self._node_centered():
data_mapping = "nodes"
elif self._edge_centered():
data_mapping = "edge centers"
else:
raise ValueError(
"Data_mapping is not face, node, or edge. Could not define data_mapping."
)

# reconstruct because the cached tree could be built from
# face centers, edge centers or nodes.
tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True)
Comment on lines +1459 to +1460
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronzedwick

We should probably fix this logic in get_ball_tree(), since we shouldn't need to manually set reconstruct=False

        if self._ball_tree is None or reconstruct:
            self._ball_tree = BallTree(
                self,
                coordinates=coordinates,
                distance_metric=distance_metric,
                coordinate_system=coordinate_system,
                reconstruct=reconstruct,
            )
        else:
            if coordinates != self._ball_tree._coordinates:
                self._ball_tree.coordinates = coordinates

The coordinates != self._ball_tree._coordinates check should be included in the first if

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. So, move the coordinates check to the if-clause like this?

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )

What if the coordinate_system is different? Would that also require a newly constructed tree?

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever logic is fixed in Grid.get_ball_tree should also be applied to Grid.get_kdtree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking coordinate system also (coordinate_system is not a hidden variable of _ball_tree; it has no underscore):

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or coordinate_system != self._ball_tree.coordinate_system
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )


coordinate_system = tree.coordinate_system

if coordinate_system == "spherical":
if data_mapping == "nodes":
lon, lat = (
self.uxgrid.node_lon.values,
self.uxgrid.node_lat.values,
)
elif data_mapping == "face centers":
lon, lat = (
self.uxgrid.face_lon.values,
self.uxgrid.face_lat.values,
)
elif data_mapping == "edge centers":
lon, lat = (
self.uxgrid.edge_lon.values,
self.uxgrid.edge_lat.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)

dest_coords = np.vstack((lon, lat)).T

elif coordinate_system == "cartesian":
if data_mapping == "nodes":
x, y, z = (
self.uxgrid.node_x.values,
self.uxgrid.node_y.values,
self.uxgrid.node_z.values,
)
elif data_mapping == "face centers":
x, y, z = (
self.uxgrid.face_x.values,
self.uxgrid.face_y.values,
self.uxgrid.face_z.values,
)
elif data_mapping == "edge centers":
x, y, z = (
self.uxgrid.edge_x.values,
self.uxgrid.edge_y.values,
self.uxgrid.edge_z.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)
Comment on lines +1461 to +1511
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use #974 's new remap.utils._remap_grid_parse instead of this code block.


dest_coords = np.vstack((x, y, z)).T

else:
raise ValueError(
f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}"
)

neighbor_indices = tree.query_radius(dest_coords, r=r)

# Construct numpy array for filtered variable.
destination_data = np.empty(self.data.shape)

# Assert last dimension is a GRID dimension.
assert self.dims[-1] in GRID_DIMS, (
f"expected last dimension of uxDataArray {self.data.dims[-1]} "
f"to be one of {GRID_DIMS}"
)
# Apply function to indices on last axis.
for i, idx in enumerate(neighbor_indices):
if len(idx):
destination_data[..., i] = func(self.data[..., idx])

# Construct UxDataArray for filtered variable.
uxda_filter = self._copy()

uxda_filter.data = destination_data

return uxda_filter

def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False):
return UxDataArray(super().where(cond, other, drop), uxgrid=self.uxgrid)

Expand Down
40 changes: 39 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from html import escape
from typing import IO, Any, Optional, Union
from typing import IO, Any, Callable, Optional, Union
from warnings import warn

import numpy as np
Expand All @@ -13,6 +13,7 @@
from xarray.core.utils import UncachedAccessor

import uxarray
from uxarray.constants import GRID_DIMS
from uxarray.core.dataarray import UxDataArray
from uxarray.core.utils import _map_dims_to_ugrid
from uxarray.formatting_html import dataset_repr
Expand Down Expand Up @@ -443,6 +444,42 @@ def to_array(self) -> UxDataArray:
xarr = super().to_array()
return UxDataArray(xarr, uxgrid=self.uxgrid)


def neighborhood_filter(
self,
func: Callable = np.mean,
r: float = 1.0,
):
"""Neighborhood function implementation for ``UxDataset``.
Parameters
---------
func : Callable = np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
"""

destination_uxds = self._copy()
# Loop through uxDataArrays in uxDataset
for var_name in self.data_vars:
uxda = self[var_name]

# Skip if uxDataArray has no GRID dimension.
grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS]
if len(grid_dims) == 0:
continue

# Put GRID dimension last for UxDataArray.neighborhood_filter.
remember_dim_order = uxda.dims
uxda = uxda.transpose(..., grid_dims[0])
# Filter uxDataArray.
uxda = uxda.neighborhood_filter(func, r)
# Restore old dimension order.
destination_uxds[var_name] = uxda.transpose(*remember_dim_order)

return destination_uxds

def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:
"""
Converts a ``ux.UXDataset`` to a ``xr.Dataset``.
Expand All @@ -464,6 +501,7 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:

return xr.Dataset(self)


def get_dual(self):
"""Compute the dual mesh for a dataset, returns a new dataset object.

Expand Down
Loading