-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE] Feature/fieldset as uxarray obj #1934
base: v4-dev
Are you sure you want to change the base?
Changes from all commits
0dfa7e1
6821e20
f1802ce
e1cbc09
4caf74c
43f93c1
d683511
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,10 @@ | |
import math | ||
|
||
from parcels.tools.statuscodes import StatusCode | ||
from parcels.uxfieldset import UXFieldSet | ||
|
||
__all__ = [ | ||
"UxAdvectionEuler", | ||
"AdvectionAnalytical", | ||
"AdvectionEE", | ||
"AdvectionRK4", | ||
|
@@ -14,6 +16,15 @@ | |
] | ||
|
||
|
||
def UxAdvectionEuler(particle, fieldset: UXFieldSet, time): | ||
"""Advection of particles using Explicit Euler (aka Euler Forward) integration. | ||
on an unstructured grid.""" | ||
vel, ei = fieldset.eval(["u","v"],time,particle.depth,particle.lat,particle.lon, particle.ei[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Push particle.ei to eval and select which igrid. Also, push the xarray/uxarray "wrapper" down to the |
||
particle.ei[0] = ei | ||
particle_dlon += vel["u"] * particle.dt | ||
particle_dlat += vel["v"] * particle.dt | ||
|
||
|
||
def AdvectionRK4(particle, fieldset, time): # pragma: no cover | ||
"""Advection of particles using fourth-order Runge-Kutta integration.""" | ||
(u1, v1) = fieldset.UV[particle] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -206,6 +206,8 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth | |
|
||
This function is to be called from the derived class when setting up the 'pyfunc'. | ||
""" | ||
|
||
## To do : Add checks for UXFieldSet - [email protected] | ||
if self.fieldset is not None: | ||
if pyfunc is AdvectionRK4_3D: | ||
warning = False | ||
|
@@ -335,17 +337,17 @@ def execute(self, pset, endtime, dt): | |
stacklevel=2, | ||
) | ||
|
||
if pset.fieldset is not None: | ||
for g in pset.fieldset.gridset.grids: | ||
if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel | ||
g._load_chunk = np.where( | ||
g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk | ||
) | ||
# if pset.fieldset is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This didn't seem to be needing to execute for the UXArray data, so I've just commented this out. Not sure what the plans are here, given that I recall us discussing removing chunking and deferred loading. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is gone in some of the recent updates |
||
# for g in pset.fieldset.gridset.grids: | ||
# if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel | ||
# g._load_chunk = np.where( | ||
# g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk | ||
# ) | ||
|
||
for f in self.fieldset.get_fields(): | ||
if isinstance(f, (VectorField, NestedField)): | ||
continue | ||
f.data = np.array(f.data) | ||
# for f in self.fieldset.get_fields(): | ||
# if isinstance(f, (VectorField, NestedField)): | ||
# continue | ||
# f.data = np.array(f.data) | ||
|
||
if not self._positionupdate_kernels_added: | ||
self.add_positionupdate_kernels() | ||
|
@@ -424,6 +426,7 @@ def evaluate_particle(self, p, endtime): | |
except KeyError: | ||
if abs(endtime - p.time_nextloop) < abs(p.dt) - 1e-6: | ||
p.dt = abs(endtime - p.time_nextloop) * sign_dt | ||
|
||
res = self._pyfunc(p, self._fieldset, p.time_nextloop) | ||
|
||
if res is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from parcels.tools.loggers import logger | ||
from parcels.tools.statuscodes import StatusCode | ||
from parcels.tools.warnings import ParticleSetWarning | ||
from parcels.uxfieldset import UXFieldSet | ||
|
||
__all__ = ["ParticleSet"] | ||
|
||
|
@@ -148,7 +149,10 @@ def ArrayClass_init(self, *args, **kwargs): | |
pid_orig = np.arange(lon.size) | ||
|
||
if depth is None: | ||
mindepth = self.fieldset.gridset.dimrange("depth")[0] | ||
if type(self.fieldset) == UXFieldSet: | ||
mindepth = 0 # TO DO : get the min depth from the fieldset.uxgrid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we'll need a new implementation for |
||
else: | ||
mindepth = self.fieldset.gridset.dimrange("depth")[0] | ||
depth = np.ones(lon.size) * mindepth | ||
else: | ||
depth = convert_to_flat_array(depth) | ||
|
@@ -163,11 +167,16 @@ def ArrayClass_init(self, *args, **kwargs): | |
raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") | ||
time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) | ||
assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." | ||
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation): | ||
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full) | ||
|
||
if lonlatdepth_dtype is None: | ||
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) | ||
if type(fieldset) == UXFieldSet: | ||
lonlatdepth_dtype = np.float32 # To do : get precision from fieldset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An obvious to-do item here; we need to be able to get the precision from the XArray or UXArray dataset underneath the thin fieldset wrapper. |
||
else: | ||
if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation): | ||
_warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we'll need an updated implementation of Note that I'm assuming that a UXarray or XArray dataset have a single time dimension for all fields in the dataset. Getting the time extents for the dataset wouldn't depend on the individual fields (e.g. |
||
|
||
if lonlatdepth_dtype is None: | ||
lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) | ||
|
||
assert lonlatdepth_dtype in [ | ||
np.float32, | ||
np.float64, | ||
|
@@ -191,7 +200,10 @@ def ArrayClass_init(self, *args, **kwargs): | |
self._repeatkwargs = kwargs | ||
self._repeatkwargs.pop("partition_function", None) | ||
|
||
ngrids = fieldset.gridset.size | ||
if type(fieldset) == UXFieldSet: | ||
ngrids = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we explicitly set ngrids to 1. However, with the idea that the |
||
else: | ||
ngrids = fieldset.gridset.size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Variables used for interaction kernels. | ||
inter_dist_horiz = None | ||
|
@@ -967,7 +979,10 @@ def execute( | |
if runtime is not None and endtime is not None: | ||
raise RuntimeError("Only one of (endtime, runtime) can be specified") | ||
|
||
mintime, maxtime = self.fieldset.gridset.dimrange("time_full") | ||
if type(self.fieldset) == UXFieldSet: | ||
mintime, maxtime = self.fieldset.get_time_range() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
else: | ||
mintime, maxtime = self.fieldset.gridset.dimrange("time_full") | ||
|
||
default_release_time = mintime if dt >= 0 else maxtime | ||
if np.any(np.isnan(self.particledata.data["time"])): | ||
|
@@ -1037,18 +1052,19 @@ def execute( | |
|
||
time_at_startofloop = time | ||
|
||
next_input = self.fieldset.computeTimeChunk(time, dt) | ||
#next_input = self.fieldset.computeTimeChunk(time, dt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some work is going to be needed to get this implemented. Though I suspect some work is being done in this area currently in the v4-dev. Would be best to discuss together. |
||
|
||
# Define next_time (the timestamp when the execution needs to be handed back to python) | ||
if dt > 0: | ||
next_time = min(next_prelease, next_input, next_output, next_callback, endtime) | ||
else: | ||
next_time = max(next_prelease, next_input, next_output, next_callback, endtime) | ||
#if dt > 0: | ||
# next_time = min(next_prelease, next_input, next_output, next_callback, endtime) | ||
#else: | ||
# next_time = max(next_prelease, next_input, next_output, next_callback, endtime) | ||
|
||
next_time = endtime | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was hard-coded, just to get things working for the simple case with one time level in the UXArray dataset. |
||
# If we don't perform interaction, only execute the normal kernel efficiently. | ||
if self._interaction_kernel is None: | ||
if not skip_kernel: | ||
res = self._kernel.execute(self, endtime=next_time, dt=dt) | ||
res = self._kernel.execute(self, endtime=next_time, dt=dt) # [email protected] : switched to hardcoded endtime | ||
if res == StatusCode.StopAllExecution: | ||
return StatusCode.StopAllExecution | ||
# Interaction: interleave the interaction and non-interaction kernel for each time step. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import cftime | ||
import numpy as np | ||
import uxarray as ux | ||
import cftime | ||
|
||
from parcels._compat import MPI | ||
from parcels._typing import GridIndexingType, InterpMethodOption, Mesh | ||
from parcels.field import DeferredArray, Field, NestedField, VectorField | ||
from parcels.grid import Grid | ||
from parcels.gridset import GridSet | ||
from parcels.particlefile import ParticleFile | ||
from parcels.tools._helpers import default_repr, fieldset_repr | ||
from parcels.tools.converters import TimeConverter, convert_xarray_time_units | ||
from parcels.tools.loggers import logger | ||
from parcels.tools.statuscodes import TimeExtrapolationError | ||
from parcels.tools.warnings import FieldSetWarning | ||
|
||
__all__ = ["UXFieldSet"] | ||
|
||
_inside_tol = 1e-6 | ||
|
||
# class UXVectorField: | ||
# def __init__(self, name: str, U: ux.UxDataArray, V: ux.UxDataArray, W: ux.UxDataArray | None = None): | ||
# self.name = name | ||
# self.U = U | ||
# self.V = V | ||
# self.W = W | ||
# if self.W: | ||
# self.vector_type = "3D" | ||
# else: | ||
# self.vector_type = "2D" | ||
|
||
# def __repr__(self): | ||
# return f"""<{type(self).__name__}> | ||
# name: {self.name!r} | ||
# U: {default_repr(self.U)} | ||
# V: {default_repr(self.V)} | ||
# W: {default_repr(self.W)}""" | ||
|
||
# def eval(self, time, z, y, x, particle=None, applyConversion=True): | ||
|
||
class UXFieldSet: | ||
"""A FieldSet class that holds hydrodynamic data needed to execute particles | ||
in a UXArray.Dataset""" | ||
# Change uxds to ds_list - which is a list of either uxDataset or xarray dataset | ||
def __init__(self, uxds: ux.UxDataset, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
# Ensure that dataset provides a grid, and the u and v velocity | ||
# components at a minimum | ||
if not hasattr(uxds, "uxgrid"): | ||
raise ValueError("The UXArray dataset does not provide a grid") | ||
if not hasattr(uxds, "u"): | ||
raise ValueError("The UXArray dataset does not provide u velocity data") | ||
if not hasattr(uxds, "v"): | ||
raise ValueError("The UXArray dataset does not provide v velocity data") | ||
|
||
self.time_origin = time_origin | ||
self.uxds = uxds | ||
self._spatialhash = self.uxds.uxgrid.get_spatial_hash() | ||
|
||
#def _validate_uxds(self, uxds: ux.UxDataset): | ||
#def _validate_xds(self, xds: xr.Dataset): | ||
|
||
def _check_complete(self): | ||
assert self.uxds is not None, "UXFieldSet has not been loaded" | ||
assert self.uxds.u is not None, "UXFieldSet does not provide u velocity data" | ||
assert self.uxds.v is not None, "UXFieldSet does not provide v velocity data" | ||
assert self.uxds.uxgrid is not None, "UXFieldSet does not provide a grid" | ||
|
||
def _face_interp(self, field, time, z, y, x, ei): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, hard-coded interpolation schemes are provided for face registered data and node registered data. We may want to discuss with the Geomar folks if this dichotomy is intuitive for them or if they have a different way of viewing the interpolation problem. |
||
ti = 0 | ||
zi = 0 | ||
return field[ti,zi,ei] | ||
|
||
def _node_interp(self, field, time, z, y, x, ei): | ||
"""Performs barycentric interpolation of a field at a given location.""" | ||
ti = 0 | ||
zi = 0 | ||
coords =np.deg2rad([[x, y]]) | ||
n_nodes = self.uxds.uxgrid.n_nodes_per_face[ei].to_numpy() | ||
node_ids = self.uxds.uxgrid.face_node_connectivity[ei, 0:n_nodes] | ||
nodes = np.column_stack( | ||
( | ||
np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()), | ||
np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()), | ||
) | ||
) | ||
bcoord = np.asarray(_barycentric_coordinates(nodes, coords)) | ||
return np.sum(bcoord * field[ti,zi,node_ids].flatten(), axis=0) | ||
|
||
def get_time_range(self): | ||
return self.uxds.time.min().to_numpy(), self.uxds.time.max().to_numpy() | ||
|
||
def _point_is_in_face(self, y, x, ei): | ||
"Checks if a point is inside a given face id " | ||
#ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle | ||
fi = ei | ||
|
||
# Check if particle is in the same face, otherwise search again. | ||
n_nodes = self.uxds.uxgrid.n_nodes_per_face[fi].to_numpy() | ||
node_ids = self.uxds.uxgrid.face_node_connectivity[fi, 0:n_nodes] | ||
nodes = np.column_stack( | ||
( | ||
np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()), | ||
np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()), | ||
) | ||
) | ||
|
||
coord = np.deg2rad([x, y]) | ||
bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) | ||
if ( not (bcoord >= 0).all() ) and (not (bcoord <= 1.0).all()): | ||
return False | ||
|
||
return True | ||
|
||
def eval(self, field_names, time, z, y, x, ei: int=None, applyConversion=True): | ||
|
||
res = {} | ||
|
||
if ei is not None: | ||
fi = ei | ||
if not self._point_is_in_face(y,x,ei): | ||
# If the point is not in the previously defined face, then | ||
# search for the face again. | ||
# To do : Update the search here to do nearest neighbors search, rather than spatial hash - [email protected] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For UXArray data, we ought to add a method that does a nearest neighbors search, rather than a hash query when the particle is not found in the I suppose this code here, and the proposed code, ought to placed in a |
||
print(f"Position : {x}, {y}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print statements can probable come out now.. these were needed for debugging.. |
||
print(f"Hash indices : {self._spatialhash._hash_index2d(np.deg2rad([[x,y]]))}") | ||
fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle | ||
fi = fi[0] | ||
print(f"Face index (updated): {fi}") | ||
print(f"Barycentric coordinates (updated): {bcoords}") | ||
|
||
for f in field_names: | ||
field = getattr(self.uxds, f) | ||
face_registered = ("n_face" in field.dims) | ||
|
||
if face_registered: | ||
r = self._face_interp(field, time, z, y, x, fi) | ||
else: | ||
r = self._node_interp(field, time, z, y, x, fi) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For 3-D interpolation, we need to add vertical interpolation. Currently, all interpolation uses the fields at the upper most vertical layer |
||
#if applyConversion: | ||
# res[f] = self.units.to_target(r, z, y, x) | ||
#else: | ||
# To do : Add call to units.to_target to handle unit conversion : [email protected] | ||
res[f] = r/111111.111111111 | ||
|
||
return res, fi | ||
|
||
def _barycentric_coordinates(nodes, point): | ||
""" | ||
Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. | ||
So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized | ||
barycentric coordinates, which is only valid for convex polygons. | ||
|
||
Parameters | ||
---------- | ||
nodes : numpy.ndarray | ||
Spherical coordinates (lon,lat) of each corner node of a face | ||
point : numpy.ndarray | ||
Spherical coordinates (lon,lat) of the point | ||
Returns | ||
------- | ||
numpy.ndarray | ||
Barycentric coordinates corresponding to each vertex. | ||
|
||
""" | ||
n = len(nodes) | ||
sum_wi = 0 | ||
w = [] | ||
|
||
for i in range(0, n): | ||
vim1 = nodes[i - 1] | ||
vi = nodes[i] | ||
vi1 = nodes[(i + 1) % n] | ||
a0 = _triangle_area(vim1, vi, vi1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These areas need to be clipped with a minimum value. As we found in uxarray, when a particle lies on a face edge or vertex, these areas can be zero, leading to division by zero. |
||
a1 = _triangle_area(point, vim1, vi) | ||
a2 = _triangle_area(point, vi, vi1) | ||
sum_wi += a0 / (a1 * a2) | ||
w.append(a0 / (a1 * a2)) | ||
|
||
barycentric_coords = [w_i / sum_wi for w_i in w] | ||
|
||
return barycentric_coords | ||
|
||
def _triangle_area(A, B, C): | ||
""" | ||
Compute the area of a triangle given by three points. | ||
""" | ||
return 0.5 * (A[0] * (B[1] - C[1]) + B[0] * (C[1] - A[1]) + C[0] * (A[1] - B[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to have advection routines for UXArray that are distinct from XArray, unless we are able to get uniformity in the velocity field naming conventions between structured and unstructured data sets.