diff --git a/doc/api.rst b/doc/api.rst index 0c30ddc4c20..85ef46ca6ba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1210,6 +1210,7 @@ Dataset DatasetGroupBy.var DatasetGroupBy.dims DatasetGroupBy.groups + DatasetGroupBy.shuffle_to_chunks DataArray --------- @@ -1241,6 +1242,7 @@ DataArray DataArrayGroupBy.var DataArrayGroupBy.dims DataArrayGroupBy.groups + DataArrayGroupBy.shuffle_to_chunks Grouper Objects --------------- diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index f7d76edadf8..7cb4e883347 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -330,3 +330,24 @@ Different groupers can be combined to construct sophisticated GroupBy operations from xarray.groupers import BinGrouper ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + + +Shuffling +~~~~~~~~~ + +Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``. +Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example, +Shuffle the object using either :py:class:`DatasetGroupBy` or :py:class:`DataArrayGroupBy` as appropriate. + +.. ipython:: python + + da = xr.DataArray( + dims="x", + data=[1, 2, 3, 4, 5, 6], + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + da.groupby("label").shuffle_to_chunks() + + +For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer. +Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1448145183f..eae11c0c491 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -63,6 +63,7 @@ Bins, DaCompatible, NetcdfWriteModes, + T_Chunks, T_DataArray, T_DataArrayOrSet, ZarrWriteModes, @@ -105,6 +106,7 @@ Dims, ErrorOptions, ErrorOptionsWithWarn, + GroupIndices, GroupInput, InterpOptions, PadModeOptions, @@ -1687,6 +1689,12 @@ def sel( ) return self._from_temp_dataset(ds) + def _shuffle( + self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks + ) -> Self: + ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks) + return self._from_temp_dataset(ds) + def head( self, indexers: Mapping[Any, int] | int | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cc34a8cc04b..5056a80af9a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -155,6 +155,7 @@ DsCompatible, ErrorOptions, ErrorOptionsWithWarn, + GroupIndices, GroupInput, InterpOptions, JoinOptions, @@ -166,6 +167,7 @@ ResampleCompatible, SideOptions, T_ChunkDimFreq, + T_Chunks, T_DatasetPadConstantValues, T_Xarray, ) @@ -3237,6 +3239,38 @@ def sel( result = self.isel(indexers=query_results.dim_indexers, drop=drop) return result._overwrite_indexes(*query_results.as_tuple()[1:]) + def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self: + # Shuffling is only different from `isel` for chunked arrays. + # Extract them out, and treat them specially. The rest, we route through isel. + # This makes it easy to ensure correct handling of indexes. + is_chunked = { + name: var + for name, var in self._variables.items() + if is_chunked_array(var._data) + } + subset = self[[name for name in self._variables if name not in is_chunked]] + + no_slices: list[list[int]] = [ + list(range(*idx.indices(self.sizes[dim]))) + if isinstance(idx, slice) + else idx + for idx in indices + ] + no_slices = [idx for idx in no_slices if idx] + + shuffled = ( + subset + if dim not in subset.dims + else subset.isel({dim: np.concatenate(no_slices)}) + ) + for name, var in is_chunked.items(): + shuffled[name] = var._shuffle( + indices=no_slices, + dim=dim, + chunks=chunks, + ) + return shuffled + def head( self, indexers: Mapping[Any, int] | int | None = None, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7a32cd7b1db..9596d19e735 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -57,7 +57,13 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey + from xarray.core.types import ( + GroupIndex, + GroupIndices, + GroupInput, + GroupKey, + T_Chunks, + ) from xarray.core.utils import Frozen from xarray.groupers import EncodedGroups, Grouper @@ -676,6 +682,76 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray: + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask.array + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle_to_chunks() + >>> shuffled + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 + + >>> shuffled.groupby("x").quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + self._raise_if_by_is_chunked() + return self._shuffle_obj(chunks) + + def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: + from xarray.core.dataarray import DataArray + + was_array = isinstance(self._obj, DataArray) + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + + for grouper in self.groupers: + if grouper.name not in as_dataset._variables: + as_dataset.coords[grouper.name] = grouper.group + + shuffled = as_dataset._shuffle( + dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks + ) + unstacked: Dataset = self._maybe_unstack(shuffled) + if was_array: + return self._obj._from_temp_dataset(unstacked) + else: + return unstacked # type: ignore[return-value] + def map( self, func: Callable, @@ -896,7 +972,9 @@ def _maybe_unstack(self, obj): # and `inserted_dims` # if multiple groupers all share the same single dimension, then # we don't stack/unstack. Do that manually now. - obj = obj.unstack(*self.encoded.unique_coord.dims) + dims_to_unstack = self.encoded.unique_coord.dims + if all(dim in obj.dims for dim in dims_to_unstack): + obj = obj.unstack(*dims_to_unstack) to_drop = [ grouper.name for grouper in self.groupers diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 9dd91d86a47..5cc98c9651c 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.types import T_Chunks from xarray.groupers import RESAMPLE_DIM @@ -58,6 +59,50 @@ def _flox_reduce( result = result.rename({RESAMPLE_DIM: self._group_dim}) return result + def shuffle_to_chunks(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask.array + >>> da = xr.DataArray( + ... dims="time", + ... data=dask.array.arange(10, chunks=1), + ... coords={"time": xr.date_range("2001-01-01", freq="12h", periods=10)}, + ... name="a", + ... ) + >>> shuffled = da.resample(time="2D").shuffle_to_chunks() + >>> shuffled + Size: 80B + dask.array + Coordinates: + * time (time) datetime64[ns] 80B 2001-01-01 ... 2001-01-05T12:00:00 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + (grouper,) = self.groupers + return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) + def _drop_coords(self) -> T_Xarray: """Drop non-dimension coordinates along the resampled dimension.""" obj = self._obj diff --git a/xarray/core/types.py b/xarray/core/types.py index 56d45ddfed6..11b26da033f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -362,7 +362,7 @@ def read(self, __n: int = ...) -> AnyStr_co: ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3b41b7867d0..9f660d0878a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -45,7 +45,13 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_chunked_array, + to_duck_array, +) from xarray.namedarray.utils import module_available from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -1019,6 +1025,24 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def _shuffle( + self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks + ) -> Self: + # TODO (dcherian): consider making this public API + array = self._data + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return self._replace( + data=chunkmanager.shuffle( + array, + indexer=indices, + axis=self.get_axis_num(dim), + chunks=chunks, + ) + ) + else: + return self.isel({dim: np.concatenate(indices)}) + def isel( self, indexers: Mapping[Any, Any] | None = None, diff --git a/xarray/groupers.py b/xarray/groupers.py index 89b189e582e..dac4c4309de 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -30,6 +30,7 @@ DatetimeLike, GroupIndices, ResampleCompatible, + Self, SideOptions, ) from xarray.core.variable import Variable @@ -139,6 +140,13 @@ def factorize(self, group: T_Group) -> EncodedGroups: """ pass + @abstractmethod + def reset(self) -> Self: + """ + Creates a new version of this Grouper clearing any caches. + """ + pass + class Resampler(Grouper): """ @@ -177,6 +185,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = pd.Index(np.array(self.group).ravel()) return self._group_as_index + def reset(self) -> Self: + return type(self)() + def factorize(self, group: T_Group) -> EncodedGroups: self.group = group @@ -325,6 +336,16 @@ class BinGrouper(Grouper): include_lowest: bool = False duplicates: Literal["raise", "drop"] = "raise" + def reset(self) -> Self: + return type(self)( + bins=self.bins, + right=self.right, + labels=self.labels, + precision=self.precision, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + ) + def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") @@ -427,6 +448,15 @@ class TimeResampler(Resampler): index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False) group_as_index: pd.Index = field(init=False, repr=False) + def reset(self) -> Self: + return type(self)( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + ) + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 90c442d2e1f..95e7d7adfc3 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -78,7 +78,8 @@ def dtype(self) -> _DType_co: ... _Chunks = tuple[_Shape, ...] _NormalizedChunks = tuple[tuple[int, ...], ...] # FYI in some cases we don't allow `None`, which this doesn't take account of. -T_ChunkDim: TypeAlias = int | Literal["auto"] | None | tuple[int, ...] +# # FYI the `str` is for a size string, e.g. "16MB", supported by dask. +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | None | tuple[int, ...] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 9725e341083..6485ba375f5 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -249,3 +249,18 @@ def store( targets=targets, **kwargs, ) + + def shuffle( + self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.1"): + raise ValueError( + "This method is very inefficient on dask<2024.08.1. Please upgrade." + ) + if chunks is None: + chunks = "auto" + if chunks != "auto": + raise NotImplementedError("Only chunks='auto' is supported at present.") + return dask.array.shuffle(x, indexer, axis, chunks="auto") diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index 6ecc9ae9a2b..69dd4ab5f93 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_Chunks, _Chunks, _DType, _DType_co, @@ -356,6 +357,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> T_ChunkedArray: + raise NotImplementedError() + def persist( self, *data: T_ChunkedArray | Any, **kwargs: Any ) -> tuple[T_ChunkedArray | Any, ...]: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index f086e3ca90e..a917f5deb58 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -107,6 +107,9 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_08_1, requires_dask_ge_2024_08_1 = _importorskip( + "dask", minversion="2024.08.1" +) has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0") with warnings.catch_warnings(): warnings.filterwarnings( @@ -152,7 +155,7 @@ def _importorskip( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" ) has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") -_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") +has_flox_0_9_12, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3fc7fcac132..2d5d6c4c16c 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -4,6 +4,7 @@ import operator import warnings from itertools import pairwise +from typing import Literal from unittest import mock import numpy as np @@ -32,11 +33,13 @@ create_test_data, has_cftime, has_dask, + has_dask_ge_2024_08_1, has_flox, has_pandas_ge_2_2, raise_if_dask_computes, requires_cftime, requires_dask, + requires_dask_ge_2024_08_1, requires_flox, requires_flox_0_9_12, requires_pandas_ge_2_2, @@ -632,10 +635,25 @@ def test_groupby_repr_datetime(obj) -> None: assert actual == expected +@pytest.mark.filterwarnings("ignore:No index created for dimension id:UserWarning") @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") -@pytest.mark.filterwarnings("ignore:No index created for dimension id:UserWarning") -def test_groupby_drops_nans() -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + dict(lat=1), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + pytest.param( + dict(lat=2, lon=2), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], +) +def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() # GH2383 # nan in 2D data variable (requires stacking) ds = xr.Dataset( @@ -650,13 +668,17 @@ def test_groupby_drops_nans() -> None: ds["id"].values[3, 0] = np.nan ds["id"].values[-1, -1] = np.nan + if chunk: + ds["variable"] = ds["variable"].chunk(chunk) grouped = ds.groupby(ds.id) + if shuffle: + grouped = grouped.shuffle_to_chunks().groupby(ds.id) # non reduction operation expected1 = ds.copy() - expected1.variable.values[0, 0, :] = np.nan - expected1.variable.values[-1, -1, :] = np.nan - expected1.variable.values[3, 0, :] = np.nan + expected1.variable.data[0, 0, :] = np.nan + expected1.variable.data[-1, -1, :] = np.nan + expected1.variable.data[3, 0, :] = np.nan actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) assert_identical(actual1, expected1) @@ -1355,11 +1377,27 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method) -> None: - array = self.da - grouped = array.groupby("abc") + def test_groupby_reductions( + self, use_flox: bool, method: str, shuffle: bool, chunk: bool + ) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() + array = self.da + if chunk: + array.data = array.chunk({"y": 5}).data reduction = getattr(np, method) expected = Dataset( { @@ -1377,14 +1415,14 @@ def test_groupby_reductions(self, method) -> None: } )["foo"] - with xr.set_options(use_flox=False): - actual_legacy = getattr(grouped, method)(dim="y") - - with xr.set_options(use_flox=True): - actual_npg = getattr(grouped, method)(dim="y") + with raise_if_dask_computes(): + grouped = array.groupby("abc") + if shuffle: + grouped = grouped.shuffle_to_chunks().groupby("abc") - assert_allclose(expected, actual_legacy) - assert_allclose(expected, actual_npg) + with xr.set_options(use_flox=use_flox): + actual = getattr(grouped, method)(dim="y") + assert_allclose(expected, actual) def test_groupby_count(self) -> None: array = DataArray( @@ -1648,13 +1686,17 @@ def test_groupby_bins( ) with xr.set_options(use_flox=use_flox): - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) + shuffled = gb.shuffle_to_chunks().groupby_bins( + "dim_0", bins=bins, **cut_kwargs + ) + actual = gb.sum() assert_identical(expected, actual) + assert_identical(expected, shuffled.sum()) - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( - lambda x: x.sum() - ) + actual = gb.map(lambda x: x.sum()) assert_identical(expected, actual) + assert_identical(expected, shuffled.map(lambda x: x.sum())) # make sure original array dims are unchanged assert len(array.dim_0) == 4 @@ -1799,6 +1841,7 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: class TestDataArrayResample: + @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_cftime", [True, False]) @pytest.mark.parametrize( "resample_freq", @@ -1813,7 +1856,7 @@ class TestDataArrayResample: ], ) def test_resample( - self, use_cftime: bool, resample_freq: ResampleCompatible + self, use_cftime: bool, shuffle: bool, resample_freq: ResampleCompatible ) -> None: if use_cftime and not has_cftime: pytest.skip() @@ -1836,16 +1879,22 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time=resample_freq).mean() + rs = array.resample(time=resample_freq) + shuffled = rs.shuffle_to_chunks().resample(time=resample_freq) + actual = rs.mean() expected = resample_as_pandas(array, resample_freq) assert_identical(expected, actual) + assert_identical(expected, shuffled.mean()) - actual = array.resample(time=resample_freq).reduce(np.mean) - assert_identical(expected, actual) + assert_identical(expected, rs.reduce(np.mean)) + assert_identical(expected, shuffled.reduce(np.mean)) - actual = array.resample(time=resample_freq, closed="right").mean() - expected = resample_as_pandas(array, resample_freq, closed="right") + rs = array.resample(time="24h", closed="right") + actual = rs.mean() + shuffled = rs.shuffle_to_chunks().resample(time="24h", closed="right") + expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) + assert_identical(expected, shuffled.mean()) with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time=resample_freq) @@ -2673,6 +2722,9 @@ def factorize(self, group) -> EncodedGroups: codes = group.copy(data=codes_).rename("year") return EncodedGroups(codes=codes, full_index=pd.Index(uniques)) + def reset(self): + return type(self)() + da = xr.DataArray( dims="time", data=np.arange(20), @@ -2769,8 +2821,9 @@ def test_multiple_groupers_string(as_dataset) -> None: obj.groupby("labels1", foo=UniqueGrouper()) +@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers(use_flox) -> None: +def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: da = DataArray( np.array([1, 2, 3, 0, 2, np.nan]), dims="d", @@ -2781,7 +2834,11 @@ def test_multiple_groupers(use_flox) -> None: name="foo", ) - gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + groupers: dict[str, Grouper] + groupers = dict(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + gb = da.groupby(groupers) + if shuffle: + gb = gb.shuffle_to_chunks().groupby(groupers) repr(gb) expected = DataArray( @@ -2800,7 +2857,10 @@ def test_multiple_groupers(use_flox) -> None: # ------- coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) - gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper()) + groupers = dict(a=UniqueGrouper(), b=UniqueGrouper()) + gb = square.groupby(groupers) + if shuffle: + gb = gb.shuffle_to_chunks().groupby(groupers) repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2823,12 +2883,18 @@ def test_multiple_groupers(use_flox) -> None: coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) - gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + groupers = dict(x=UniqueGrouper(), y=UniqueGrouper()) + gb = b.groupby(groupers) + if shuffle: + gb = gb.shuffle_to_chunks().groupby(groupers) repr(gb) with xr.set_options(use_flox=use_flox): assert_identical(gb.mean("z"), b.mean("z")) - gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) + groupers = dict(x=UniqueGrouper(), xy=UniqueGrouper()) + gb = b.groupby(groupers) + if shuffle: + gb = gb.shuffle_to_chunks().groupby(groupers) repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2873,13 +2939,19 @@ def test_multiple_groupers(use_flox) -> None: @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers_mixed(use_flox) -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None: # This groupby has missing groups ds = xr.Dataset( {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, ) - gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + groupers: dict[str, Grouper] = dict( + x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper() + ) + gb = ds.groupby(groupers) + if shuffle: + gb = gb.shuffle_to_chunks().groupby(groupers) expected_data = np.array( [ [[0.0, np.nan], [np.nan, 3.0]], @@ -3107,6 +3179,50 @@ def test_groupby_multiple_bin_grouper_missing_groups(): assert_identical(actual, expected) +@requires_dask_ge_2024_08_1 +def test_shuffle_simple() -> None: + import dask + + da = xr.DataArray( + dims="x", + data=dask.array.from_array([1, 2, 3, 4, 5, 6], chunks=2), + coords={"label": ("x", "a b c a b c".split(" "))}, + ) + actual = da.groupby(label=UniqueGrouper()).shuffle_to_chunks() + expected = da.isel(x=[0, 3, 1, 4, 2, 5]) + assert_identical(actual, expected) + + with pytest.raises(ValueError): + da.chunk(x=2, eagerly_load_group=False).groupby("label").shuffle_to_chunks() + + +@requires_dask_ge_2024_08_1 +@pytest.mark.parametrize( + "chunks, expected_chunks", + [ + ((1,), (1, 3, 3, 3)), + ((10,), (10,)), + ], +) +def test_shuffle_by(chunks, expected_chunks): + import dask.array + + from xarray.groupers import UniqueGrouper + + da = xr.DataArray( + dims="x", + data=dask.array.arange(10, chunks=chunks), + coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + name="a", + ) + ds = da.to_dataset() + + for obj in [ds, da]: + actual = obj.groupby(x=UniqueGrouper()).shuffle_to_chunks() + assert_identical(actual, obj.sortby("x")) + assert actual.chunksizes["x"] == expected_chunks + + @requires_dask def test_groupby_dask_eager_load_warnings(): ds = xr.Dataset(