Skip to content

Commit fdba4c5

Browse files
Merge pull request #2113 from OceanParcels/adding_unit_tests
Adding unit tests
2 parents c6f91d6 + 1293758 commit fdba4c5

File tree

16 files changed

+725
-770
lines changed

16 files changed

+725
-770
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import xarray as xr
3+
4+
5+
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"):
6+
max_lon = 180.0 if mesh_type == "spherical" else 1e6
7+
8+
return xr.Dataset(
9+
{"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))},
10+
coords={
11+
"time": (["time"], xr.date_range("2000", "2001", dims[0]), {"axis": "T"}),
12+
"depth": (["depth"], np.linspace(0, maxdepth, dims[1]), {"axis": "Z"}),
13+
"YC": (["YC"], np.arange(dims[2]) + 0.5, {"axis": "Y"}),
14+
"YG": (["YG"], np.arange(dims[2]), {"axis": "Y", "c_grid_axis_shift": -0.5}),
15+
"XC": (["XC"], np.arange(dims[3]) + 0.5, {"axis": "X"}),
16+
"XG": (["XG"], np.arange(dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}),
17+
"lat": (["YG"], np.linspace(-90, 90, dims[2]), {"axis": "Y", "c_grid_axis_shift": 0.5}),
18+
"lon": (["XG"], np.linspace(-max_lon, max_lon, dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}),
19+
},
20+
)

parcels/_index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def _search_time_index(field: Field, time: datetime):
6363
tau = 1
6464
else:
6565
tau = (
66-
(time - field.data.time[ti]).dt.total_seconds()
67-
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds()
66+
(time - field.data.time[ti]).dt.total_seconds().values
67+
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds().values
6868
if field.data.time[ti] != field.data.time[ti + 1]
6969
else 0
7070
)

parcels/_reprs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def particleset_repr(pset: ParticleSet) -> str:
5555
out = f"""<{type(pset).__name__}>
5656
fieldset :
5757
{textwrap.indent(repr(pset.fieldset), " " * 8)}
58-
pclass : {pset.pclass}
59-
repeatdt : {pset.repeatdt}
58+
ptype : {pset._ptype}
6059
# particles: {len(pset)}
6160
particles : {_format_list_items_multiline(particles, level=2)}
6261
"""

parcels/application_kernels/advection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
113113
Time-step dt is halved if error is larger than fieldset.RK45_tol,
114114
and doubled if error is smaller than 1/10th of tolerance.
115115
"""
116-
dt = min(particle.next_dt, fieldset.RK45_max_dt) / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
116+
dt = min(particle.next_dt / np.timedelta64(1, "s"), fieldset.RK45_max_dt) # noqa TODO improve API for converting dt to seconds
117117
c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
118118
A = [
119119
[1.0 / 4.0, 0.0, 0.0, 0.0, 0.0],

parcels/application_kernels/advectiondiffusion.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def AdvectionDiffusionM1(particle, fieldset, time): # pragma: no cover
2424
The Wiener increment `dW` is normally distributed with zero
2525
mean and a standard deviation of sqrt(dt).
2626
"""
27+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
2728
# Wiener increment with zero mean and std of sqrt(dt)
28-
dWx = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
29-
dWy = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
29+
dWx = random.normalvariate(0, math.sqrt(math.fabs(dt)))
30+
dWy = random.normalvariate(0, math.sqrt(math.fabs(dt)))
3031

3132
Kxp1 = fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon + fieldset.dres]
3233
Kxm1 = fieldset.Kh_zonal[time, particle.depth, particle.lat, particle.lon - fieldset.dres]
@@ -42,8 +43,8 @@ def AdvectionDiffusionM1(particle, fieldset, time): # pragma: no cover
4243
by = math.sqrt(2 * fieldset.Kh_meridional[time, particle.depth, particle.lat, particle.lon])
4344

4445
# Particle positions are updated only after evaluating all terms.
45-
particle_dlon += u * particle.dt + 0.5 * dKdx * (dWx**2 + particle.dt) + bx * dWx # noqa
46-
particle_dlat += v * particle.dt + 0.5 * dKdy * (dWy**2 + particle.dt) + by * dWy # noqa
46+
particle_dlon += u * dt + 0.5 * dKdx * (dWx**2 + dt) + bx * dWx # noqa
47+
particle_dlat += v * dt + 0.5 * dKdy * (dWy**2 + dt) + by * dWy # noqa
4748

4849

4950
def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
@@ -59,9 +60,10 @@ def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
5960
The Wiener increment `dW` is normally distributed with zero
6061
mean and a standard deviation of sqrt(dt).
6162
"""
63+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
6264
# Wiener increment with zero mean and std of sqrt(dt)
63-
dWx = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
64-
dWy = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
65+
dWx = random.normalvariate(0, math.sqrt(math.fabs(dt)))
66+
dWy = random.normalvariate(0, math.sqrt(math.fabs(dt)))
6567

6668
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
6769

@@ -78,8 +80,8 @@ def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
7880
by = math.sqrt(2 * fieldset.Kh_meridional[time, particle.depth, particle.lat, particle.lon])
7981

8082
# Particle positions are updated only after evaluating all terms.
81-
particle_dlon += ax * particle.dt + bx * dWx # noqa
82-
particle_dlat += ay * particle.dt + by * dWy # noqa
83+
particle_dlon += ax * dt + bx * dWx # noqa
84+
particle_dlat += ay * dt + by * dWy # noqa
8385

8486

8587
def DiffusionUniformKh(particle, fieldset, time): # pragma: no cover
@@ -100,9 +102,10 @@ def DiffusionUniformKh(particle, fieldset, time): # pragma: no cover
100102
The Wiener increment `dW` is normally distributed with zero
101103
mean and a standard deviation of sqrt(dt).
102104
"""
105+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
103106
# Wiener increment with zero mean and std of sqrt(dt)
104-
dWx = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
105-
dWy = random.normalvariate(0, math.sqrt(math.fabs(particle.dt)))
107+
dWx = random.normalvariate(0, math.sqrt(math.fabs(dt)))
108+
dWy = random.normalvariate(0, math.sqrt(math.fabs(dt)))
106109

107110
bx = math.sqrt(2 * fieldset.Kh_zonal[particle])
108111
by = math.sqrt(2 * fieldset.Kh_meridional[particle])

parcels/application_kernels/interpolation.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,135 @@
77
import numpy as np
88

99
from parcels.field import Field
10+
from parcels.tools.statuscodes import (
11+
FieldOutOfBoundError,
12+
)
1013

1114
if TYPE_CHECKING:
1215
from parcels.uxgrid import _UXGRID_AXES
16+
from parcels.xgrid import _XGRID_AXES
1317

1418
__all__ = [
1519
"UXPiecewiseConstantFace",
1620
"UXPiecewiseLinearNode",
21+
"XBiLinear",
22+
"XBiLinearPeriodic",
23+
"XTriLinear",
1724
]
1825

1926

27+
def XBiLinear(
28+
field: Field,
29+
ti: int,
30+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
31+
tau: np.float32 | np.float64,
32+
t: np.float32 | np.float64,
33+
z: np.float32 | np.float64,
34+
y: np.float32 | np.float64,
35+
x: np.float32 | np.float64,
36+
):
37+
"""Bilinear interpolation on a regular grid."""
38+
xi, xsi = position["X"]
39+
yi, eta = position["Y"]
40+
zi, _ = position["Z"]
41+
42+
data = field.data.data[:, zi, yi : yi + 2, xi : xi + 2]
43+
data = (1 - tau) * data[ti, :, :] + tau * data[ti + 1, :, :]
44+
45+
return (
46+
(1 - xsi) * (1 - eta) * data[0, 0]
47+
+ xsi * (1 - eta) * data[0, 1]
48+
+ xsi * eta * data[1, 1]
49+
+ (1 - xsi) * eta * data[1, 0]
50+
)
51+
52+
53+
def XBiLinearPeriodic(
54+
field: Field,
55+
ti: int,
56+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
57+
tau: np.float32 | np.float64,
58+
t: np.float32 | np.float64,
59+
z: np.float32 | np.float64,
60+
y: np.float32 | np.float64,
61+
x: np.float32 | np.float64,
62+
):
63+
"""Bilinear interpolation on a regular grid with periodic boundary conditions in horizontal directions."""
64+
xi, xsi = position["X"]
65+
yi, eta = position["Y"]
66+
zi, _ = position["Z"]
67+
68+
if xi < 0:
69+
xi = 0
70+
xsi = (x - field.grid.lon[xi]) / (field.grid.lon[xi + 1] - field.grid.lon[xi])
71+
if yi < 0:
72+
yi = 0
73+
eta = (y - field.grid.lat[yi]) / (field.grid.lat[yi + 1] - field.grid.lat[yi])
74+
75+
data = field.data.data[:, zi, yi : yi + 2, xi : xi + 2]
76+
data = (1 - tau) * data[ti, :, :] + tau * data[ti + 1, :, :]
77+
78+
xsi = 0 if not np.isfinite(xsi) else xsi
79+
eta = 0 if not np.isfinite(eta) else eta
80+
81+
if xsi > 0 and eta > 0:
82+
return (
83+
(1 - xsi) * (1 - eta) * data[0, 0]
84+
+ xsi * (1 - eta) * data[0, 1]
85+
+ xsi * eta * data[1, 1]
86+
+ (1 - xsi) * eta * data[1, 0]
87+
)
88+
elif xsi > 0 and eta == 0:
89+
return (1 - xsi) * data[0, 0] + xsi * data[0, 1]
90+
elif xsi == 0 and eta > 0:
91+
return (1 - eta) * data[0, 0] + eta * data[1, 0]
92+
else:
93+
return data[0, 0]
94+
95+
96+
def XTriLinear(
97+
field: Field,
98+
ti: int,
99+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
100+
tau: np.float32 | np.float64,
101+
t: np.float32 | np.float64,
102+
z: np.float32 | np.float64,
103+
y: np.float32 | np.float64,
104+
x: np.float32 | np.float64,
105+
):
106+
"""Trilinear interpolation on a regular grid."""
107+
xi, xsi = position["X"]
108+
yi, eta = position["Y"]
109+
zi, zeta = position["Z"]
110+
111+
if zi < 0 or xi < 0 or yi < 0:
112+
raise FieldOutOfBoundError
113+
114+
data = field.data.data[:, zi : zi + 2, yi : yi + 2, xi : xi + 2]
115+
data = (1 - tau) * data[ti, :, :, :] + tau * data[ti + 1, :, :, :]
116+
if zeta > 0:
117+
data = (1 - zeta) * data[0, :, :] + zeta * data[1, :, :]
118+
else:
119+
data = data[0, :, :]
120+
121+
xsi = 0 if not np.isfinite(xsi) else xsi
122+
eta = 0 if not np.isfinite(eta) else eta
123+
124+
if xsi > 0 and eta > 0:
125+
return (
126+
(1 - xsi) * (1 - eta) * data[0, 0]
127+
+ xsi * (1 - eta) * data[0, 1]
128+
+ xsi * eta * data[1, 1]
129+
+ (1 - xsi) * eta * data[1, 0]
130+
)
131+
elif xsi > 0 and eta == 0:
132+
return (1 - xsi) * data[0, 0] + xsi * data[0, 1]
133+
elif xsi == 0 and eta > 0:
134+
return (1 - eta) * data[0, 0] + eta * data[1, 0]
135+
else:
136+
return data[0, 0]
137+
138+
20139
def UXPiecewiseConstantFace(
21140
field: Field,
22141
ti: int,

parcels/kernel.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from parcels.application_kernels.advection import (
1414
AdvectionAnalytical,
15-
AdvectionRK4_3D,
16-
AdvectionRK4_3D_CROCO,
1715
AdvectionRK45,
1816
)
1917
from parcels.basegrid import GridType
@@ -124,9 +122,10 @@ def __init__(
124122
# Derive meta information from pyfunc, if not given
125123
self.check_fieldsets_in_kernels(pyfunc)
126124

127-
if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco":
128-
pyfunc = AdvectionRK4_3D_CROCO
129-
self.funcname = "AdvectionRK4_3D_CROCO"
125+
# # TODO will be implemented when we support CROCO again
126+
# if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco":
127+
# pyfunc = AdvectionRK4_3D_CROCO
128+
# self.funcname = "AdvectionRK4_3D_CROCO"
130129

131130
if funcvars is not None: # TODO v4: check if needed from here onwards
132131
self.funcvars = funcvars
@@ -385,13 +384,12 @@ def evaluate_particle(self, p, endtime):
385384
return p
386385

387386
pre_dt = p.dt
388-
# TODO implement below later again
389-
# try: # Use next_dt from AdvectionRK45 if it is set
390-
# if abs(endtime - p.time_nextloop) < abs(p.next_dt) - 1e-6:
391-
# p.next_dt = abs(endtime - p.time_nextloop) * sign_dt
392-
# except AttributeError:
393-
if sign_dt * (endtime - p.time_nextloop) <= p.dt:
394-
p.dt = sign_dt * (endtime - p.time_nextloop)
387+
try: # Use next_dt from AdvectionRK45 if it is set
388+
if abs(endtime - p.time_nextloop) < abs(p.next_dt) - np.timedelta64(1000, "ns"):
389+
p.next_dt = sign_dt * (endtime - p.time_nextloop)
390+
except KeyError:
391+
if sign_dt * (endtime - p.time_nextloop) <= p.dt:
392+
p.dt = sign_dt * (endtime - p.time_nextloop)
395393
res = self._pyfunc(p, self._fieldset, p.time_nextloop)
396394

397395
if res is None:

parcels/particleset.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def __init__(
109109
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"
110110

111111
if time is None or len(time) == 0:
112-
time = np.datetime64("NaT", "ns") # do not set a time yet (because sign_dt not known)
112+
time = type(fieldset.U.data.time[0].values)(
113+
"NaT", "ns"
114+
) # do not set a time yet (because sign_dt not known)
113115
elif type(time[0]) in [np.datetime64, np.timedelta64]:
114116
pass # already in the right format
115117
else:
@@ -156,7 +158,7 @@ def __init__(
156158
if isinstance(v.initial, attrgetter):
157159
initial = v.initial(self)
158160
else:
159-
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
161+
initial = [np.array(v.initial, dtype=v.dtype)] * len(trajectory_ids)
160162
self._data[v.name] = initial
161163

162164
# update initial values provided on ParticleSet creation
@@ -843,16 +845,20 @@ def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_tim
843845

844846

845847
def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time: TimeInterval):
846-
if np.any(release_times):
847-
if np.any(release_times < time.left):
848-
warnings.warn(
849-
"Some particles are set to be released outside the FieldSet's executable time domain.",
850-
ParticleSetWarning,
851-
stacklevel=2,
852-
)
853-
if np.any(release_times > time.right):
854-
warnings.warn(
855-
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
856-
ParticleSetWarning,
857-
stacklevel=2,
858-
)
848+
if np.isnat(release_times).all():
849+
return
850+
851+
if isinstance(time.left, np.datetime64) and isinstance(release_times[0], np.timedelta64):
852+
release_times = np.array([t + time.left for t in release_times])
853+
if np.any(release_times < time.left):
854+
warnings.warn(
855+
"Some particles are set to be released outside the FieldSet's executable time domain.",
856+
ParticleSetWarning,
857+
stacklevel=2,
858+
)
859+
if np.any(release_times > time.right):
860+
warnings.warn(
861+
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
862+
ParticleSetWarning,
863+
stacklevel=2,
864+
)

parcels/xgrid.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from parcels import xgcm
1010
from parcels._index_search import _search_indices_curvilinear_2d
1111
from parcels.basegrid import BaseGrid
12+
from parcels.tools.statuscodes import FieldOutOfBoundError, FieldOutOfBoundSurfaceError
1213

1314
_XGRID_AXES = Literal["X", "Y", "Z"]
1415
_XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX"
@@ -271,6 +272,15 @@ def search(self, z, y, x, ei=None):
271272
ds = self.xgcm_grid._ds
272273

273274
zi, zeta = _search_1d_array(ds.depth.values, z)
275+
if zi == -1:
276+
if zeta < 0:
277+
raise FieldOutOfBoundError(
278+
f"Depth {z} is out of bounds for the grid with depth values {ds.depth.values}."
279+
)
280+
elif zeta > 1:
281+
raise FieldOutOfBoundSurfaceError(
282+
f"Depth {z} is out of the surface for the grid with depth values {ds.depth.values}."
283+
)
274284

275285
if ds.lon.ndim == 1:
276286
yi, eta = _search_1d_array(ds.lat.values, y)
@@ -453,6 +463,9 @@ def _search_1d_array(
453463
float
454464
Barycentric coordinate.
455465
"""
466+
# TODO v4: We probably rework this to deal with 0D arrays before this point (as we already know field dimensionality)
467+
if len(arr) < 2:
468+
return 0, 0.0
456469
i = np.argmin(arr <= x) - 1
457470
bcoord = (x - arr[i]) / (arr[i + 1] - arr[i])
458471
return i, bcoord

0 commit comments

Comments
 (0)