Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 4 additions & 31 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_raise_field_out_of_bound_error,
_raise_field_out_of_bound_surface_error,
_raise_field_sampling_error,
_raise_time_extrapolation_error,
)

from .basegrid import GridType
Expand All @@ -38,36 +37,10 @@ def _search_time_index(field: Field, time: datetime):
Note that we normalize to either the first or the last index
if the sampled value is outside the time value range.
"""
if field.time_interval is None:
return 0, 0

if time not in field.time_interval:
_raise_time_extrapolation_error(time, field=None)

time_index = field.data.time <= time

if time_index.all():
# If given time > last known field time, use
# the last field frame without interpolation
ti = len(field.data.time) - 1

elif np.logical_not(time_index).all():
# If given time < any time in the field, use
# the first field frame without interpolation
ti = 0
else:
ti = int(time_index.argmin() - 1) if time_index.any() else 0
if len(field.data.time) == 1:
tau = 0
elif ti == len(field.data.time) - 1:
tau = 1
else:
tau = (
(time - field.data.time[ti]).dt.total_seconds()
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds()
if field.data.time[ti] != field.data.time[ti + 1]
else 0
)
ti = np.argmin(field._time_float <= time) - 1
tau = (time - field._time_float[ti]) / (field._time_float[ti + 1] - field._time_float[ti])
if tau < 0 or tau > 1: # TODO only for debugging; test can go?
raise ValueError(f"Time {time} is out of bounds for field time data {field.data.time.data}.")
return tau, ti


Expand Down
2 changes: 1 addition & 1 deletion parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover

def AdvectionEE(particle, fieldset, time): # pragma: no cover
"""Advection of particles using Explicit Euler (aka Euler Forward) integration."""
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
dt = particle.dt
(u1, v1) = fieldset.UV[particle]
particle_dlon += u1 * dt # noqa
particle_dlat += v1 * dt # noqa
Expand Down
41 changes: 25 additions & 16 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid)

self.name = name
self.data = data
self.data_full = data
self.grid = grid

try:
Expand Down Expand Up @@ -189,8 +189,8 @@ def __init__(
else:
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")

if self.data.shape[0] > 1:
if "time" not in self.data.coords:
if data.shape[0] > 1:
if "time" not in data.coords:
raise ValueError("Field data is missing a 'time' coordinate.")

@property
Expand All @@ -205,27 +205,27 @@ def units(self, value):

@property
def xdim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.xdim
else:
raise NotImplementedError("xdim not implemented for unstructured grids")

@property
def ydim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.ydim
else:
raise NotImplementedError("ydim not implemented for unstructured grids")

@property
def zdim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.zdim
else:
if "nz1" in self.data.dims:
return self.data.sizes["nz1"]
elif "nz" in self.data.dims:
return self.data.sizes["nz"]
if "nz1" in self.data_full.dims:
return self.data_full.sizes["nz1"]
elif "nz" in self.data_full.dims:
return self.data_full.sizes["nz"]
else:
return 0

Expand All @@ -246,6 +246,19 @@ def _check_velocitysampling(self):
stacklevel=2,
)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of a field."""
ti = np.argmin(self._time_float <= time) - 1 # TODO also implement dt < 0
if not hasattr(self, "data"):
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
elif self.data_full.time.data[ti] == self.data.time.data[1]:
self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time")
elif self.data_full.time.data[ti] != self.data.time.data[0]:
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
assert (
len(self.data.time) == 2
), f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}."

def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
"""Interpolate field values in space and time.

Expand All @@ -266,17 +279,14 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
if np.isnan(value):
# Detect Out-of-bounds sampling and raise exception
_raise_field_out_of_bound_error(z, y, x)
else:
return value

except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e:
e.add_note(f"Error interpolating field '{self.name}'.")
raise e

if applyConversion:
return self.units.to_target(value, z, y, x)
else:
return value
value = self.units.to_target(value, z, y, x)
return value

def __getitem__(self, key):
self._check_velocitysampling()
Expand Down Expand Up @@ -359,7 +369,6 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
else:
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)

# print(u,v)
if applyConversion:
u = self.U.units.to_target(u, z, y, x)
v = self.V.units.to_target(v, z, y, x)
Expand Down
7 changes: 7 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def time_interval(self):
return None
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of all fields in the fieldset."""
for fldname in self.fields:
field = self.fields[fldname]
if isinstance(field, Field):
field._load_timesteps(time)

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.

Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def execute(self, pset, endtime, dt):
"""Execute this Kernel over a ParticleSet for several timesteps."""
pset._data["state"][:] = StatusCode.Evaluate

if abs(dt) < np.timedelta64(1000, "ns"): # TODO still needed?
if abs(dt) < 1e-6:
warnings.warn(
"'dt' is too small, causing numerical accuracy limit problems. Please chose a higher 'dt' and rather scale the 'time' axis of the field accordingly. (related issue #762)",
RuntimeWarning,
Expand Down
44 changes: 31 additions & 13 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parcels._reprs import particleset_repr
from parcels.application_kernels.advection import AdvectionRK4
from parcels.basegrid import GridType
from parcels.field import Field
from parcels.interaction.interactionkernel import InteractionKernel
from parcels.kernel import Kernel
from parcels.particle import Particle, Variable
Expand Down Expand Up @@ -109,9 +110,11 @@ def __init__(
assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts"

if time is None or len(time) == 0:
time = np.datetime64("NaT", "ns") # do not set a time yet (because sign_dt not known)
elif type(time[0]) in [np.datetime64, np.timedelta64]:
pass # already in the right format
time = np.array([np.nan]) # do not set a time yet (because sign_dt not known)
elif type(time[0]) is np.datetime64:
time = time - self.fieldset.time_interval.left
elif type(time[0]) is np.timedelta64:
time = time / np.timedelta64(1, "s")
else:
raise TypeError("particle time must be a datetime, timedelta, or date object")
time = np.repeat(time, lon.size) if time.size == 1 else time
Expand Down Expand Up @@ -140,7 +143,7 @@ def __init__(
"lat": lat.astype(lonlatdepth_dtype),
"depth": depth.astype(lonlatdepth_dtype),
"time": time,
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
"dt": np.ones(len(trajectory_ids), dtype=np.float64),
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
"lon_nextloop": lon.astype(lonlatdepth_dtype),
Expand Down Expand Up @@ -736,7 +739,10 @@ def execute(
if (dt is not None) and (not isinstance(dt, np.timedelta64)):
raise TypeError("dt must be a np.timedelta64 object")
if dt is None or np.isnat(dt):
dt = np.timedelta64(1, "s")
dt = 1
else:
dt = dt / np.timedelta64(1, "s")

self._data["dt"][:] = dt
sign_dt = np.sign(dt).astype(int)
if sign_dt not in [-1, 1]:
Expand All @@ -754,7 +760,7 @@ def execute(
raise TypeError("The runtime must be a np.timedelta64 object")

else:
if not np.isnat(self._data["time_nextloop"]).any():
if not np.isnan(self._data["time_nextloop"]).any():
if sign_dt > 0:
start_time = self._data["time_nextloop"].min()
else:
Expand Down Expand Up @@ -787,8 +793,11 @@ def execute(
else:
end_time = start_time + runtime * sign_dt

start_time = (start_time - self.fieldset.time_interval.left) / np.timedelta64(1, "s")
end_time = (end_time - self.fieldset.time_interval.left) / np.timedelta64(1, "s")

# Set the time of the particles if it hadn't been set on initialisation
if np.isnat(self._data["time"]).any():
if np.isnan(self._data["time"]).any():
self._data["time"][:] = start_time
self._data["time_nextloop"][:] = start_time

Expand All @@ -799,16 +808,25 @@ def execute(
logger.info(f"Output files are stored in {output_file.fname}.")

if verbose_progress:
pbar = tqdm(total=(end_time - start_time) / np.timedelta64(1, "s"), file=sys.stdout)
pbar = tqdm(total=(end_time - start_time), file=sys.stdout)

next_output = outputdt if output_file else None

time = start_time

for fldname in self.fieldset.fields:
field = self.fieldset.fields[fldname]
if isinstance(field, Field):
field._time_float = (field.data_full.time.data - field.time_interval.left) / np.timedelta64(1, "s")

while sign_dt * (time - end_time) < 0:
# Load the appropriate timesteps of the fieldset
self.fieldset._load_timesteps(self._data["time_nextloop"][0])

if sign_dt > 0:
next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works
next_time = min(time + dt, end_time)
else:
next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works
next_time = max(time - dt, end_time)
res = self._kernel.execute(self, endtime=next_time, dt=dt)
if res == StatusCode.StopAllExecution:
return StatusCode.StopAllExecution
Expand All @@ -822,7 +840,7 @@ def execute(
next_output += outputdt

if verbose_progress:
pbar.update((next_time - time) / np.timedelta64(1, "s"))
pbar.update(next_time - time)

time = next_time

Expand All @@ -844,13 +862,13 @@ def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_tim

def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time: TimeInterval):
if np.any(release_times):
if np.any(release_times < time.left):
if np.any(release_times < 0):
warnings.warn(
"Some particles are set to be released outside the FieldSet's executable time domain.",
ParticleSetWarning,
stacklevel=2,
)
if np.any(release_times > time.right):
if np.any(release_times > (time.right - time.left) / np.timedelta64(1, "s")):
warnings.warn(
"Some particles are set to be released after the fieldset's last time and the fields are not constant in time.",
ParticleSetWarning,
Expand Down
10 changes: 7 additions & 3 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ def _gtype(self):
def search(self, z, y, x, ei=None):
ds = self.xgcm_grid._ds

zi, zeta = _search_1d_array(ds.depth.values, z)
zi, zeta = _search_1d_array(ds.depth.data, z)

if ds.lon.ndim == 1:
yi, eta = _search_1d_array(ds.lat.values, y)
xi, xsi = _search_1d_array(ds.lon.values, x)
yi, eta = _search_1d_array(ds.lat.data, y)
xi, xsi = _search_1d_array(ds.lon.data, x)
return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}

yi, xi = None, None
Expand Down Expand Up @@ -453,8 +453,12 @@ def _search_1d_array(
float
Barycentric coordinate.
"""
if len(arr) < 2:
return 0, 0.0
i = np.argmin(arr <= x) - 1
bcoord = (x - arr[i]) / (arr[i + 1] - arr[i])
if bcoord < 0 or bcoord > 1: # TODO only for debugging; test can go?
raise ValueError(f"Position {x} is out of bounds for array {arr}.")
return i, bcoord


Expand Down
Loading