Skip to content

Commit 6baf346

Browse files
Merge pull request #2094 from OceanParcels/particledata_as_dict
Performance: Using a dict for ParticleSet._data (instead of xarray DataSet)
2 parents e576d39 + 0158fe4 commit 6baf346

File tree

4 files changed

+47
-38
lines changed

4 files changed

+47
-38
lines changed

parcels/particle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self, data: xr.Dataset, index: int):
115115
self._index = index
116116

117117
def __getattr__(self, name):
118-
return self._data[name].values[self._index]
118+
return self._data[name][self._index]
119119

120120
def __setattr__(self, name, value):
121121
if name in ["_data", "_index"]:

parcels/particleset.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -135,36 +135,29 @@ def __init__(
135135
lon.size == kwargs[kwvar].size
136136
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."
137137

138-
self._data = xr.Dataset(
139-
{
140-
"lon": (["trajectory"], lon.astype(lonlatdepth_dtype)),
141-
"lat": (["trajectory"], lat.astype(lonlatdepth_dtype)),
142-
"depth": (["trajectory"], depth.astype(lonlatdepth_dtype)),
143-
"time": (["trajectory"], time),
144-
"dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))),
145-
"ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
146-
"state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)),
147-
"lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)),
148-
"lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)),
149-
"depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)),
150-
"time_nextloop": (["trajectory"], time),
151-
},
152-
coords={
153-
"trajectory": ("trajectory", trajectory_ids),
154-
},
155-
attrs={
156-
"ngrid": len(fieldset.gridset),
157-
"ptype": pclass.getPType(),
158-
},
159-
)
138+
self._data = {
139+
"lon": lon.astype(lonlatdepth_dtype),
140+
"lat": lat.astype(lonlatdepth_dtype),
141+
"depth": depth.astype(lonlatdepth_dtype),
142+
"time": time,
143+
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
144+
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
145+
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
146+
"lon_nextloop": lon.astype(lonlatdepth_dtype),
147+
"lat_nextloop": lat.astype(lonlatdepth_dtype),
148+
"depth_nextloop": depth.astype(lonlatdepth_dtype),
149+
"time_nextloop": time,
150+
"trajectory": trajectory_ids,
151+
}
152+
self._ptype = pclass.getPType()
160153
# add extra fields from the custom Particle class
161154
for v in pclass.__dict__.values():
162155
if isinstance(v, Variable):
163156
if isinstance(v.initial, attrgetter):
164-
initial = v.initial(self).values
157+
initial = v.initial(self)
165158
else:
166159
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
167-
self._data[v.name] = (["trajectory"], initial)
160+
self._data[v.name] = initial
168161

169162
# update initial values provided on ParticleSet creation
170163
for kwvar, kwval in kwargs.items():
@@ -238,13 +231,28 @@ def add(self, particles):
238231
The current ParticleSet
239232
240233
"""
234+
assert (
235+
particles is not None
236+
), f"Trying to add another {type(self)} to this one, but the other one is None - invalid operation."
237+
assert type(particles) is type(self)
238+
239+
if len(particles) == 0:
240+
return
241+
242+
if len(self) == 0:
243+
self._data = particles._data
244+
return
245+
241246
if isinstance(particles, type(self)):
242247
if len(self._data["trajectory"]) > 0:
243-
offset = self._data["trajectory"].values.max() + 1
248+
offset = self._data["trajectory"].max() + 1
244249
else:
245250
offset = 0
246-
particles._data["trajectory"] = particles._data["trajectory"].values + offset
247-
self._data = xr.concat([self._data, particles._data], dim="trajectory")
251+
particles._data["trajectory"] = particles._data["trajectory"] + offset
252+
253+
for d in self._data:
254+
self._data[d] = np.concatenate((self._data[d], particles._data[d]))
255+
248256
# Adding particles invalidates the neighbor search structure.
249257
self._dirty_neighbor = True
250258
return self
@@ -270,7 +278,8 @@ def __iadd__(self, particles):
270278

271279
def remove_indices(self, indices):
272280
"""Method to remove particles from the ParticleSet, based on their `indices`."""
273-
self._data = self._data.drop_sel(trajectory=indices)
281+
for d in self._data:
282+
self._data[d] = np.delete(self._data[d], indices, axis=0)
274283

275284
def _active_particles_mask(self, time, dt):
276285
active_indices = (time - self._data["time"]) / dt >= 0
@@ -591,19 +600,19 @@ def Kernel(self, pyfunc):
591600
if isinstance(pyfunc, list):
592601
return Kernel.from_list(
593602
self.fieldset,
594-
self._data.ptype,
603+
self._ptype,
595604
pyfunc,
596605
)
597606
return Kernel(
598607
self.fieldset,
599-
self._data.ptype,
608+
self._ptype,
600609
pyfunc=pyfunc,
601610
)
602611

603612
def InteractionKernel(self, pyfunc_inter):
604613
if pyfunc_inter is None:
605614
return None
606-
return InteractionKernel(self.fieldset, self._data.ptype, pyfunc=pyfunc_inter)
615+
return InteractionKernel(self.fieldset, self._ptype, pyfunc=pyfunc_inter)
607616

608617
def ParticleFile(self, *args, **kwargs):
609618
"""Wrapper method to initialise a :class:`parcels.particlefile.ParticleFile` object from the ParticleSet."""
@@ -747,9 +756,9 @@ def execute(
747756
else:
748757
if not np.isnat(self._data["time_nextloop"]).any():
749758
if sign_dt > 0:
750-
start_time = self._data["time_nextloop"].min().values
759+
start_time = self._data["time_nextloop"].min()
751760
else:
752-
start_time = self._data["time_nextloop"].max().values
761+
start_time = self._data["time_nextloop"].max()
753762
else:
754763
if sign_dt > 0:
755764
start_time = self.fieldset.time_interval.left

tests/v4/test_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_unknown_var_in_kernel(fieldset):
4242
def ErrorKernel(particle, fieldset, time): # pragma: no cover
4343
particle.unknown_varname += 0.2
4444

45-
with pytest.raises(KeyError, match="No variable named 'unknown_varname'"):
45+
with pytest.raises(KeyError, match="'unknown_varname'"):
4646
pset.execute(ErrorKernel, runtime=np.timedelta64(2, "s"))
4747

4848

tests/v4/test_particleset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,17 @@ def test_pset_add_explicit(fieldset):
197197
assert len(pset) == npart
198198
assert np.allclose([p.lon for p in pset], lon, atol=1e-12)
199199
assert np.allclose([p.lat for p in pset], lat, atol=1e-12)
200-
assert np.allclose(np.diff(pset._data.trajectory), np.ones(pset._data.trajectory.size - 1), atol=1e-12)
200+
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(pset._data["trajectory"].size - 1), atol=1e-12)
201201

202202

203203
def test_pset_add_implicit(fieldset):
204204
pset = ParticleSet(fieldset, lon=np.zeros(3), lat=np.ones(3), pclass=Particle)
205205
pset += ParticleSet(fieldset, lon=np.ones(4), lat=np.zeros(4), pclass=Particle)
206206
assert len(pset) == 7
207-
assert np.allclose(np.diff(pset._data.trajectory), np.ones(6), atol=1e-12)
207+
assert np.allclose(np.diff(pset._data["trajectory"]), np.ones(6), atol=1e-12)
208208

209209

210-
def test_pset_add_implicit(fieldset, npart=10):
210+
def test_pset_add_implicit_in_loop(fieldset, npart=10):
211211
pset = ParticleSet(fieldset, lon=[], lat=[])
212212
for _ in range(npart):
213213
pset += ParticleSet(pclass=Particle, lon=0.1, lat=0.1, fieldset=fieldset)

0 commit comments

Comments
 (0)