Skip to content
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
fe5d894
First attempt to vectorize kernel loop
erikvansebille Jul 29, 2025
867662b
Adding radial rotation flow
erikvansebille Jul 29, 2025
bce90d5
Fixing kernel statement
erikvansebille Jul 29, 2025
aa42ac2
Make Parcels vectorized for simple kernels
erikvansebille Jul 29, 2025
f6097eb
Support for Field[particle] evaluation
erikvansebille Jul 29, 2025
926af77
Making XBiLinear also work on time-evolving fields
erikvansebille Jul 29, 2025
c975c2a
Fixing unit tests for vectorized kernels
erikvansebille Jul 30, 2025
809eb78
Fixing more unit tests for vectorised kernels
erikvansebille Jul 30, 2025
2ec28da
Fix diffusion tests for vectorized kernel
erikvansebille Jul 30, 2025
5e536f7
merge v4-dev
erikvansebille Jul 30, 2025
088c2cb
fixing more unit tests for vectorized kernels
erikvansebille Jul 30, 2025
cb702e1
Update search to handle array of lat,lon
fluidnumerics-joe Jul 30, 2025
bac89b6
Merge branch 'vectorized-kernel' of github.com:OceanParcels/Parcels i…
fluidnumerics-joe Jul 30, 2025
eeff3ad
Fixing setattr on particleset level
erikvansebille Jul 31, 2025
b87691d
Fixing last failing tests in vectorized kernels
erikvansebille Jul 31, 2025
613f873
Fixing import in uxgrid to use right file
erikvansebille Jul 31, 2025
88c7388
clean up kernel loop
erikvansebille Jul 31, 2025
6c4acc0
Setting default interpolators on Fields
erikvansebille Jul 31, 2025
cac1f17
simplifying vectorized XLinear interpolation function
erikvansebille Jul 31, 2025
a775671
Small speedup in _search_1d_array
erikvansebille Aug 4, 2025
dae70fb
improving XLinear by using only 1 isel call
erikvansebille Aug 4, 2025
a0c0573
Further improving XLinear Interpolaiton
erikvansebille Aug 4, 2025
21d7c73
Further speeding up access pattern creation in XLinear
erikvansebille Aug 4, 2025
587bc6e
speeding up XLinear by reducing clipping operation
erikvansebille Aug 4, 2025
dd9bfd2
Calling dask.compute() at end of XInterpolation
erikvansebille Aug 4, 2025
92c3d82
Fixing bug in creation of xi and yi indices
erikvansebille Aug 4, 2025
9bd07f7
Fixing vectorized kernels from running particles that start after end…
erikvansebille Aug 4, 2025
067d082
Fixing xi and yi calculations
erikvansebille Aug 5, 2025
9cd1d8e
Using numpy for field indexing (instead of slower xarray.isel)
erikvansebille Aug 5, 2025
26ec76f
Reducing number of support variables
erikvansebille Aug 5, 2025
ded4ebd
Reverting back to xarray.isel, as dask only works this way(?)
erikvansebille Aug 5, 2025
a55db91
Another fix to interpolation xi vector generation
erikvansebille Aug 5, 2025
4c89533
Extending tests in test_particleset_execute also to more particles
erikvansebille Aug 5, 2025
95f0726
Fixing bug in reshaping of corner_data in XLinear
erikvansebille Aug 6, 2025
e0faea6
Fix to how zi array is compiled in XLinear
erikvansebille Aug 6, 2025
c9dc0bd
Merge branch 'v4-dev' into vectorized-kernel
erikvansebille Aug 6, 2025
ec03f53
Merge branch 'v4-dev' into vectorized-kernel
erikvansebille Aug 6, 2025
408bcf6
Clarifying docstring in XLinear interpolation
erikvansebille Aug 6, 2025
db776fb
fixing merge conflict errors
erikvansebille Aug 7, 2025
a8e4e68
fix CLinear interpolation for grids without depth or time dimension
erikvansebille Aug 7, 2025
d4b1adc
Quick clean-up of RK45 in vectorized kernels
erikvansebille Aug 7, 2025
dd9ccee
Merge branch 'v4-dev' into vectorized-kernel
erikvansebille Aug 7, 2025
9ee900e
fixing merge issues
erikvansebille Aug 7, 2025
354b9be
Changing time_interval.__contains__ to is_all_time_in_interval()
erikvansebille Aug 13, 2025
d6677b2
Renaming test function to be more explicit
erikvansebille Aug 13, 2025
535371a
Expanding unit test to make its function clearer
erikvansebille Aug 13, 2025
d0e6573
Adding vectorised tests for _search_1d_array
erikvansebille Aug 13, 2025
f375db0
vectorized version of _search_indices_curvilinear_2d
erikvansebille Aug 13, 2025
af9aee0
Fixing RK45 for vectorized kernels
erikvansebille Aug 13, 2025
4c51124
Fixing breaking test now that _search_inidces_curvilinear_2d expects …
erikvansebille Aug 13, 2025
9335ddb
Using particle.time in the advection kernels
erikvansebille Aug 13, 2025
04946b7
Fix unit test
erikvansebille Aug 13, 2025
bee5464
Cleaning up kernel loop
erikvansebille Aug 14, 2025
7c6b21d
Adding KernelParticle getitem
erikvansebille Aug 14, 2025
a2feb9e
Fixing advection unit tests
erikvansebille Aug 14, 2025
654ae1a
Fixing xgrid unit test
erikvansebille Aug 14, 2025
316f15e
Fixing particleset_execution unit tests
erikvansebille Aug 14, 2025
9652bf7
Adding and extending unit tests on particle errors
erikvansebille Aug 14, 2025
081753d
merge commit
erikvansebille Aug 14, 2025
4c90bbc
Cleaning up Kernel for RK45
erikvansebille Aug 14, 2025
161ad9c
updating advection tests for faster runtime
erikvansebille Aug 15, 2025
558b466
Update parcels/_core/utils/time.py
erikvansebille Aug 19, 2025
d4fdc01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2025
5d2bf9a
Reverting that key in eval can be a ParticleSet
erikvansebille Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def __init__(self, left: T, right: T) -> None:
self.left = left
self.right = right

def __contains__(self, item: T) -> bool:
return self.left <= item <= self.right
def is_all_time_in_interval(self, time):
item = np.atleast_1d(time)
return (self.left <= item).all() and (item <= self.right).all()

def __repr__(self) -> str:
return f"TimeInterval(left={self.left!r}, right={self.right!r})"
Expand Down
107 changes: 33 additions & 74 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,14 @@ def _search_time_index(field: Field, time: datetime):
if the sampled value is outside the time value range.
"""
if field.time_interval is None:
return 0, 0
return np.zeros(shape=time.shape, dtype=np.float32), np.zeros(shape=time.shape, dtype=np.int32)

if time not in field.time_interval:
if not field.time_interval.is_all_time_in_interval(time):
_raise_time_extrapolation_error(time, field=None)

time_index = field.data.time <= time

if time_index.all():
# If given time > last known field time, use
# the last field frame without interpolation
ti = len(field.data.time) - 1

elif np.logical_not(time_index).all():
# If given time < any time in the field, use
# the first field frame without interpolation
ti = 0
else:
ti = int(time_index.argmin() - 1) if time_index.any() else 0
if len(field.data.time) == 1:
tau = 0
elif ti == len(field.data.time) - 1:
tau = 1
else:
tau = (
(time - field.data.time[ti]).dt.total_seconds().values
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds().values
if field.data.time[ti] != field.data.time[ti + 1]
else 0
)
return tau, ti
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
return np.atleast_1d(tau), np.atleast_1d(ti)


def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float):
Expand Down Expand Up @@ -273,13 +251,13 @@ def _search_indices_rectilinear(

def _search_indices_curvilinear_2d(
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
):
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
yi, xi = yi_guess, xi_guess
if yi is None or xi is None:
faces = grid.get_spatial_hash().query(np.column_stack((y, x)))
yi, xi = faces[0]

xsi = eta = -1.0
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
invA = np.array(
[
[1, 0, 0, 0],
Expand All @@ -303,7 +281,7 @@ def _search_indices_curvilinear_2d(
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
# _raise_field_out_of_bound_error(z, y, x)

while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol:
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])

py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
Expand All @@ -313,40 +291,29 @@ def _search_indices_curvilinear_2d(
aa = a[3] * b[2] - a[2] * b[3]
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
if abs(aa) < 1e-12: # Rectilinear cell, or quasi
eta = -cc / bb
else:
det2 = bb * bb - 4 * aa * cc
if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter
det = np.sqrt(det2)
eta = (-bb + det) / (2 * aa)
if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg
xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5
else:
xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta)
if xsi < 0 and eta < 0 and xi == 0 and yi == 0:
_raise_field_out_of_bound_error(0, y, x)
if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1:
_raise_field_out_of_bound_error(0, y, x)
if xsi < -tol:
xi -= 1
elif xsi > 1 + tol:
xi += 1
if eta < -tol:
yi -= 1
elif eta > 1 + tol:
yi += 1

det2 = bb * bb - 4 * aa * cc
det = np.where(det2 > 0, np.sqrt(det2), eta)
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))

xsi = np.where(
abs(a[1] + a[3] * eta) < 1e-12,
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
)

xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))

(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
it += 1
if it > maxIterSearch:
print(f"Correct cell not found after {maxIterSearch} iterations")
_raise_field_out_of_bound_error(0, y, x)
xsi = max(0.0, xsi)
eta = max(0.0, eta)
xsi = min(1.0, xsi)
eta = min(1.0, eta)
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))

if not ((0 <= xsi <= 1) and (0 <= eta <= 1)):
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
_raise_field_sampling_error(y, x)
return (yi, eta, xi, xsi)

Expand Down Expand Up @@ -442,20 +409,12 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2


def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
if xi < 0:
if sphere_mesh:
xi = xdim - 2
else:
xi = 0
if xi > xdim - 2:
if sphere_mesh:
xi = 0
else:
xi = xdim - 2
if yi < 0:
yi = 0
if yi > ydim - 2:
yi = ydim - 2
if sphere_mesh:
xi = xdim - xi
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)

xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)

yi = np.where(yi < 0, 0, yi)
yi = np.where(yi > ydim - 2, ydim - 2, yi)

return yi, xi
91 changes: 50 additions & 41 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def AdvectionRK4(particle, fieldset, time): # pragma: no cover
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
(u1, v1) = fieldset.UV[particle]
lon1, lat1 = (particle.lon + u1 * 0.5 * dt, particle.lat + v1 * 0.5 * dt)
(u2, v2) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
(u2, v2) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
lon2, lat2 = (particle.lon + u2 * 0.5 * dt, particle.lat + v2 * 0.5 * dt)
(u3, v3) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
(u3, v3) = fieldset.UV[particle.time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
lon3, lat3 = (particle.lon + u3 * dt, particle.lat + v3 * dt)
(u4, v4) = fieldset.UV[time + particle.dt, particle.depth, lat3, lon3, particle]
(u4, v4) = fieldset.UV[particle.time + particle.dt, particle.depth, lat3, lon3, particle]
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6.0 * dt
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6.0 * dt

Expand All @@ -37,15 +37,15 @@ def AdvectionRK4_3D(particle, fieldset, time): # pragma: no cover
lon1 = particle.lon + u1 * 0.5 * dt
lat1 = particle.lat + v1 * 0.5 * dt
dep1 = particle.depth + w1 * 0.5 * dt
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
lon2 = particle.lon + u2 * 0.5 * dt
lat2 = particle.lat + v2 * 0.5 * dt
dep2 = particle.depth + w2 * 0.5 * dt
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
lon3 = particle.lon + u3 * dt
lat3 = particle.lat + v3 * dt
dep3 = particle.depth + w3 * dt
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
particle.ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * dt
Expand All @@ -56,35 +56,35 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
"""
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]
sig_dep = particle.depth / fieldset.H[particle.time, 0, particle.lat, particle.lon]

(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
(u1, v1, w1) = fieldset.UVW[particle.time, particle.depth, particle.lat, particle.lon, particle]
w1 *= sig_dep / fieldset.H[particle.time, 0, particle.lat, particle.lon]
lon1 = particle.lon + u1 * 0.5 * dt
lat1 = particle.lat + v1 * 0.5 * dt
sig_dep1 = sig_dep + w1 * 0.5 * dt
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]
dep1 = sig_dep1 * fieldset.H[particle.time, 0, lat1, lon1]

(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
(u2, v2, w2) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
w2 *= sig_dep1 / fieldset.H[particle.time, 0, lat1, lon1]
lon2 = particle.lon + u2 * 0.5 * dt
lat2 = particle.lat + v2 * 0.5 * dt
sig_dep2 = sig_dep + w2 * 0.5 * dt
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]
dep2 = sig_dep2 * fieldset.H[particle.time, 0, lat2, lon2]

(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
(u3, v3, w3) = fieldset.UVW[particle.time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
w3 *= sig_dep2 / fieldset.H[particle.time, 0, lat2, lon2]
lon3 = particle.lon + u3 * dt
lat3 = particle.lat + v3 * dt
sig_dep3 = sig_dep + w3 * dt
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]
dep3 = sig_dep3 * fieldset.H[particle.time, 0, lat3, lon3]

(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
(u4, v4, w4) = fieldset.UVW[particle.time + particle.dt, dep3, lat3, lon3, particle]
w4 *= sig_dep3 / fieldset.H[particle.time, 0, lat3, lon3]
lon4 = particle.lon + u4 * dt
lat4 = particle.lat + v4 * dt
sig_dep4 = sig_dep + w4 * dt
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]
dep4 = sig_dep4 * fieldset.H[particle.time, 0, lat4, lon4]

particle.dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt
particle.dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt
Expand Down Expand Up @@ -115,14 +115,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
Time-step dt is halved if error is larger than fieldset.RK45_tol,
and doubled if error is smaller than 1/10th of tolerance.
"""
dt = particle.next_dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds
if dt > fieldset.RK45_max_dt:
dt = fieldset.RK45_max_dt
particle.next_dt = fieldset.RK45_max_dt * np.timedelta64(1, "s")
if dt < fieldset.RK45_min_dt:
particle.next_dt = fieldset.RK45_min_dt * np.timedelta64(1, "s")
return StatusCode.Repeat
particle.dt = particle.next_dt
dt = particle.dt / np.timedelta64(1, "s") # TODO: improve API for converting dt to seconds

c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
A = [
Expand All @@ -137,42 +130,58 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover

(u1, v1) = fieldset.UV[particle]
lon1, lat1 = (particle.lon + u1 * A[0][0] * dt, particle.lat + v1 * A[0][0] * dt)
(u2, v2) = fieldset.UV[time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
(u2, v2) = fieldset.UV[particle.time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
lon2, lat2 = (
particle.lon + (u1 * A[1][0] + u2 * A[1][1]) * dt,
particle.lat + (v1 * A[1][0] + v2 * A[1][1]) * dt,
)
(u3, v3) = fieldset.UV[time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
(u3, v3) = fieldset.UV[particle.time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
lon3, lat3 = (
particle.lon + (u1 * A[2][0] + u2 * A[2][1] + u3 * A[2][2]) * dt,
particle.lat + (v1 * A[2][0] + v2 * A[2][1] + v3 * A[2][2]) * dt,
)
(u4, v4) = fieldset.UV[time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
(u4, v4) = fieldset.UV[particle.time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
lon4, lat4 = (
particle.lon + (u1 * A[3][0] + u2 * A[3][1] + u3 * A[3][2] + u4 * A[3][3]) * dt,
particle.lat + (v1 * A[3][0] + v2 * A[3][1] + v3 * A[3][2] + v4 * A[3][3]) * dt,
)
(u5, v5) = fieldset.UV[time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
(u5, v5) = fieldset.UV[particle.time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
lon5, lat5 = (
particle.lon + (u1 * A[4][0] + u2 * A[4][1] + u3 * A[4][2] + u4 * A[4][3] + u5 * A[4][4]) * dt,
particle.lat + (v1 * A[4][0] + v2 * A[4][1] + v3 * A[4][2] + v4 * A[4][3] + v5 * A[4][4]) * dt,
)
(u6, v6) = fieldset.UV[time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
(u6, v6) = fieldset.UV[particle.time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]

lon_4th = (u1 * b4[0] + u2 * b4[1] + u3 * b4[2] + u4 * b4[3] + u5 * b4[4]) * dt
lat_4th = (v1 * b4[0] + v2 * b4[1] + v3 * b4[2] + v4 * b4[3] + v5 * b4[4]) * dt
lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * dt
lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * dt

kappa = math.sqrt(math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2))
if (kappa <= fieldset.RK45_tol) or (math.fabs(dt) < math.fabs(fieldset.RK45_min_dt)):
particle.dlon += lon_4th
particle.dlat += lat_4th
if (kappa <= fieldset.RK45_tol / 10) and (math.fabs(dt * 2) <= math.fabs(fieldset.RK45_max_dt)):
particle.next_dt *= 2
else:
particle.next_dt /= 2
return StatusCode.Repeat
kappa = np.sqrt(np.pow(lon_5th - lon_4th, 2) + np.pow(lat_5th - lat_4th, 2))

good_particles = (kappa <= fieldset.RK45_tol) | (np.fabs(dt) <= np.fabs(fieldset.RK45_min_dt))
particle.dlon += np.where(good_particles, lon_5th, 0)
particle.dlat += np.where(good_particles, lat_5th, 0)

increase_dt_particles = (
good_particles & (kappa <= fieldset.RK45_tol / 10) & (np.fabs(dt * 2) <= np.fabs(fieldset.RK45_max_dt))
)
particle.dt = np.where(increase_dt_particles, particle.dt * 2, particle.dt)
particle.dt = np.where(
particle.dt > fieldset.RK45_max_dt * np.timedelta64(1, "s"),
fieldset.RK45_max_dt * np.timedelta64(1, "s"),
particle.dt,
)
particle.state = np.where(good_particles, StatusCode.Success, particle.state)

repeat_particles = np.invert(good_particles)
particle.dt = np.where(repeat_particles, particle.dt / 2, particle.dt)
particle.dt = np.where(
particle.dt < fieldset.RK45_min_dt * np.timedelta64(1, "s"),
fieldset.RK45_min_dt * np.timedelta64(1, "s"),
particle.dt,
)
particle.state = np.where(repeat_particles, StatusCode.Repeat, particle.state)


def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
Expand Down
Loading
Loading