Skip to content

Commit

Permalink
Add GroupBy.shuffle()
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 7, 2024
1 parent e2981d3 commit 3bc51bd
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 1 deletion.
15 changes: 15 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,18 @@ def chunked_nanfirst(darray, axis):

def chunked_nanlast(darray, axis):
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)


def shuffle_array(array, indices: list[list[int]], axis: int):
# TODO: do chunk manager dance here.
if is_duck_dask_array(array):
if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
# TODO: handle dimensions
return array.shuffle(indexer=indices, axis=axis)
else:
indexer = np.concatenate(indices)
# TODO: Do the array API thing here.
return np.take(array, indices=indexer, axis=axis)
48 changes: 48 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,54 @@ def sizes(self) -> Mapping[Hashable, int]:
self._sizes = self._obj.isel({self._group_dim: index}).sizes
return self._sizes

def shuffle(self) -> None:
"""
Shuffle the underlying object so that all members in a group occur sequentially.
The order of appearance is not guaranteed. This method modifies the underlying Xarray
object in place.
Use this method first if you need to map a function that requires all members of a group
be in a single chunk.
"""
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.duck_array_ops import shuffle_array

(grouper,) = self.groupers
dim = self._group_dim

# Slices mean this is already sorted. E.g. resampling ops, _DummyGroup
if all(isinstance(idx, slice) for idx in self._group_indices):
return

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

shuffled = Dataset()
for name, var in as_dataset._variables.items():
if dim not in var.dims:
shuffled[name] = var
continue
shuffled_data = shuffle_array(
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
)
shuffled[name] = var._replace(data=shuffled_data)

# Replace self._group_indices with slices
slices = []
start = 0
for idxr in self._group_indices:
slices.append(slice(start, start + len(idxr)))
start += len(idxr)
# TODO: we have now broken the invariant
# self._group_indices ≠ self.groupers[0].group_indices
self._group_indices = tuple(slices)
if was_array:
self._obj = self._obj._from_temp_dataset(shuffled)
else:
self._obj = shuffled

def map(
self,
func: Callable,
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _importorskip(
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
has_cftime, requires_cftime = _importorskip("cftime")
has_dask, requires_dask = _importorskip("dask")
has_dask_ge_2024_08_0, _ = _importorskip("dask", minversion="2024.08.0")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
18 changes: 17 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
assert_identical,
create_test_data,
has_cftime,
has_dask_ge_2024_08_0,
has_flox,
requires_cftime,
requires_dask,
Expand Down Expand Up @@ -1293,11 +1294,26 @@ def test_groupby_sum(self) -> None:
assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y"))
assert_allclose(expected_sum_axis1, grouped.sum("y"))

@pytest.mark.parametrize(
"shuffle",
[
pytest.param(
True,
marks=pytest.mark.skipif(
not has_dask_ge_2024_08_0, reason="dask too old"
),
),
False,
],
)
@pytest.mark.parametrize("method", ["sum", "mean", "median"])
def test_groupby_reductions(self, method) -> None:
def test_groupby_reductions(self, method: str, shuffle: bool) -> None:
array = self.da
grouped = array.groupby("abc")

if shuffle:
grouped.shuffle()

reduction = getattr(np, method)
expected = Dataset(
{
Expand Down

0 comments on commit 3bc51bd

Please sign in to comment.