Skip to content

Commit 40c27d1

Browse files
align_chunks not working for datasets (#10516)
* The align_chunks parameter was not being sent on the to_zarr method of the datasets * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add a note on the whats-new.rst about the error of the align_chunks for datasets * Fix a ValueError on the test_dataset_to_zarr_align_chunks_true * Fix the case when enc_chunks are bigger than the dask chunks * Linter * Fix small reintroduced issue when the region is None * Fix mypy issues * Update whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use "v" instead of "var" to follow the name convention used on the rest of Xarray, move the modification of the enc_chunks to the build_grid_chunks function, add additional test to covert the scenario where the chunk is bigger than the size of the array * Update the whats-new.rst * Fix whats-new.rst --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cc1de6b commit 40c27d1

File tree

6 files changed

+115
-56
lines changed

6 files changed

+115
-56
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ Deprecations
2929
Bug fixes
3030
~~~~~~~~~
3131

32+
- Fix the ``align_chunks`` parameter on the :py:meth:`~xarray.Dataset.to_zarr` method, it was not being
33+
passed to the underlying :py:meth:`~xarray.backends.api` method (:issue:`10501`, :pull:`10516`).
3234

3335
Documentation
3436
~~~~~~~~~~~~~

xarray/backends/chunks.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44

55

66
def align_nd_chunks(
7-
nd_var_chunks: tuple[tuple[int, ...], ...],
7+
nd_v_chunks: tuple[tuple[int, ...], ...],
88
nd_backend_chunks: tuple[tuple[int, ...], ...],
99
) -> tuple[tuple[int, ...], ...]:
10-
if len(nd_backend_chunks) != len(nd_var_chunks):
10+
if len(nd_backend_chunks) != len(nd_v_chunks):
1111
raise ValueError(
1212
"The number of dimensions on the backend and the variable must be the same."
1313
)
1414

1515
nd_aligned_chunks: list[tuple[int, ...]] = []
16-
for backend_chunks, var_chunks in zip(
17-
nd_backend_chunks, nd_var_chunks, strict=True
18-
):
16+
for backend_chunks, v_chunks in zip(nd_backend_chunks, nd_v_chunks, strict=True):
1917
# Validate that they have the same number of elements
20-
if sum(backend_chunks) != sum(var_chunks):
18+
if sum(backend_chunks) != sum(v_chunks):
2119
raise ValueError(
2220
"The number of elements in the backend does not "
2321
"match the number of elements in the variable. "
@@ -42,39 +40,39 @@ def align_nd_chunks(
4240
nd_aligned_chunks.append(backend_chunks)
4341
continue
4442

45-
if len(var_chunks) == 1:
46-
nd_aligned_chunks.append(var_chunks)
43+
if len(v_chunks) == 1:
44+
nd_aligned_chunks.append(v_chunks)
4745
continue
4846

4947
# Size of the chunk on the backend
5048
fixed_chunk = max(backend_chunks)
5149

5250
# The ideal size of the chunks is the maximum of the two; this would avoid
5351
# that we use more memory than expected
54-
max_chunk = max(fixed_chunk, *var_chunks)
52+
max_chunk = max(fixed_chunk, *v_chunks)
5553

5654
# The algorithm assumes that the chunks on this array are aligned except the last one
5755
# because it can be considered a partial one
5856
aligned_chunks: list[int] = []
5957

6058
# For simplicity of the algorithm, let's transform the Array chunks in such a way that
6159
# we remove the partial chunks. To achieve this, we add artificial data to the borders
62-
t_var_chunks = list(var_chunks)
63-
t_var_chunks[0] += fixed_chunk - backend_chunks[0]
64-
t_var_chunks[-1] += fixed_chunk - backend_chunks[-1]
60+
t_v_chunks = list(v_chunks)
61+
t_v_chunks[0] += fixed_chunk - backend_chunks[0]
62+
t_v_chunks[-1] += fixed_chunk - backend_chunks[-1]
6563

6664
# The unfilled_size is the amount of space that has not been filled on the last
6765
# processed chunk; this is equivalent to the amount of data that would need to be
6866
# added to a partial Zarr chunk to fill it up to the fixed_chunk size
6967
unfilled_size = 0
7068

71-
for var_chunk in t_var_chunks:
69+
for v_chunk in t_v_chunks:
7270
# Ideally, we should try to preserve the original Dask chunks, but this is only
7371
# possible if the last processed chunk was aligned (unfilled_size == 0)
74-
ideal_chunk = var_chunk
72+
ideal_chunk = v_chunk
7573
if unfilled_size:
7674
# If that scenario is not possible, the best option is to merge the chunks
77-
ideal_chunk = var_chunk + aligned_chunks[-1]
75+
ideal_chunk = v_chunk + aligned_chunks[-1]
7876

7977
while ideal_chunk:
8078
if not unfilled_size:
@@ -105,27 +103,27 @@ def align_nd_chunks(
105103
border_size = fixed_chunk - backend_chunks[::order][0]
106104
aligned_chunks = aligned_chunks[::order]
107105
aligned_chunks[0] -= border_size
108-
t_var_chunks = t_var_chunks[::order]
109-
t_var_chunks[0] -= border_size
106+
t_v_chunks = t_v_chunks[::order]
107+
t_v_chunks[0] -= border_size
110108
if (
111109
len(aligned_chunks) >= 2
112110
and aligned_chunks[0] + aligned_chunks[1] <= max_chunk
113-
and aligned_chunks[0] != t_var_chunks[0]
111+
and aligned_chunks[0] != t_v_chunks[0]
114112
):
115113
# The artificial data added to the border can introduce inefficient chunks
116114
# on the borders, for that reason, we will check if we can merge them or not
117115
# Example:
118116
# backend_chunks = [6, 6, 1]
119-
# var_chunks = [6, 7]
120-
# t_var_chunks = [6, 12]
121-
# The ideal output should preserve the same var_chunks, but the previous loop
117+
# v_chunks = [6, 7]
118+
# t_v_chunks = [6, 12]
119+
# The ideal output should preserve the same v_chunks, but the previous loop
122120
# is going to produce aligned_chunks = [6, 6, 6]
123121
# And after removing the artificial data, we will end up with aligned_chunks = [6, 6, 1]
124122
# which is not ideal and can be merged into a single chunk
125123
aligned_chunks[1] += aligned_chunks[0]
126124
aligned_chunks = aligned_chunks[1:]
127125

128-
t_var_chunks = t_var_chunks[::order]
126+
t_v_chunks = t_v_chunks[::order]
129127
aligned_chunks = aligned_chunks[::order]
130128

131129
nd_aligned_chunks.append(tuple(aligned_chunks))
@@ -144,6 +142,11 @@ def build_grid_chunks(
144142
region_start = region.start or 0
145143
# Generate the zarr chunks inside the region of this dim
146144
chunks_on_region = [chunk_size - (region_start % chunk_size)]
145+
if chunks_on_region[0] >= size:
146+
# This is useful for the scenarios where the chunk_size are bigger
147+
# than the variable chunks, which can happens when the user specifies
148+
# the enc_chunks manually.
149+
return (size,)
147150
chunks_on_region.extend([chunk_size] * ((size - chunks_on_region[0]) // chunk_size))
148151
if (size - chunks_on_region[0]) % chunk_size != 0:
149152
chunks_on_region.append((size - chunks_on_region[0]) % chunk_size)
@@ -155,45 +158,45 @@ def grid_rechunk(
155158
enc_chunks: tuple[int, ...],
156159
region: tuple[slice, ...],
157160
) -> Variable:
158-
nd_var_chunks = v.chunks
159-
if not nd_var_chunks:
161+
nd_v_chunks = v.chunks
162+
if not nd_v_chunks:
160163
return v
161164

162165
nd_grid_chunks = tuple(
163166
build_grid_chunks(
164-
sum(var_chunks),
167+
v_size,
165168
region=interval,
166169
chunk_size=chunk_size,
167170
)
168-
for var_chunks, chunk_size, interval in zip(
169-
nd_var_chunks, enc_chunks, region, strict=True
171+
for v_size, chunk_size, interval in zip(
172+
v.shape, enc_chunks, region, strict=True
170173
)
171174
)
172175

173176
nd_aligned_chunks = align_nd_chunks(
174-
nd_var_chunks=nd_var_chunks,
177+
nd_v_chunks=nd_v_chunks,
175178
nd_backend_chunks=nd_grid_chunks,
176179
)
177180
v = v.chunk(dict(zip(v.dims, nd_aligned_chunks, strict=True)))
178181
return v
179182

180183

181184
def validate_grid_chunks_alignment(
182-
nd_var_chunks: tuple[tuple[int, ...], ...] | None,
185+
nd_v_chunks: tuple[tuple[int, ...], ...] | None,
183186
enc_chunks: tuple[int, ...],
184187
backend_shape: tuple[int, ...],
185188
region: tuple[slice, ...],
186189
allow_partial_chunks: bool,
187190
name: str,
188191
):
189-
if nd_var_chunks is None:
192+
if nd_v_chunks is None:
190193
return
191194
base_error = (
192195
"Specified Zarr chunks encoding['chunks']={enc_chunks!r} for "
193196
"variable named {name!r} would overlap multiple Dask chunks. "
194-
"Check the chunk at position {var_chunk_pos}, which has a size of "
195-
"{var_chunk_size} on dimension {dim_i}. It is unaligned with "
196-
"backend chunks of size {chunk_size} in region {region}. "
197+
"Please check the Dask chunks at position {v_chunk_pos} and "
198+
"{v_chunk_pos_next}, on axis {axis}, they are overlapped "
199+
"on the same Zarr chunk in the region {region}. "
197200
"Writing this array in parallel with Dask could lead to corrupted data. "
198201
"To resolve this issue, consider one of the following options: "
199202
"- Rechunk the array using `chunk()`. "
@@ -202,22 +205,23 @@ def validate_grid_chunks_alignment(
202205
"- Enable automatic chunks alignment with `align_chunks=True`."
203206
)
204207

205-
for dim_i, chunk_size, var_chunks, interval, size in zip(
208+
for axis, chunk_size, v_chunks, interval, size in zip(
206209
range(len(enc_chunks)),
207210
enc_chunks,
208-
nd_var_chunks,
211+
nd_v_chunks,
209212
region,
210213
backend_shape,
211214
strict=True,
212215
):
213-
for i, chunk in enumerate(var_chunks[1:-1]):
216+
for i, chunk in enumerate(v_chunks[1:-1]):
214217
if chunk % chunk_size:
215218
raise ValueError(
216219
base_error.format(
217-
var_chunk_pos=i + 1,
218-
var_chunk_size=chunk,
220+
v_chunk_pos=i + 1,
221+
v_chunk_pos_next=i + 2,
222+
v_chunk_size=chunk,
223+
axis=axis,
219224
name=name,
220-
dim_i=dim_i,
221225
chunk_size=chunk_size,
222226
region=interval,
223227
enc_chunks=enc_chunks,
@@ -226,20 +230,21 @@ def validate_grid_chunks_alignment(
226230

227231
interval_start = interval.start or 0
228232

229-
if len(var_chunks) > 1:
233+
if len(v_chunks) > 1:
230234
# The first border size is the amount of data that needs to be updated on the
231235
# first chunk taking into account the region slice.
232236
first_border_size = chunk_size
233237
if allow_partial_chunks:
234238
first_border_size = chunk_size - interval_start % chunk_size
235239

236-
if (var_chunks[0] - first_border_size) % chunk_size:
240+
if (v_chunks[0] - first_border_size) % chunk_size:
237241
raise ValueError(
238242
base_error.format(
239-
var_chunk_pos=0,
240-
var_chunk_size=var_chunks[0],
243+
v_chunk_pos=0,
244+
v_chunk_pos_next=0,
245+
v_chunk_size=v_chunks[0],
246+
axis=axis,
241247
name=name,
242-
dim_i=dim_i,
243248
chunk_size=chunk_size,
244249
region=interval,
245250
enc_chunks=enc_chunks,
@@ -250,10 +255,11 @@ def validate_grid_chunks_alignment(
250255
region_stop = interval.stop or size
251256

252257
error_on_last_chunk = base_error.format(
253-
var_chunk_pos=len(var_chunks) - 1,
254-
var_chunk_size=var_chunks[-1],
258+
v_chunk_pos=len(v_chunks) - 1,
259+
v_chunk_pos_next=len(v_chunks) - 1,
260+
v_chunk_size=v_chunks[-1],
261+
axis=axis,
255262
name=name,
256-
dim_i=dim_i,
257263
chunk_size=chunk_size,
258264
region=interval,
259265
enc_chunks=enc_chunks,
@@ -267,7 +273,7 @@ def validate_grid_chunks_alignment(
267273
# If the region is covering the last chunk then check
268274
# if the reminder with the default chunk size
269275
# is equal to the size of the last chunk
270-
if var_chunks[-1] % chunk_size != size % chunk_size:
276+
if v_chunks[-1] % chunk_size != size % chunk_size:
271277
raise ValueError(error_on_last_chunk)
272-
elif var_chunks[-1] % chunk_size:
278+
elif v_chunks[-1] % chunk_size:
273279
raise ValueError(error_on_last_chunk)

xarray/backends/zarr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ def set_variables(
12491249
# threads
12501250
shape = zarr_shape or v.shape
12511251
validate_grid_chunks_alignment(
1252-
nd_var_chunks=v.chunks,
1252+
nd_v_chunks=v.chunks,
12531253
enc_chunks=encoding["chunks"],
12541254
region=region,
12551255
allow_partial_chunks=self._mode != "r+",

xarray/core/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,7 @@ def to_zarr(
23752375
append_dim=append_dim,
23762376
region=region,
23772377
safe_chunks=safe_chunks,
2378+
align_chunks=align_chunks,
23782379
zarr_version=zarr_version,
23792380
zarr_format=zarr_format,
23802381
write_empty_chunks=write_empty_chunks,

xarray/tests/test_backends.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7720,6 +7720,54 @@ def test_zarr_safe_chunk_region(self, mode: Literal["r+", "a"]):
77207720
chunk = chunk.chunk()
77217721
self.save(store, chunk.chunk(), region=region)
77227722

7723+
@requires_dask
7724+
def test_dataset_to_zarr_align_chunks_true(self, tmp_store) -> None:
7725+
# This test is a replica of the one in `test_dataarray_to_zarr_align_chunks_true`
7726+
# but for datasets
7727+
with self.create_zarr_target() as store:
7728+
ds = (
7729+
DataArray(
7730+
np.arange(4).reshape((2, 2)),
7731+
dims=["a", "b"],
7732+
coords={
7733+
"a": np.arange(2),
7734+
"b": np.arange(2),
7735+
},
7736+
)
7737+
.chunk(a=(1, 1), b=(1, 1))
7738+
.to_dataset(name="foo")
7739+
)
7740+
7741+
self.save(
7742+
store,
7743+
ds,
7744+
align_chunks=True,
7745+
encoding={"foo": {"chunks": (3, 3)}},
7746+
mode="w",
7747+
)
7748+
assert_identical(ds, xr.open_zarr(store))
7749+
7750+
ds = (
7751+
DataArray(
7752+
np.arange(4, 8).reshape((2, 2)),
7753+
dims=["a", "b"],
7754+
coords={
7755+
"a": np.arange(2),
7756+
"b": np.arange(2),
7757+
},
7758+
)
7759+
.chunk(a=(1, 1), b=(1, 1))
7760+
.to_dataset(name="foo")
7761+
)
7762+
7763+
self.save(
7764+
store,
7765+
ds,
7766+
align_chunks=True,
7767+
region="auto",
7768+
)
7769+
assert_identical(ds, xr.open_zarr(store))
7770+
77237771

77247772
@requires_h5netcdf
77257773
@requires_fsspec

xarray/tests/test_backends_chunks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
(10, 3, None, (3, 3, 3, 1)),
1515
(10, 3, slice(None, 10), (3, 3, 3, 1)),
1616
(10, 3, slice(0, None), (3, 3, 3, 1)),
17+
(2, 10, slice(0, 3), (2,)),
18+
(4, 10, slice(7, 10), (3, 1)),
1719
],
1820
)
1921
def test_build_grid_chunks(size, chunk_size, region, expected_chunks):
@@ -26,24 +28,24 @@ def test_build_grid_chunks(size, chunk_size, region, expected_chunks):
2628

2729

2830
@pytest.mark.parametrize(
29-
"nd_var_chunks, nd_backend_chunks, expected_chunks",
31+
"nd_v_chunks, nd_backend_chunks, expected_chunks",
3032
[
3133
(((2, 2, 2, 2),), ((3, 3, 2),), ((3, 3, 2),)),
3234
# ND cases
3335
(((2, 4), (2, 3)), ((2, 2, 2), (3, 2)), ((2, 4), (3, 2))),
3436
],
3537
)
36-
def test_align_nd_chunks(nd_var_chunks, nd_backend_chunks, expected_chunks):
38+
def test_align_nd_chunks(nd_v_chunks, nd_backend_chunks, expected_chunks):
3739
aligned_nd_chunks = align_nd_chunks(
38-
nd_var_chunks=nd_var_chunks,
40+
nd_v_chunks=nd_v_chunks,
3941
nd_backend_chunks=nd_backend_chunks,
4042
)
4143
assert aligned_nd_chunks == expected_chunks
4244

4345

4446
@requires_dask
4547
@pytest.mark.parametrize(
46-
"enc_chunks, region, nd_var_chunks, expected_chunks",
48+
"enc_chunks, region, nd_v_chunks, expected_chunks",
4749
[
4850
(
4951
(3,),
@@ -93,7 +95,7 @@ def test_align_nd_chunks(nd_var_chunks, nd_backend_chunks, expected_chunks):
9395
),
9496
],
9597
)
96-
def test_grid_rechunk(enc_chunks, region, nd_var_chunks, expected_chunks):
98+
def test_grid_rechunk(enc_chunks, region, nd_v_chunks, expected_chunks):
9799
dims = [f"dim_{i}" for i in range(len(region))]
98100
coords = {
99101
dim: list(range(r.start, r.stop)) for dim, r in zip(dims, region, strict=False)
@@ -104,7 +106,7 @@ def test_grid_rechunk(enc_chunks, region, nd_var_chunks, expected_chunks):
104106
dims=dims,
105107
coords=coords,
106108
)
107-
arr = arr.chunk(dict(zip(dims, nd_var_chunks, strict=False)))
109+
arr = arr.chunk(dict(zip(dims, nd_v_chunks, strict=False)))
108110

109111
result = grid_rechunk(
110112
arr.variable,

0 commit comments

Comments
 (0)