diff --git a/parcels/__init__.py b/parcels/__init__.py index 0eeb3d280..c549dbdc6 100644 --- a/parcels/__init__.py +++ b/parcels/__init__.py @@ -13,3 +13,4 @@ from parcels.particlefile import * from parcels.particleset import * from parcels.tools import * +from parcels.uxfieldset import * diff --git a/parcels/application_kernels/advection.py b/parcels/application_kernels/advection.py index d35d3063e..e2d626abd 100644 --- a/parcels/application_kernels/advection.py +++ b/parcels/application_kernels/advection.py @@ -3,8 +3,10 @@ import math from parcels.tools.statuscodes import StatusCode +from parcels.uxfieldset import UXFieldSet __all__ = [ + "UxAdvectionEuler", "AdvectionAnalytical", "AdvectionEE", "AdvectionRK4", @@ -14,6 +16,15 @@ ] +def UxAdvectionEuler(particle, fieldset: UXFieldSet, time): + """Advection of particles using Explicit Euler (aka Euler Forward) integration. + on an unstructured grid.""" + vel, ei = fieldset.eval(["u","v"],time,particle.depth,particle.lat,particle.lon, particle.ei[0]) + particle.ei[0] = ei + particle_dlon += vel["u"] * particle.dt + particle_dlat += vel["v"] * particle.dt + + def AdvectionRK4(particle, fieldset, time): # pragma: no cover """Advection of particles using fourth-order Runge-Kutta integration.""" (u1, v1) = fieldset.UV[particle] diff --git a/parcels/kernel.py b/parcels/kernel.py index 08224bb0c..795ef5920 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -206,6 +206,8 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth This function is to be called from the derived class when setting up the 'pyfunc'. """ + + ## To do : Add checks for UXFieldSet - joe@fluidnumerics.com if self.fieldset is not None: if pyfunc is AdvectionRK4_3D: warning = False @@ -335,17 +337,17 @@ def execute(self, pset, endtime, dt): stacklevel=2, ) - if pset.fieldset is not None: - for g in pset.fieldset.gridset.grids: - if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel - g._load_chunk = np.where( - g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk - ) + # if pset.fieldset is not None: + # for g in pset.fieldset.gridset.grids: + # if len(g._load_chunk) > g._chunk_not_loaded: # not the case if a field in not called in the kernel + # g._load_chunk = np.where( + # g._load_chunk == g._chunk_loaded_touched, g._chunk_deprecated, g._load_chunk + # ) - for f in self.fieldset.get_fields(): - if isinstance(f, (VectorField, NestedField)): - continue - f.data = np.array(f.data) + # for f in self.fieldset.get_fields(): + # if isinstance(f, (VectorField, NestedField)): + # continue + # f.data = np.array(f.data) if not self._positionupdate_kernels_added: self.add_positionupdate_kernels() @@ -424,6 +426,7 @@ def evaluate_particle(self, p, endtime): except KeyError: if abs(endtime - p.time_nextloop) < abs(p.dt) - 1e-6: p.dt = abs(endtime - p.time_nextloop) * sign_dt + res = self._pyfunc(p, self._fieldset, p.time_nextloop) if res is None: diff --git a/parcels/particleset.py b/parcels/particleset.py index e6946e2d8..ff1a85ff6 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -30,6 +30,7 @@ from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode from parcels.tools.warnings import ParticleSetWarning +from parcels.uxfieldset import UXFieldSet __all__ = ["ParticleSet"] @@ -148,7 +149,10 @@ def ArrayClass_init(self, *args, **kwargs): pid_orig = np.arange(lon.size) if depth is None: - mindepth = self.fieldset.gridset.dimrange("depth")[0] + if type(self.fieldset) == UXFieldSet: + mindepth = 0 # TO DO : get the min depth from the fieldset.uxgrid + else: + mindepth = self.fieldset.gridset.dimrange("depth")[0] depth = np.ones(lon.size) * mindepth else: depth = convert_to_flat_array(depth) @@ -163,11 +167,16 @@ def ArrayClass_init(self, *args, **kwargs): raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." - if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation): - _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full) - if lonlatdepth_dtype is None: - lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) + if type(fieldset) == UXFieldSet: + lonlatdepth_dtype = np.float32 # To do : get precision from fieldset + else: + if isinstance(fieldset.U, Field) and (not fieldset.U.allow_time_extrapolation): + _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.U.grid.time_full) + + if lonlatdepth_dtype is None: + lonlatdepth_dtype = self.lonlatdepth_dtype_from_field_interp_method(fieldset.U) + assert lonlatdepth_dtype in [ np.float32, np.float64, @@ -191,7 +200,10 @@ def ArrayClass_init(self, *args, **kwargs): self._repeatkwargs = kwargs self._repeatkwargs.pop("partition_function", None) - ngrids = fieldset.gridset.size + if type(fieldset) == UXFieldSet: + ngrids = 1 + else: + ngrids = fieldset.gridset.size # Variables used for interaction kernels. inter_dist_horiz = None @@ -967,7 +979,10 @@ def execute( if runtime is not None and endtime is not None: raise RuntimeError("Only one of (endtime, runtime) can be specified") - mintime, maxtime = self.fieldset.gridset.dimrange("time_full") + if type(self.fieldset) == UXFieldSet: + mintime, maxtime = self.fieldset.get_time_range() + else: + mintime, maxtime = self.fieldset.gridset.dimrange("time_full") default_release_time = mintime if dt >= 0 else maxtime if np.any(np.isnan(self.particledata.data["time"])): @@ -1037,18 +1052,19 @@ def execute( time_at_startofloop = time - next_input = self.fieldset.computeTimeChunk(time, dt) + #next_input = self.fieldset.computeTimeChunk(time, dt) # Define next_time (the timestamp when the execution needs to be handed back to python) - if dt > 0: - next_time = min(next_prelease, next_input, next_output, next_callback, endtime) - else: - next_time = max(next_prelease, next_input, next_output, next_callback, endtime) + #if dt > 0: + # next_time = min(next_prelease, next_input, next_output, next_callback, endtime) + #else: + # next_time = max(next_prelease, next_input, next_output, next_callback, endtime) + next_time = endtime # If we don't perform interaction, only execute the normal kernel efficiently. if self._interaction_kernel is None: if not skip_kernel: - res = self._kernel.execute(self, endtime=next_time, dt=dt) + res = self._kernel.execute(self, endtime=next_time, dt=dt) # joe@fluidnumerics.com : switched to hardcoded endtime if res == StatusCode.StopAllExecution: return StatusCode.StopAllExecution # Interaction: interleave the interaction and non-interaction kernel for each time step. diff --git a/parcels/uxfieldset.py b/parcels/uxfieldset.py new file mode 100644 index 000000000..0cb5fb9ec --- /dev/null +++ b/parcels/uxfieldset.py @@ -0,0 +1,189 @@ +import cftime +import numpy as np +import uxarray as ux +import cftime + +from parcels._compat import MPI +from parcels._typing import GridIndexingType, InterpMethodOption, Mesh +from parcels.field import DeferredArray, Field, NestedField, VectorField +from parcels.grid import Grid +from parcels.gridset import GridSet +from parcels.particlefile import ParticleFile +from parcels.tools._helpers import default_repr, fieldset_repr +from parcels.tools.converters import TimeConverter, convert_xarray_time_units +from parcels.tools.loggers import logger +from parcels.tools.statuscodes import TimeExtrapolationError +from parcels.tools.warnings import FieldSetWarning + +__all__ = ["UXFieldSet"] + +_inside_tol = 1e-6 + +# class UXVectorField: +# def __init__(self, name: str, U: ux.UxDataArray, V: ux.UxDataArray, W: ux.UxDataArray | None = None): +# self.name = name +# self.U = U +# self.V = V +# self.W = W +# if self.W: +# self.vector_type = "3D" +# else: +# self.vector_type = "2D" + +# def __repr__(self): +# return f"""<{type(self).__name__}> +# name: {self.name!r} +# U: {default_repr(self.U)} +# V: {default_repr(self.V)} +# W: {default_repr(self.W)}""" + +# def eval(self, time, z, y, x, particle=None, applyConversion=True): + +class UXFieldSet: + """A FieldSet class that holds hydrodynamic data needed to execute particles + in a UXArray.Dataset""" + # Change uxds to ds_list - which is a list of either uxDataset or xarray dataset + def __init__(self, uxds: ux.UxDataset, time_origin: float | np.datetime64 | np.timedelta64 | cftime.datetime = 0): + # Ensure that dataset provides a grid, and the u and v velocity + # components at a minimum + if not hasattr(uxds, "uxgrid"): + raise ValueError("The UXArray dataset does not provide a grid") + if not hasattr(uxds, "u"): + raise ValueError("The UXArray dataset does not provide u velocity data") + if not hasattr(uxds, "v"): + raise ValueError("The UXArray dataset does not provide v velocity data") + + self.time_origin = time_origin + self.uxds = uxds + self._spatialhash = self.uxds.uxgrid.get_spatial_hash() + + #def _validate_uxds(self, uxds: ux.UxDataset): + #def _validate_xds(self, xds: xr.Dataset): + + def _check_complete(self): + assert self.uxds is not None, "UXFieldSet has not been loaded" + assert self.uxds.u is not None, "UXFieldSet does not provide u velocity data" + assert self.uxds.v is not None, "UXFieldSet does not provide v velocity data" + assert self.uxds.uxgrid is not None, "UXFieldSet does not provide a grid" + + def _face_interp(self, field, time, z, y, x, ei): + ti = 0 + zi = 0 + return field[ti,zi,ei] + + def _node_interp(self, field, time, z, y, x, ei): + """Performs barycentric interpolation of a field at a given location.""" + ti = 0 + zi = 0 + coords =np.deg2rad([[x, y]]) + n_nodes = self.uxds.uxgrid.n_nodes_per_face[ei].to_numpy() + node_ids = self.uxds.uxgrid.face_node_connectivity[ei, 0:n_nodes] + nodes = np.column_stack( + ( + np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()), + np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()), + ) + ) + bcoord = np.asarray(_barycentric_coordinates(nodes, coords)) + return np.sum(bcoord * field[ti,zi,node_ids].flatten(), axis=0) + + def get_time_range(self): + return self.uxds.time.min().to_numpy(), self.uxds.time.max().to_numpy() + + def _point_is_in_face(self, y, x, ei): + "Checks if a point is inside a given face id " + #ti, zi, fi = self.unravel_index(particle.ei) # Get the time, z, and face index of the particle + fi = ei + + # Check if particle is in the same face, otherwise search again. + n_nodes = self.uxds.uxgrid.n_nodes_per_face[fi].to_numpy() + node_ids = self.uxds.uxgrid.face_node_connectivity[fi, 0:n_nodes] + nodes = np.column_stack( + ( + np.deg2rad(self.uxds.uxgrid.node_lon[node_ids].to_numpy()), + np.deg2rad(self.uxds.uxgrid.node_lat[node_ids].to_numpy()), + ) + ) + + coord = np.deg2rad([x, y]) + bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) + if ( not (bcoord >= 0).all() ) and (not (bcoord <= 1.0).all()): + return False + + return True + + def eval(self, field_names, time, z, y, x, ei: int=None, applyConversion=True): + + res = {} + + if ei is not None: + fi = ei + if not self._point_is_in_face(y,x,ei): + # If the point is not in the previously defined face, then + # search for the face again. + # To do : Update the search here to do nearest neighbors search, rather than spatial hash - joe@fluidnumerics.com + print(f"Position : {x}, {y}") + print(f"Hash indices : {self._spatialhash._hash_index2d(np.deg2rad([[x,y]]))}") + fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + fi = fi[0] + print(f"Face index (updated): {fi}") + print(f"Barycentric coordinates (updated): {bcoords}") + + for f in field_names: + field = getattr(self.uxds, f) + face_registered = ("n_face" in field.dims) + + if face_registered: + r = self._face_interp(field, time, z, y, x, fi) + else: + r = self._node_interp(field, time, z, y, x, fi) + + #if applyConversion: + # res[f] = self.units.to_target(r, z, y, x) + #else: + # To do : Add call to units.to_target to handle unit conversion : joe@fluidnumerics.com + res[f] = r/111111.111111111 + + return res, fi + +def _barycentric_coordinates(nodes, point): + """ + Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. + So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized + barycentric coordinates, which is only valid for convex polygons. + + Parameters + ---------- + nodes : numpy.ndarray + Spherical coordinates (lon,lat) of each corner node of a face + point : numpy.ndarray + Spherical coordinates (lon,lat) of the point + Returns + ------- + numpy.ndarray + Barycentric coordinates corresponding to each vertex. + + """ + n = len(nodes) + sum_wi = 0 + w = [] + + for i in range(0, n): + vim1 = nodes[i - 1] + vi = nodes[i] + vi1 = nodes[(i + 1) % n] + a0 = _triangle_area(vim1, vi, vi1) + a1 = _triangle_area(point, vim1, vi) + a2 = _triangle_area(point, vi, vi1) + sum_wi += a0 / (a1 * a2) + w.append(a0 / (a1 * a2)) + + barycentric_coords = [w_i / sum_wi for w_i in w] + + return barycentric_coords + +def _triangle_area(A, B, C): + """ + Compute the area of a triangle given by three points. + """ + return 0.5 * (A[0] * (B[1] - C[1]) + B[0] * (C[1] - A[1]) + C[0] * (A[1] - B[1])) \ No newline at end of file diff --git a/tests/test_uxfieldset.py b/tests/test_uxfieldset.py new file mode 100644 index 000000000..483ad6948 --- /dev/null +++ b/tests/test_uxfieldset.py @@ -0,0 +1,56 @@ +import uxarray as ux + +from parcels import ( + UXFieldSet, + ParticleSet, + Particle, + UxAdvectionEuler +) + +from tests.utils import TEST_DATA + + +def test_fesom_fieldset(): + # Load a FESOM dataset + grid_path = f"{TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{TEST_DATA}/u.fesom_channel.nc", + f"{TEST_DATA}/v.fesom_channel.nc", + f"{TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + fieldset = UXFieldSet(ds) + fieldset._check_complete() + # Check that the fieldset has the expected properties + assert fieldset.uxds == ds + + +def test_fesom_in_particleset(): + # Load a FESOM dataset + grid_path = f"{TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{TEST_DATA}/u.fesom_channel.nc", + f"{TEST_DATA}/v.fesom_channel.nc", + f"{TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + fieldset = UXFieldSet(ds) + pset = ParticleSet(fieldset, pclass=Particle) + +def test_advection_fesom_channel(): + """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" + grid_path=f"{TEST_DATA}/fesom_channel.nc" + data_path=[f"{TEST_DATA}/u.fesom_channel.nc", + f"{TEST_DATA}/v.fesom_channel.nc", + f"{TEST_DATA}/w.fesom_channel.nc"] + ds = ux.open_mfdataset(grid_path,data_path) + fieldset = UXFieldSet(ds) + print(f"Spatial hash grid shape : {fieldset._spatialhash._nx}, {fieldset._spatialhash._ny}") + npart = 10 + pset = ParticleSet( + fieldset, + pclass=Particle, + lon=np.zeros(npart) + 2.0, + lat=np.linspace(5, 15, npart)) + pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=600)) + #assert (np.diff(pset3D.lon) > 1.0e-4).all() \ No newline at end of file