Skip to content

Commit ea7fcb8

Browse files
committed
Move chunks-related functions to a new file
Part of pydata#10089
1 parent 4174aa1 commit ea7fcb8

File tree

10 files changed

+262
-92
lines changed

10 files changed

+262
-92
lines changed

Diff for: xarray/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
cross,
2525
dot,
2626
polyval,
27-
unify_chunks,
2827
where,
2928
)
3029
from xarray.conventions import SerializationWarning, decode_cf
@@ -52,6 +51,7 @@
5251
from xarray.core.variable import IndexVariable, Variable, as_variable
5352
from xarray.namedarray.core import NamedArray
5453
from xarray.structure.alignment import align, broadcast
54+
from xarray.structure.chunks import unify_chunks
5555
from xarray.structure.combine import combine_by_coords, combine_nested
5656
from xarray.structure.concat import concat
5757
from xarray.structure.merge import Context, MergeError, merge

Diff for: xarray/backends/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from xarray.backends.locks import _get_scheduler
3636
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
3737
from xarray.core import indexing
38-
from xarray.core.chunk import _get_chunk, _maybe_chunk
3938
from xarray.core.dataarray import DataArray
4039
from xarray.core.dataset import Dataset
4140
from xarray.core.datatree import DataTree
@@ -45,6 +44,7 @@
4544
from xarray.core.utils import is_remote_uri
4645
from xarray.namedarray.daskmanager import DaskManager
4746
from xarray.namedarray.parallelcompat import guess_chunkmanager
47+
from xarray.structure.chunks import _get_chunk, _maybe_chunk
4848
from xarray.structure.combine import (
4949
_infer_concat_order_from_positions,
5050
_nested_combine,

Diff for: xarray/coding/calendar_ops.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_should_cftime_be_used,
1010
convert_times,
1111
)
12-
from xarray.computation.apply_ufunc import apply_ufunc
12+
1313
from xarray.core.common import (
1414
_contains_datetime_like_objects,
1515
full_like,
@@ -333,6 +333,8 @@ def _decimal_year(times):
333333
else:
334334
function = _decimal_year_numpy
335335
kwargs = {"dtype": times.dtype}
336+
from xarray.computation.apply_ufunc import apply_ufunc
337+
336338
return apply_ufunc(
337339
function,
338340
times,

Diff for: xarray/computation/computation.py

+7-79
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
Callable,
1414
Hashable,
1515
)
16-
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, overload
16+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
1717

1818
import numpy as np
1919

2020
from xarray.compat.array_api_compat import to_like_array
21-
from xarray.computation.apply_ufunc import apply_ufunc
2221
from xarray.core import dtypes, duck_array_ops, utils
2322
from xarray.core.common import zeros_like
2423
from xarray.core.duck_array_ops import datetime_to_numeric
@@ -467,6 +466,8 @@ def cross(
467466
" dimensions without coordinates must have have a length of 2 or 3"
468467
)
469468

469+
from xarray.computation.apply_ufunc import apply_ufunc
470+
470471
c = apply_ufunc(
471472
duck_array_ops.cross,
472473
a,
@@ -629,6 +630,8 @@ def dot(
629630
# subscripts should be passed to np.einsum as arg, not as kwargs. We need
630631
# to construct a partial function for apply_ufunc to work.
631632
func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
633+
from xarray.computation.apply_ufunc import apply_ufunc
634+
632635
result = apply_ufunc(
633636
func,
634637
*arrays,
@@ -729,6 +732,8 @@ def where(cond, x, y, keep_attrs=None):
729732
keep_attrs = _get_keep_attrs(default=False)
730733

731734
# alignment for three arguments is complicated, so don't support it yet
735+
from xarray.computation.apply_ufunc import apply_ufunc
736+
732737
result = apply_ufunc(
733738
duck_array_ops.where,
734739
cond,
@@ -951,80 +956,3 @@ def _calc_idxminmax(
951956
res.attrs = indx.attrs
952957

953958
return res
954-
955-
956-
_T = TypeVar("_T", bound=Union["Dataset", "DataArray"])
957-
_U = TypeVar("_U", bound=Union["Dataset", "DataArray"])
958-
_V = TypeVar("_V", bound=Union["Dataset", "DataArray"])
959-
960-
961-
@overload
962-
def unify_chunks(__obj: _T) -> tuple[_T]: ...
963-
964-
965-
@overload
966-
def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ...
967-
968-
969-
@overload
970-
def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ...
971-
972-
973-
@overload
974-
def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ...
975-
976-
977-
def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]:
978-
"""
979-
Given any number of Dataset and/or DataArray objects, returns
980-
new objects with unified chunk size along all chunked dimensions.
981-
982-
Returns
983-
-------
984-
unified (DataArray or Dataset) – Tuple of objects with the same type as
985-
*objects with consistent chunk sizes for all dask-array variables
986-
987-
See Also
988-
--------
989-
dask.array.core.unify_chunks
990-
"""
991-
from xarray.core.dataarray import DataArray
992-
993-
# Convert all objects to datasets
994-
datasets = [
995-
obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy()
996-
for obj in objects
997-
]
998-
999-
# Get arguments to pass into dask.array.core.unify_chunks
1000-
unify_chunks_args = []
1001-
sizes: dict[Hashable, int] = {}
1002-
for ds in datasets:
1003-
for v in ds._variables.values():
1004-
if v.chunks is not None:
1005-
# Check that sizes match across different datasets
1006-
for dim, size in v.sizes.items():
1007-
try:
1008-
if sizes[dim] != size:
1009-
raise ValueError(
1010-
f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}"
1011-
)
1012-
except KeyError:
1013-
sizes[dim] = size
1014-
unify_chunks_args += [v._data, v._dims]
1015-
1016-
# No dask arrays: Return inputs
1017-
if not unify_chunks_args:
1018-
return objects
1019-
1020-
chunkmanager = get_chunked_array_type(*list(unify_chunks_args))
1021-
_, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args)
1022-
chunked_data_iter = iter(chunked_data)
1023-
out: list[Dataset | DataArray] = []
1024-
for obj, ds in zip(objects, datasets, strict=True):
1025-
for k, v in ds._variables.items():
1026-
if v.chunks is not None:
1027-
ds._variables[k] = v.copy(data=next(chunked_data_iter))
1028-
out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds)
1029-
1030-
return tuple(out)

Diff for: xarray/core/accessor_str.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151

5252
import numpy as np
5353

54-
from xarray.computation.computation import apply_ufunc
5554
from xarray.core import duck_array_ops
5655
from xarray.core.types import T_DataArray
5756

@@ -127,6 +126,8 @@ def _apply_str_ufunc(
127126
if output_sizes is not None:
128127
dask_gufunc_kwargs["output_sizes"] = output_sizes
129128

129+
from xarray.computation.apply_ufunc import apply_ufunc
130+
130131
return apply_ufunc(
131132
func,
132133
obj,

Diff for: xarray/core/common.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def clip(
491491
--------
492492
numpy.clip : equivalent function
493493
"""
494-
from xarray.computation.computation import apply_ufunc
494+
from xarray.computation.apply_ufunc import apply_ufunc
495495

496496
if keep_attrs is None:
497497
# When this was a unary func, the default was True, so retaining the
@@ -1311,7 +1311,7 @@ def isnull(self, keep_attrs: bool | None = None) -> Self:
13111311
array([False, True, False])
13121312
Dimensions without coordinates: x
13131313
"""
1314-
from xarray.computation.computation import apply_ufunc
1314+
from xarray.computation.apply_ufunc import apply_ufunc
13151315

13161316
if keep_attrs is None:
13171317
keep_attrs = _get_keep_attrs(default=False)
@@ -1354,7 +1354,7 @@ def notnull(self, keep_attrs: bool | None = None) -> Self:
13541354
array([ True, False, True])
13551355
Dimensions without coordinates: x
13561356
"""
1357-
from xarray.computation.computation import apply_ufunc
1357+
from xarray.computation.apply_ufunc import apply_ufunc
13581358

13591359
if keep_attrs is None:
13601360
keep_attrs = _get_keep_attrs(default=False)
@@ -1393,7 +1393,7 @@ def isin(self, test_elements: Any) -> Self:
13931393
--------
13941394
numpy.isin
13951395
"""
1396-
from xarray.computation.computation import apply_ufunc
1396+
from xarray.computation.apply_ufunc import apply_ufunc
13971397
from xarray.core.dataarray import DataArray
13981398
from xarray.core.dataset import Dataset
13991399
from xarray.core.variable import Variable
@@ -1476,7 +1476,7 @@ def astype(
14761476
dask.array.Array.astype
14771477
sparse.COO.astype
14781478
"""
1479-
from xarray.computation.computation import apply_ufunc
1479+
from xarray.computation.apply_ufunc import apply_ufunc
14801480

14811481
kwargs = dict(order=order, casting=casting, subok=subok, copy=copy)
14821482
kwargs = {k: v for k, v in kwargs.items() if v is not None}

Diff for: xarray/core/dataarray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from xarray.coding.cftimeindex import CFTimeIndex
3232
from xarray.computation import computation, ops
3333
from xarray.computation.arithmetic import DataArrayArithmetic
34-
from xarray.computation.computation import unify_chunks
3534
from xarray.core import dtypes, indexing, utils
3635
from xarray.core._aggregations import DataArrayAggregations
3736
from xarray.core.accessor_dt import CombinedDatetimelikeAccessor
@@ -88,6 +87,7 @@
8887
_get_broadcast_dims_map_common_coords,
8988
align,
9089
)
90+
from xarray.structure.chunks import unify_chunks
9191
from xarray.structure.merge import PANDAS_TYPES, MergeError
9292
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
9393

Diff for: xarray/core/dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import numpy as np
2828
from pandas.api.types import is_extension_array_dtype
2929

30-
from xarray.core.chunk import _maybe_chunk
3130
from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer
3231
from xarray.core.dataset_variables import DataVariables
32+
from xarray.structure.chunks import _maybe_chunk
3333

3434
# remove once numpy 2.0 is the oldest supported version
3535
try:
@@ -44,7 +44,7 @@
4444
from xarray.compat.array_api_compat import to_like_array
4545
from xarray.computation import ops
4646
from xarray.computation.arithmetic import DatasetArithmetic
47-
from xarray.computation.computation import _ensure_numeric, unify_chunks
47+
from xarray.computation.computation import _ensure_numeric
4848
from xarray.core import dtypes as xrdtypes
4949
from xarray.core import (
5050
duck_array_ops,
@@ -125,6 +125,7 @@
125125
_get_broadcast_dims_map_common_coords,
126126
align,
127127
)
128+
from xarray.structure.chunks import unify_chunks
128129
from xarray.structure.merge import (
129130
dataset_merge_method,
130131
dataset_update_method,

Diff for: xarray/groupers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from numpy.typing import ArrayLike
1818

1919
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
20-
from xarray.computation.computation import apply_ufunc
20+
from xarray.computation.apply_ufunc import apply_ufunc
2121
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2222
from xarray.core.dataarray import DataArray
2323
from xarray.core.duck_array_ops import array_all, isnull

0 commit comments

Comments
 (0)