Skip to content

Commit 19c8bf9

Browse files
authoredJun 12, 2023
Refactor P2P rechunk validation (#7890)
1 parent 618f5ac commit 19c8bf9

File tree

2 files changed

+1
-38
lines changed

2 files changed

+1
-38
lines changed
 

‎distributed/shuffle/_rechunk.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
import math
43
from collections import defaultdict
5-
from itertools import compress, product
4+
from itertools import product
65
from typing import TYPE_CHECKING, NamedTuple
76

87
import dask
@@ -73,34 +72,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
7372
# Special case for empty array, as the algorithm below does not behave correctly
7473
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)
7574

76-
old_chunks = x.chunks
77-
new_chunks = chunks
78-
79-
def is_unknown(dim: ChunkedAxis) -> bool:
80-
return any(math.isnan(chunk) for chunk in dim)
81-
82-
old_is_unknown = [is_unknown(dim) for dim in old_chunks]
83-
new_is_unknown = [is_unknown(dim) for dim in new_chunks]
84-
85-
if old_is_unknown != new_is_unknown or any(
86-
new != old for new, old in compress(zip(old_chunks, new_chunks), old_is_unknown)
87-
):
88-
raise ValueError(
89-
"Chunks must be unchanging along dimensions with missing values.\n\n"
90-
"A possible solution:\n x.compute_chunk_sizes()"
91-
)
92-
93-
old_known = [dim for dim, unknown in zip(old_chunks, old_is_unknown) if not unknown]
94-
new_known = [dim for dim, unknown in zip(new_chunks, new_is_unknown) if not unknown]
95-
96-
old_sizes = [sum(o) for o in old_known]
97-
new_sizes = [sum(n) for n in new_known]
98-
99-
if old_sizes != new_sizes:
100-
raise ValueError(
101-
f"Cannot change dimensions from {old_sizes!r} to {new_sizes!r}"
102-
)
103-
10475
dsk: dict = {}
10576
token = tokenize(x, chunks)
10677
_barrier_key = barrier_key(ShuffleId(token))

‎distributed/shuffle/_worker_extension.py

-8
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,6 @@ def __init__(
314314
memory_limiter_comms=memory_limiter_comms,
315315
memory_limiter_disk=memory_limiter_disk,
316316
)
317-
from dask.array.core import normalize_chunks
318-
319-
# We rely on a canonical `np.nan` in `dask.array.rechunk.old_to_new`
320-
# that passes an implicit identity check when testing for list equality.
321-
# This does not work with (de)serialization, so we have to normalize the chunks
322-
# here again to canonicalize `nan`s.
323-
old = normalize_chunks(old)
324-
new = normalize_chunks(new)
325317
self.old = old
326318
self.new = new
327319
partitions_of = defaultdict(list)

0 commit comments

Comments
 (0)
Please sign in to comment.