Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Feature/fieldset as uxarray obj #1934

Draft
wants to merge 7 commits into
base: v4-dev
Choose a base branch
from
1 change: 1 addition & 0 deletions parcels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from parcels.particlefile import *
from parcels.particleset import *
from parcels.tools import *
from parcels.uxfieldset import *
11 changes: 11 additions & 0 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import math

from parcels.tools.statuscodes import StatusCode
from parcels.uxfieldset import UXFieldSet

__all__ = [
"UxAdvectionEuler",
"AdvectionAnalytical",
"AdvectionEE",
"AdvectionRK4",
Expand All @@ -14,6 +16,15 @@
]


def UxAdvectionEuler(particle, fieldset: UXFieldSet, time):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to have advection routines for UXArray that are distinct from XArray, unless we are able to get uniformity in the velocity field naming conventions between structured and unstructured data sets.

"""Advection of particles using Explicit Euler (aka Euler Forward) integration.
on an unstructured grid."""
vel, ei = fieldset.eval(["u","v"],time,particle.depth,particle.lat,particle.lon, particle.ei[0])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Push particle.ei to eval and select which igrid.

Also, push the xarray/uxarray "wrapper" down to the field class

particle.ei[0] = ei
particle_dlon += vel["u"] * particle.dt
particle_dlat += vel["v"] * particle.dt


def AdvectionRK4(particle, fieldset, time): # pragma: no cover
"""Advection of particles using fourth-order Runge-Kutta integration."""
(u1, v1) = fieldset.UV[particle]
Expand Down
23 changes: 13 additions & 10 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth

This function is to be called from the derived class when setting up the 'pyfunc'.
"""

## To do : Add checks for UXFieldSet - [email protected]
if self.fieldset is not None:
if pyfunc is AdvectionRK4_3D:
warning = False
Expand Down Expand Up @@ -335,17 +337,17 @@ def execute(self, pset, endtime, dt):
stacklevel=2,
)

if pset.fieldset is not None:
for g in pset.fieldset.gridset.grids:
if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel
g._load_chunk = np.where(
g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk
)
# if pset.fieldset is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't seem to be needing to execute for the UXArray data, so I've just commented this out. Not sure what the plans are here, given that I recall us discussing removing chunking and deferred loading.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gone in some of the recent updates

# for g in pset.fieldset.gridset.grids:
# if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel
# g._load_chunk = np.where(
# g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk
# )

for f in self.fieldset.get_fields():
if isinstance(f, (VectorField, NestedField)):
continue
f.data = np.array(f.data)
# for f in self.fieldset.get_fields():
# if isinstance(f, (VectorField, NestedField)):
# continue
# f.data = np.array(f.data)

if not self._positionupdate_kernels_added:
self.add_positionupdate_kernels()
Expand Down Expand Up @@ -424,6 +426,7 @@ def evaluate_particle(self, p, endtime):
except KeyError:
if abs(endtime - p.time_nextloop) < abs(p.dt) - 1e-6:
p.dt = abs(endtime - p.time_nextloop) * sign_dt

res = self._pyfunc(p, self._fieldset, p.time_nextloop)

if res is None:
Expand Down
42 changes: 29 additions & 13 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from parcels.tools.loggers import logger
from parcels.tools.statuscodes import StatusCode
from parcels.tools.warnings import ParticleSetWarning
from parcels.uxfieldset import UXFieldSet

__all__ = ["ParticleSet"]

Expand Down Expand Up @@ -148,7 +149,10 @@ def ArrayClass_init(self, *args, **kwargs):
pid_orig = np.arange(lon.size)

if depth is None:
mindepth = self.fieldset.gridset.dimrange("depth")[0]
if type(self.fieldset) == UXFieldSet:
mindepth = 0 # TO DO : get the min depth from the fieldset.uxgrid
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we'll need a new implementation for fieldset.gridset.dimrange . Since the new proposed fieldset is meant to wrap around either an XArray or UXArray dataset, which each have a grid baked into the structure, the gridset subclass here is irrelevant.

else:
mindepth = self.fieldset.gridset.dimrange("depth")[0]
depth = np.ones(lon.size) * mindepth
else:
depth = convert_to_flat_array(depth)
Expand All @@ -163,11 +167,16 @@ def ArrayClass_init(self, *args, **kwargs):
raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double")
time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time])
assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths."
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation):
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full)

if lonlatdepth_dtype is None:
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U)
if type(fieldset) == UXFieldSet:
lonlatdepth_dtype = np.float32 # To do : get precision from fieldset
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An obvious to-do item here; we need to be able to get the precision from the XArray or UXArray dataset underneath the thin fieldset wrapper.

else:
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation):
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we'll need an updated implementation of _warn_particle_times_outside_fieldset_time_bounds .. Alternatively, we'll need to replace the fieldset.U.grid.time_full with a fieldset.time_full call.

Note that I'm assuming that a UXarray or XArray dataset have a single time dimension for all fields in the dataset. Getting the time extents for the dataset wouldn't depend on the individual fields (e.g. U or V); rather it'd be associated with the dimensions of the dataset.


if lonlatdepth_dtype is None:
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U)

assert lonlatdepth_dtype in [
np.float32,
np.float64,
Expand All @@ -191,7 +200,10 @@ def ArrayClass_init(self, *args, **kwargs):
self._repeatkwargs = kwargs
self._repeatkwargs.pop("partition_function", None)

ngrids = fieldset.gridset.size
if type(fieldset) == UXFieldSet:
ngrids = 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we explicitly set ngrids to 1. However, with the idea that the fieldset will store a list of either XArray or UXArray datasets, we will want to set ngrids to the length of the list of datasets.

else:
ngrids = fieldset.gridset.size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gridset attribute will be going away with the updated fieldset.


# Variables used for interaction kernels.
inter_dist_horiz = None
Expand Down Expand Up @@ -967,7 +979,10 @@ def execute(
if runtime is not None and endtime is not None:
raise RuntimeError("Only one of (endtime, runtime) can be specified")

mintime, maxtime = self.fieldset.gridset.dimrange("time_full")
if type(self.fieldset) == UXFieldSet:
mintime, maxtime = self.fieldset.get_time_range()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fieldset.get_time_range is meant to replace fieldset.gridset.dimrange("time_full")

else:
mintime, maxtime = self.fieldset.gridset.dimrange("time_full")

default_release_time = mintime if dt >= 0 else maxtime
if np.any(np.isnan(self.particledata.data["time"])):
Expand Down Expand Up @@ -1037,18 +1052,19 @@ def execute(

time_at_startofloop = time

next_input = self.fieldset.computeTimeChunk(time, dt)
#next_input = self.fieldset.computeTimeChunk(time, dt)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some work is going to be needed to get this implemented. Though I suspect some work is being done in this area currently in the v4-dev. Would be best to discuss together.


# Define next_time (the timestamp when the execution needs to be handed back to python)
if dt > 0:
next_time = min(next_prelease, next_input, next_output, next_callback, endtime)
else:
next_time = max(next_prelease, next_input, next_output, next_callback, endtime)
#if dt > 0:
# next_time = min(next_prelease, next_input, next_output, next_callback, endtime)
#else:
# next_time = max(next_prelease, next_input, next_output, next_callback, endtime)

next_time = endtime
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was hard-coded, just to get things working for the simple case with one time level in the UXArray dataset.

# If we don't perform interaction, only execute the normal kernel efficiently.
if self._interaction_kernel is None:
if not skip_kernel:
res = self._kernel.execute(self, endtime=next_time, dt=dt)
res = self._kernel.execute(self, endtime=next_time, dt=dt) # [email protected] : switched to hardcoded endtime
if res == StatusCode.StopAllExecution:
return StatusCode.StopAllExecution
# Interaction: interleave the interaction and non-interaction kernel for each time step.
Expand Down
189 changes: 189 additions & 0 deletions parcels/uxfieldset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import cftime
import numpy as np
import uxarray as ux
import cftime

from parcels._compat import MPI
from parcels._typing import GridIndexingType, InterpMethodOption, Mesh
from parcels.field import DeferredArray, Field, NestedField, VectorField
from parcels.grid import Grid
from parcels.gridset import GridSet
from parcels.particlefile import ParticleFile
from parcels.tools._helpers import default_repr, fieldset_repr
from parcels.tools.converters import TimeConverter, convert_xarray_time_units
from parcels.tools.loggers import logger
from parcels.tools.statuscodes import TimeExtrapolationError
from parcels.tools.warnings import FieldSetWarning

__all__ = ["UXFieldSet"]

_inside_tol = 1e-6

# class UXVectorField:
# def __init__(self, name: str, U: ux.UxDataArray, V: ux.UxDataArray, W: ux.UxDataArray | None = None):
# self.name = name
# self.U = U
# self.V = V
# self.W = W
# if self.W:
# self.vector_type = "3D"
# else:
# self.vector_type = "2D"

# def __repr__(self):
# return f"""<{type(self).__name__}>
# name: {self.name!r}
# U: {default_repr(self.U)}
# V: {default_repr(self.V)}
# W: {default_repr(self.W)}"""

# def eval(self, time, z, y, x, particle=None, applyConversion=True):

class UXFieldSet:
"""A FieldSet class that holds hydrodynamic data needed to execute particles
in a UXArray.Dataset"""
# Change uxds to ds_list - which is a list of either uxDataset or xarray dataset
def __init__(self, uxds: ux.UxDataset, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0):
Copy link
Author

@fluidnumerics-joe fluidnumerics-joe Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The uxds argument will be changed to ds (standing for "DataSets"), which will be a List(uxarray.UxDataset | xarray.Dataset)

# Ensure that dataset provides a grid, and the u and v velocity
# components at a minimum
if not hasattr(uxds, "uxgrid"):
raise ValueError("The UXArray dataset does not provide a grid")
if not hasattr(uxds, "u"):
raise ValueError("The UXArray dataset does not provide u velocity data")
if not hasattr(uxds, "v"):
raise ValueError("The UXArray dataset does not provide v velocity data")

self.time_origin = time_origin
self.uxds = uxds
self._spatialhash = self.uxds.uxgrid.get_spatial_hash()

#def _validate_uxds(self, uxds: ux.UxDataset):
#def _validate_xds(self, xds: xr.Dataset):

def _check_complete(self):
assert self.uxds is not None, "UXFieldSet has not been loaded"
assert self.uxds.u is not None, "UXFieldSet does not provide u velocity data"
assert self.uxds.v is not None, "UXFieldSet does not provide v velocity data"
assert self.uxds.uxgrid is not None, "UXFieldSet does not provide a grid"

def _face_interp(self, field, time, z, y, x, ei):
Copy link
Author

@fluidnumerics-joe fluidnumerics-joe Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, hard-coded interpolation schemes are provided for face registered data and node registered data.

We may want to discuss with the Geomar folks if this dichotomy is intuitive for them or if they have a different way of viewing the interpolation problem.

ti = 0
zi = 0
return field[ti,zi,ei]

def _node_interp(self, field, time, z, y, x, ei):
"""Performs barycentric interpolation of a field at a given location."""
ti = 0
zi = 0
coords =np.deg2rad([[x, y]])
n_nodes = self.uxds.uxgrid.n_nodes_per_face[ei].to_numpy()
node_ids = self.uxds.uxgrid.face_node_connectivity[ei, 0:n_nodes]
nodes = np.column_stack(
(
np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()),
np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()),
)
)
bcoord = np.asarray(_barycentric_coordinates(nodes, coords))
return np.sum(bcoord * field[ti,zi,node_ids].flatten(), axis=0)

def get_time_range(self):
return self.uxds.time.min().to_numpy(), self.uxds.time.max().to_numpy()

def _point_is_in_face(self, y, x, ei):
"Checks if a point is inside a given face id "
#ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle
fi = ei

# Check if particle is in the same face, otherwise search again.
n_nodes = self.uxds.uxgrid.n_nodes_per_face[fi].to_numpy()
node_ids = self.uxds.uxgrid.face_node_connectivity[fi, 0:n_nodes]
nodes = np.column_stack(
(
np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()),
np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()),
)
)

coord = np.deg2rad([x, y])
bcoord = np.asarray(_barycentric_coordinates(nodes, coord))
if ( not (bcoord >= 0).all() ) and (not (bcoord <= 1.0).all()):
return False

return True

def eval(self, field_names, time, z, y, x, ei: int=None, applyConversion=True):

res = {}

if ei is not None:
fi = ei
if not self._point_is_in_face(y,x,ei):
# If the point is not in the previously defined face, then
# search for the face again.
# To do : Update the search here to do nearest neighbors search, rather than spatial hash - [email protected]
Copy link
Author

@fluidnumerics-joe fluidnumerics-joe Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For UXArray data, we ought to add a method that does a nearest neighbors search, rather than a hash query when the particle is not found in the ei face.

I suppose this code here, and the proposed code, ought to placed in a _search_uxindices method

print(f"Position : {x}, {y}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print statements can probable come out now.. these were needed for debugging..

print(f"Hash indices : {self._spatialhash._hash_index2d(np.deg2rad([[x,y]]))}")
fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle
fi = fi[0]
print(f"Face index (updated): {fi}")
print(f"Barycentric coordinates (updated): {bcoords}")

for f in field_names:
field = getattr(self.uxds, f)
face_registered = ("n_face" in field.dims)

if face_registered:
r = self._face_interp(field, time, z, y, x, fi)
else:
r = self._node_interp(field, time, z, y, x, fi)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 3-D interpolation, we need to add vertical interpolation. Currently, all interpolation uses the fields at the upper most vertical layer

#if applyConversion:
# res[f] = self.units.to_target(r, z, y, x)
#else:
# To do : Add call to units.to_target to handle unit conversion : [email protected]
res[f] = r/111111.111111111

return res, fi

def _barycentric_coordinates(nodes, point):
"""
Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
barycentric coordinates, which is only valid for convex polygons.

Parameters
----------
nodes : numpy.ndarray
Spherical coordinates (lon,lat) of each corner node of a face
point : numpy.ndarray
Spherical coordinates (lon,lat) of the point
Returns
-------
numpy.ndarray
Barycentric coordinates corresponding to each vertex.

"""
n = len(nodes)
sum_wi = 0
w = []

for i in range(0, n):
vim1 = nodes[i - 1]
vi = nodes[i]
vi1 = nodes[(i + 1) % n]
a0 = _triangle_area(vim1, vi, vi1)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These areas need to be clipped with a minimum value. As we found in uxarray, when a particle lies on a face edge or vertex, these areas can be zero, leading to division by zero.

a1 = _triangle_area(point, vim1, vi)
a2 = _triangle_area(point, vi, vi1)
sum_wi += a0 / (a1 * a2)
w.append(a0 / (a1 * a2))

barycentric_coords = [w_i / sum_wi for w_i in w]

return barycentric_coords

def _triangle_area(A, B, C):
"""
Compute the area of a triangle given by three points.
"""
return 0.5 * (A[0] * (B[1] - C[1]) + B[0] * (C[1] - A[1]) + C[0] * (A[1] - B[1]))
Loading