Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion parcels/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, data: xr.Dataset, index: int):
self._index = index

def __getattr__(self, name):
return self._data[name].values[self._index]
return self._data[name][self._index]

def __setattr__(self, name, value):
if name in ["_data", "_index"]:
Expand Down
75 changes: 42 additions & 33 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,29 @@ def __init__(
lon.size == kwargs[kwvar].size
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."

self._data = xr.Dataset(
{
"lon": (["trajectory"], lon.astype(lonlatdepth_dtype)),
"lat": (["trajectory"], lat.astype(lonlatdepth_dtype)),
"depth": (["trajectory"], depth.astype(lonlatdepth_dtype)),
"time": (["trajectory"], time),
"dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))),
"ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)),
"lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)),
"lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)),
"depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)),
"time_nextloop": (["trajectory"], time),
},
coords={
"trajectory": ("trajectory", trajectory_ids),
},
attrs={
"ngrid": len(fieldset.gridset),
"ptype": pclass.getPType(),
},
)
self._data = {
"lon": lon.astype(lonlatdepth_dtype),
"lat": lat.astype(lonlatdepth_dtype),
"depth": depth.astype(lonlatdepth_dtype),
"time": time,
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
"lon_nextloop": lon.astype(lonlatdepth_dtype),
"lat_nextloop": lat.astype(lonlatdepth_dtype),
"depth_nextloop": depth.astype(lonlatdepth_dtype),
"time_nextloop": time,
"trajectory": trajectory_ids,
}
self._ptype = pclass.getPType()
# add extra fields from the custom Particle class
for v in pclass.__dict__.values():
if isinstance(v, Variable):
if isinstance(v.initial, attrgetter):
initial = v.initial(self).values
initial = v.initial(self)
else:
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
self._data[v.name] = (["trajectory"], initial)
self._data[v.name] = initial

# update initial values provided on ParticleSet creation
for kwvar, kwval in kwargs.items():
Expand Down Expand Up @@ -238,13 +231,28 @@ def add(self, particles):
The current ParticleSet

"""
assert (
particles is not None
), f"Trying to add another {type(self)} to this one, but the other one is None - invalid operation."
assert type(particles) is type(self)

if len(particles) == 0:
return

if len(self) == 0:
self._data = particles._data
return

if isinstance(particles, type(self)):
if len(self._data["trajectory"]) > 0:
offset = self._data["trajectory"].values.max() + 1
offset = self._data["trajectory"].max() + 1
else:
offset = 0
particles._data["trajectory"] = particles._data["trajectory"].values + offset
self._data = xr.concat([self._data, particles._data], dim="trajectory")
particles._data["trajectory"] = particles._data["trajectory"] + offset

for d in self._data:
self._data[d] = np.concatenate((self._data[d], particles._data[d]))

# Adding particles invalidates the neighbor search structure.
self._dirty_neighbor = True
return self
Expand All @@ -270,7 +278,8 @@ def __iadd__(self, particles):

def remove_indices(self, indices):
"""Method to remove particles from the ParticleSet, based on their `indices`."""
self._data = self._data.drop_sel(trajectory=indices)
for d in self._data:
self._data[d] = np.delete(self._data[d], indices, axis=0)

def _active_particles_mask(self, time, dt):
active_indices = (time - self._data["time"]) / dt >= 0
Expand Down Expand Up @@ -591,19 +600,19 @@ def Kernel(self, pyfunc):
if isinstance(pyfunc, list):
return Kernel.from_list(
self.fieldset,
self._data.ptype,
self._ptype,
pyfunc,
)
return Kernel(
self.fieldset,
self._data.ptype,
self._ptype,
pyfunc=pyfunc,
)

def InteractionKernel(self, pyfunc_inter):
if pyfunc_inter is None:
return None
return InteractionKernel(self.fieldset, self._data.ptype, pyfunc=pyfunc_inter)
return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter)

def ParticleFile(self, *args, **kwargs):
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet."""
Expand Down Expand Up @@ -747,9 +756,9 @@ def execute(
else:
if not np.isnat(self._data["time_nextloop"]).any():
if sign_dt > 0:
start_time = self._data["time_nextloop"].min().values
start_time = self._data["time_nextloop"].min()
else:
start_time = self._data["time_nextloop"].max().values
start_time = self._data["time_nextloop"].max()
else:
if sign_dt > 0:
start_time = self.fieldset.time_interval.left
Expand Down
2 changes: 1 addition & 1 deletion tests/v4/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_unknown_var_in_kernel(fieldset):
def ErrorKernel(particle, fieldset, time): # pragma: no cover
particle.unknown_varname += 0.2

with pytest.raises(KeyError, match="No variable named 'unknown_varname'"):
with pytest.raises(KeyError, match="'unknown_varname'"):
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"))


Expand Down
6 changes: 3 additions & 3 deletions tests/v4/test_particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ def test_pset_add_explicit(fieldset):
assert len(pset) == npart
assert np.allclose([p.lon for p in pset], lon, atol=1e-12)
assert np.allclose([p.lat for p in pset], lat, atol=1e-12)
assert np.allclose(np.diff(pset._data.trajectory), np.ones(pset._data.trajectory.size - 1), atol=1e-12)
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(pset._data["trajectory"].size - 1), atol=1e-12)


def test_pset_add_implicit(fieldset):
pset = ParticleSet(fieldset, lon=np.zeros(3), lat=np.ones(3), pclass=Particle)
pset += ParticleSet(fieldset, lon=np.ones(4), lat=np.zeros(4), pclass=Particle)
assert len(pset) == 7
assert np.allclose(np.diff(pset._data.trajectory), np.ones(6), atol=1e-12)
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(6), atol=1e-12)


def test_pset_add_implicit(fieldset, npart=10):
def test_pset_add_implicit_in_loop(fieldset, npart=10):
pset = ParticleSet(fieldset, lon=[], lat=[])
for _ in range(npart):
pset += ParticleSet(pclass=Particle, lon=0.1, lat=0.1, fieldset=fieldset)
Expand Down
Loading