Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8d025da
Start of implementing C-Grid interpolation
erikvansebille Aug 15, 2025
50d9d52
Moving C-grid velocity code from v3 to v4
erikvansebille Aug 15, 2025
3226306
use time_interval type to set default time
erikvansebille Aug 15, 2025
1bd7e0c
Adding nemo curvilinear test for C-grid
erikvansebille Aug 15, 2025
2da436a
Speeding up curvilinear search by dask loading lon and lat
erikvansebille Aug 15, 2025
0e1a221
Updating c-grid velocity test and algorithm
erikvansebille Aug 15, 2025
1ecab22
Fixing vector interpolation
erikvansebille Aug 18, 2025
8b1b830
Fixing CGrid_Velocity interpolation for multiple particles
erikvansebille Aug 18, 2025
cad34e0
Adding better error message handling for CGrid_velocity interpolation
erikvansebille Aug 18, 2025
26cffcc
Adding warning suppression for index_search
erikvansebille Aug 18, 2025
e2376e2
Fixing error when Grid does not have lon or lat
erikvansebille Aug 18, 2025
6529e13
Fixing to keep the maximum Error code in field
erikvansebille Aug 18, 2025
bbc2c30
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Aug 19, 2025
dda469d
Adding NEMO3D test
erikvansebille Aug 21, 2025
76d015b
Updating CGrid interpolation to not interpolate over depth for U and V
erikvansebille Aug 22, 2025
196c6e7
Fixing W interpolation for CGrid
erikvansebille Aug 22, 2025
1c948dd
Fixing stommel gyre CGrid interpolation test
erikvansebille Aug 22, 2025
df78b62
Updating failing unit test
erikvansebille Aug 22, 2025
cd69147
Further fixing unit test by dropping unused dimensions
erikvansebille Aug 22, 2025
1a64033
Temporary fix to spatialhash
erikvansebille Aug 22, 2025
7c1a87a
Adding TODO statement about spherical meshes
erikvansebille Aug 22, 2025
07a8238
Updating spherical mash hashmap creation
erikvansebille Aug 22, 2025
f77f5f4
merge
erikvansebille Aug 25, 2025
2924ee0
Fixing grid._mesh in interpolator
erikvansebille Aug 25, 2025
9264d8a
Merge branch 'feature/morton-hashing' into c-grid-interpolation
erikvansebille Aug 25, 2025
1a17911
Using is_dask_collection to check for dask in interpolation
erikvansebille Aug 25, 2025
6753e81
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Aug 27, 2025
5bdc874
Fixing vector_interp_method
erikvansebille Aug 28, 2025
a0962dc
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 1, 2025
6659a1d
fixing merging bugs
erikvansebille Sep 1, 2025
f14e8cb
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 1, 2025
2568feb
Using is_dask_collection for c-grid interpolator
erikvansebille Sep 1, 2025
525ec71
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 3, 2025
cebda92
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 8, 2025
0ac7ae0
Removing mesh types check for VectorField
erikvansebille Sep 9, 2025
872b75c
Setting lon and lat as coordinates
erikvansebille Sep 9, 2025
0be3404
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 9, 2025
38622fc
Merge branch 'v4-dev' into c-grid-interpolation
VeckoTheGecko Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 221 additions & 5 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
import numpy as np
import xarray as xr

import parcels.tools.interpolation_utils as i_u

if TYPE_CHECKING:
from parcels.field import Field
from parcels.field import Field, VectorField
from parcels.uxgrid import _UXGRID_AXES
from parcels.xgrid import _XGRID_AXES

__all__ = [
"CGrid_Tracer",
"CGrid_Velocity",
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
"XLinear",
Expand Down Expand Up @@ -52,6 +56,7 @@ def XLinear(

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data
tdim, zdim, ydim, xdim = data.shape[0], data.shape[1], data.shape[2], data.shape[3]

lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1
Expand All @@ -60,22 +65,22 @@ def XLinear(
if lenT == 1:
ti = np.repeat(ti, lenZ * 4)
else:
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])

# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
if lenZ == 1:
zi = np.repeat(zi, lenT * 4)
else:
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi_1 = np.clip(zi + 1, 0, zdim - 1)
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)

# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))

# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))

# Create DataArrays for indexing
Expand Down Expand Up @@ -111,6 +116,217 @@ def XLinear(
return value.compute() if isinstance(value, dask.Array) else value


def CGrid_Velocity(
vectorfield: VectorField,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
applyConversion: bool,
):
"""
Interpolation kernel for velocity fields on a C-Grid.
Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated
only in the direction of the grid cell faces.
"""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

U = vectorfield.U.data
V = vectorfield.V.data
grid = vectorfield.grid
tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3]

if grid.lon.ndim == 1:
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]])
else:
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])

if grid.mesh == "spherical":
px[0] = px[0] + 360 if px[0] < x - 225 else px[0]
px[0] = px[0] - 360 if px[0] > x + 225 else px[0]
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
np.testing.assert_allclose(xx, x, atol=1e-4)
c1 = i_u._geodetic_distance(py[0], py[1], px[0], px[1], grid.mesh, np.dot(i_u.phi2D_lin(0.0, xsi), py))
c2 = i_u._geodetic_distance(py[1], py[2], px[1], px[2], grid.mesh, np.dot(i_u.phi2D_lin(eta, 1.0), py))
c3 = i_u._geodetic_distance(py[2], py[3], px[2], px[3], grid.mesh, np.dot(i_u.phi2D_lin(1.0, xsi), py))
c4 = i_u._geodetic_distance(py[3], py[0], px[3], px[0], grid.mesh, np.dot(i_u.phi2D_lin(eta, 0.0), py))

lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1

# Create arrays of corner points for xarray.isel
# TODO C grid may not need all xi and yi cornerpoints, so could speed up here?

# Time coordinates: 8 points at ti, then 8 points at ti+1
if lenT == 1:
ti = np.repeat(ti, lenZ * 4)
else:
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])

# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
if lenZ == 1:
zi = np.repeat(zi, lenT * 4)
else:
zi_1 = np.clip(zi + 1, 0, zdim - 1)
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)

# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))

# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))

for data in [U, V]:
axis_dim = grid.get_axis_dim_mapping(data.dims)

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
}
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
if "time" in data.dims:
selection_dict["time"] = xr.DataArray(ti, dims=("points"))

corner_data = data.isel(selection_dict).data.reshape(lenT, lenZ, len(xsi), 4)

if lenT == 2:
tau = tau[np.newaxis, :, np.newaxis]
corner_data = corner_data[0, :, :, :] * (1 - tau) + corner_data[1, :, :, :] * tau
else:
corner_data = corner_data[0, :, :, :]

if lenZ == 2:
zeta = zeta[:, np.newaxis]
corner_data = corner_data[0, :, :] * (1 - zeta) + corner_data[1, :, :] * zeta
else:
corner_data = corner_data[0, :, :]
# # See code below for v3 version
# # if self.gridindexingtype == "nemo":
# # U0 = self.U.data[ti, zi, yi + 1, xi] * c4
# # U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2
# # V0 = self.V.data[ti, zi, yi, xi + 1] * c1
# # V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3
# # elif self.gridindexingtype in ["mitgcm", "croco"]:
# # U0 = self.U.data[ti, zi, yi, xi] * c4
# # U1 = self.U.data[ti, zi, yi, xi + 1] * c2
# # V0 = self.V.data[ti, zi, yi, xi] * c1
# # V1 = self.V.data[ti, zi, yi + 1, xi] * c3
# # TODO Nick can you help use xgcm to fix this implementation?

# # CROCO and MITgcm grid indexing,
# if data is U:
# U0 = corner_data[:, 0] * c4
# U1 = corner_data[:, 1] * c2
# elif data is V:
# V0 = corner_data[:, 0] * c1
# V1 = corner_data[:, 2] * c3
# # NEMO grid indexing
if data is U:
U0 = corner_data[:, 2] * c4
U1 = corner_data[:, 3] * c2
elif data is V:
V0 = corner_data[:, 1] * c1
V1 = corner_data[:, 3] * c3

U = (1 - xsi) * U0 + xsi * U1
V = (1 - eta) * V0 + eta * V1

deg2m = 1852 * 60.0
if applyConversion:
meshJac = (deg2m * deg2m * np.cos(np.deg2rad(y))) if grid.mesh == "spherical" else 1
else:
meshJac = deg2m if grid.mesh == "spherical" else 1

jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac

u = (
(-(1 - eta) * U - (1 - xsi) * V) * px[0]
+ ((1 - eta) * U - xsi * V) * px[1]
+ (eta * U + xsi * V) * px[2]
+ (-eta * U + (1 - xsi) * V) * px[3]
) / jac
v = (
(-(1 - eta) * U - (1 - xsi) * V) * py[0]
+ ((1 - eta) * U - xsi * V) * py[1]
+ (eta * U + xsi * V) * py[2]
+ (-eta * U + (1 - xsi) * V) * py[3]
) / jac
if isinstance(u, dask.Array):
u = u.compute()
v = v.compute()

return (u, v, 0) # TODO fix and test W also


def CGrid_Tracer(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Interpolation kernel for tracer fields on a C-Grid.

Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated
constant over the grid cell
"""
xi, _ = position["X"]
yi, _ = position["Y"]
zi, _ = position["Z"]

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data

lenT = 2 if np.any(tau > 0) else 1

if lenT == 2:
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti = np.concatenate([np.repeat(ti), np.repeat(ti_1)])
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi = np.concatenate([np.repeat(zi), np.repeat(zi_1)])
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi = np.concatenate([np.repeat(yi), np.repeat(yi_1)])
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi = np.concatenate([np.repeat(xi), np.repeat(xi_1)])

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
}
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
if "time" in field.data.dims:
selection_dict["time"] = xr.DataArray(ti, dims=("points"))

value = data.isel(selection_dict).data.reshape(lenT, len(xi))

if lenT == 2:
tau = tau[:, np.newaxis]
value = value[0, :] * (1 - tau) + value[1, :] * tau
else:
value = value[0, :]

return value.compute() if isinstance(value, dask.Array) else value


def UXPiecewiseConstantFace(
field: Field,
ti: int,
Expand Down
10 changes: 5 additions & 5 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VectorType,
assert_valid_mesh,
)
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
from parcels.application_kernels.interpolation import CGrid_Velocity, UXPiecewiseLinearNode, XLinear, ZeroInterpolator
from parcels.particle import KernelParticle
from parcels.particleset import ParticleSet
from parcels.tools.converters import (
Expand Down Expand Up @@ -292,8 +292,8 @@ def __init__(
if vector_interp_method is None:
self._vector_interp_method = None
else:
_assert_same_function_signature(vector_interp_method, ref=ZeroInterpolator)
self._interp_method = vector_interp_method
_assert_same_function_signature(vector_interp_method, ref=CGrid_Velocity)
self._vector_interp_method = vector_interp_method

def __repr__(self):
return f"""<{type(self).__name__}>
Expand All @@ -308,7 +308,7 @@ def vector_interp_method(self):

@vector_interp_method.setter
def vector_interp_method(self, method: Callable):
_assert_same_function_signature(method, ref=ZeroInterpolator)
_assert_same_function_signature(method, ref=CGrid_Velocity)
self._vector_interp_method = method

def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
Expand All @@ -333,7 +333,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
if "3D" in self.vector_type:
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
else:
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)
(u, v, w) = self._vector_interp_method(self, ti, position, tau, time, z, y, x, applyConversion)

if applyConversion:
u = self.U.units.to_target(u, z, y, x)
Expand Down
8 changes: 5 additions & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def __init__(
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"

if time is None or len(time) == 0:
time = type(fieldset.U.data.time[0].values)(
"NaT", "ns"
) # do not set a time yet (because sign_dt not known)
# do not set a time yet (because sign_dt not known)
if fieldset.time_interval is None:
time = np.timedelta64("NaT", "ns")
else:
time = type(fieldset.time_interval.left)("NaT", "ns")
elif type(time[0]) in [np.datetime64, np.timedelta64]:
pass # already in the right format
else:
Expand Down
23 changes: 12 additions & 11 deletions parcels/tools/interpolation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def phi1D_quad(xsi: float) -> list[float]:


def phi2D_lin(eta: float, xsi: float) -> list[float]:
phi = [(1-xsi) * (1-eta),
xsi * (1-eta),
xsi * eta ,
(1-xsi) * eta ]
phi = np.column_stack([(1-xsi) * (1-eta),
xsi * (1-eta),
xsi * eta ,
(1-xsi) * eta ])

return phi

Expand Down Expand Up @@ -185,12 +185,13 @@ def _geodetic_distance(lat1: float, lat2: float, lon1: float, lon2: float, mesh:


def _compute_jacobian_determinant(py: np.ndarray, px: np.ndarray, eta: float, xsi: float) -> float:
dphidxsi = [eta - 1, 1 - eta, eta, -eta]
dphideta = [xsi - 1, -xsi, xsi, 1 - xsi]
dphidxsi = np.column_stack([eta - 1, 1 - eta, eta, -eta])
dphideta = np.column_stack([xsi - 1, -xsi, xsi, 1 - xsi])

dxdxsi = np.dot(px, dphidxsi)
dxdeta = np.dot(px, dphideta)
dydxsi = np.dot(py, dphidxsi)
dydeta = np.dot(py, dphideta)
dxdxsi = np.dot(dphidxsi, px)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erikvansebille, following up on our discussion about these four $n \times n$ matrices, and the possible memory usage that may incur for large particle sets, I think the following code can work (and reduce any need to sparse)

dxdxsi_diag = np.einsum('ij,ji->i', dphidxsi, px)
dxdeta_diag = np.einsum('ij,ji->i', dphideta, px)
dydxsi_diag = np.einsum('ij,ji->i', dphidxsi, py)
dydeta_diag = np.einsum('ij,ji->i', dphideta, py)

jac_diag = dxdxsi_diag * dydeta_diag - dxdeta_diag * dydxsi_diag
return jac_diag

dxdeta = np.dot(dphideta, px)
dydxsi = np.dot(dphidxsi, py)
dydeta = np.dot(dphideta, py)
jac = dxdxsi * dydeta - dxdeta * dydxsi
return jac
# TODO check how to properly vectorize this function (and not return only half of the Jacobian)
return jac.diagonal()
4 changes: 4 additions & 0 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"):
self.mesh = mesh
self._spatialhash = None
ds = grid._ds
if hasattr(ds["lon"], "load"):
ds["lon"].load()
if hasattr(ds["lat"], "load"):
ds["lat"].load()

if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
Expand Down
Loading
Loading