Skip to content

Commit bc20889

Browse files
Adding an option to change the compression type of the parquet file (#2611)
1 parent b8bab51 commit bc20889

2 files changed

Lines changed: 28 additions & 2 deletions

File tree

src/parcels/_core/particlefile.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,23 @@ class ParticleFile:
5858
Interval which dictates the update frequency of file output
5959
while ParticleFile is given as an argument of ParticleSet.execute()
6060
It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds).
61+
compression : {"zstd", "gzip", "snappy", "brotli", None}, optional
62+
Compression algorithm to use for the Parquet file. Default is "zstd".
6163
6264
Returns
6365
-------
6466
ParticleFile
6567
ParticleFile object that can be used to write particle data to file
6668
"""
6769

68-
def __init__(self, path: PathLike, outputdt):
70+
def __init__(
71+
self, path: PathLike, outputdt, compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd"
72+
):
6973
if not isinstance(outputdt, (np.timedelta64, timedelta, float)):
7074
raise ValueError(
7175
f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}"
7276
)
77+
self._compression = compression
7378

7479
outputdt = timedelta_to_float(outputdt)
7580
path = Path(path)
@@ -133,7 +138,11 @@ def write(self, pset: ParticleSet, time, indices=None):
133138

134139
if self._writer is None:
135140
assert not self.path.exists(), "If the file exists, the writer should already be set"
136-
self._writer = pq.ParquetWriter(self.path, _get_schema(pclass, self.metadata, pset.fieldset.time_interval))
141+
self._writer = pq.ParquetWriter(
142+
self.path,
143+
_get_schema(pclass, self.metadata, pset.fieldset.time_interval),
144+
compression=self._compression,
145+
)
137146

138147
if isinstance(time, (np.timedelta64, np.datetime64)):
139148
time = timedelta_to_float(time - time_interval.left)

tests/test_particlefile.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,23 @@ def test_metadata(fieldset, tmp_parquet):
5757
assert tab.schema.metadata[b"parcels_kernels"].decode().lower() == "DoNothing".lower()
5858

5959

60+
@pytest.mark.parametrize("compression", ["zstd", "gzip", "snappy", "brotli", None])
61+
def test_compression(fieldset, tmp_parquet, compression):
62+
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0)
63+
64+
ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), compression=compression)
65+
pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile)
66+
67+
tab = pq.ParquetFile(tmp_parquet)
68+
for i in range(tab.num_row_groups):
69+
row_group = tab.metadata.row_group(i)
70+
for j in range(row_group.num_columns):
71+
col = row_group.column(j)
72+
assert col.compression.lower() == compression or (
73+
compression is None and col.compression.lower() == "uncompressed"
74+
)
75+
76+
6077
def test_write_fieldset_without_time(tmp_parquet):
6178
ds = peninsula_dataset() # DataSet without time
6279
assert "time" not in ds.dims

0 commit comments

Comments
 (0)