diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8993c136ba6..25bd86177df 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -831,3 +831,18 @@ def chunked_nanfirst(darray, axis): def chunked_nanlast(darray, axis): return _chunked_first_or_last(darray, axis, op=nputils.nanlast) + + +def shuffle_array(array, indices: list[list[int]], axis: int): + # TODO: do chunk manager dance here. + if is_duck_dask_array(array): + if not module_available("dask", minversion="2024.08.0"): + raise ValueError( + "This method is very inefficient on dask<2024.08.0. Please upgrade." + ) + # TODO: handle dimensions + return array.shuffle(indexer=indices, axis=axis) + else: + indexer = np.concatenate(indices) + # TODO: Do the array API thing here. + return np.take(array, indices=indexer, axis=axis) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9b0758d030b..9fbf6778aea 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -517,6 +517,54 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + def shuffle(self) -> None: + """ + Shuffle the underlying object so that all members in a group occur sequentially. + + The order of appearance is not guaranteed. This method modifies the underlying Xarray + object in place. + + Use this method first if you need to map a function that requires all members of a group + be in a single chunk. + """ + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.duck_array_ops import shuffle_array + + (grouper,) = self.groupers + dim = self._group_dim + + # Slices mean this is already sorted. E.g. resampling ops, _DummyGroup + if all(isinstance(idx, slice) for idx in self._group_indices): + return + + was_array = isinstance(self._obj, DataArray) + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + + shuffled = Dataset() + for name, var in as_dataset._variables.items(): + if dim not in var.dims: + shuffled[name] = var + continue + shuffled_data = shuffle_array( + var._data, list(self._group_indices), axis=var.get_axis_num(dim) + ) + shuffled[name] = var._replace(data=shuffled_data) + + # Replace self._group_indices with slices + slices = [] + start = 0 + for idxr in self._group_indices: + slices.append(slice(start, start + len(idxr))) + start += len(idxr) + # TODO: we have now broken the invariant + # self._group_indices ≠ self.groupers[0].group_indices + self._group_indices = tuple(slices) + if was_array: + self._obj = self._obj._from_temp_dataset(shuffled) + else: + self._obj = shuffled + def map( self, func: Callable, diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0caab6e8247..31d8e88dde1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -106,6 +106,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6c9254966d9..c41086cdf97 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -21,6 +21,7 @@ assert_identical, create_test_data, has_cftime, + has_dask_ge_2024_08_0, has_flox, requires_cftime, requires_dask, @@ -1293,11 +1294,26 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) + @pytest.mark.parametrize( + "shuffle", + [ + pytest.param( + True, + marks=pytest.mark.skipif( + not has_dask_ge_2024_08_0, reason="dask too old" + ), + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method) -> None: + def test_groupby_reductions(self, method: str, shuffle: bool) -> None: array = self.da grouped = array.groupby("abc") + if shuffle: + grouped.shuffle() + reduction = getattr(np, method) expected = Dataset( {