Skip to content
forked from pydata/xarray

Commit

Permalink
Handle multiple groupers
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 20, 2024
1 parent dfdc96a commit a15b04d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
27 changes: 17 additions & 10 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.duck_array_ops import where
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
Expand Down Expand Up @@ -462,20 +463,26 @@ def factorize(self) -> EncodedGroups:
# NaNs; as well as values outside the bins are coded by -1
# Restore these after the raveling
mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type]
_flatcodes[mask] = -1

midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# Constructing an index from the product is wrong when there are missing groups
# (e.g. binning, resampling). Account for that now.
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
_flatcodes = where(mask, -1, _flatcodes)

full_index = pd.MultiIndex.from_product(
(grouper.full_index.values for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# This will be unused when grouping by dask arrays, so skip..
if not is_chunked_array(_flatcodes):
midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# Constructing an index from the product is wrong when there are missing groups
# (e.g. binning, resampling). Account for that now.
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index))
else:
midx = full_index
group_indices = None

dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)

coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
Expand All @@ -484,7 +491,7 @@ def factorize(self) -> EncodedGroups:
return EncodedGroups(
codes=first_codes.copy(data=_flatcodes),
full_index=full_index,
group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)),
group_indices=group_indices,
unique_coord=Variable(dims=(dim_name,), data=midx.values),
coords=coords,
)
Expand Down
5 changes: 4 additions & 1 deletion xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def factorize(self, group: T_Group) -> EncodedGroups:
self.group = group

if is_chunked_array(group.data) and self.labels is None:
raise ValueError("When grouping by a dask array, `labels` must be passed.")
raise ValueError(
"When grouping by a dask array, `labels` must be passed using "
"a UniqueGrouper object."
)
if self.labels is not None:
return self._factorize_given_labels(group)

Expand Down
20 changes: 18 additions & 2 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
TimeResampler,
UniqueGrouper,
)
from xarray.namedarray.pycompat import is_chunked_array
from xarray.tests import (
InaccessibleArray,
assert_allclose,
assert_equal,
assert_identical,
create_test_data,
has_cftime,
has_dask,
has_flox,
has_pandas_ge_2_2,
raise_if_dask_computes,
Expand Down Expand Up @@ -2796,7 +2798,7 @@ def test_multiple_groupers(use_flox) -> None:

b = xr.DataArray(
np.random.RandomState(0).randn(2, 3, 4),
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])},
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})},
dims=["x", "y", "z"],
)
gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper())
Expand All @@ -2813,10 +2815,24 @@ def test_multiple_groupers(use_flox) -> None:
expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data
expected.loc[dict(x=1, xy=0)] = np.nan
expected.loc[dict(x=1, xy=2)] = newval
expected["xy"] = ("xy", ["a", "b", "c"])
expected["xy"] = ("xy", ["a", "b", "c"], {"foo": "bar"})
# TODO: is order of dims correct?
assert_identical(actual, expected.transpose("z", "x", "xy"))

if has_dask:
b["xy"] = b["xy"].chunk()
with raise_if_dask_computes():
gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]))

expected = xr.DataArray(
[[[1, 1, 1], [0, 1, 2]]] * 4,
dims=("z", "x", "xy"),
coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})},
)
assert_identical(gb.count(), expected)
assert is_chunked_array(gb.encoded.codes.data)
assert not gb.encoded.group_indices


@pytest.mark.parametrize("use_flox", [True, False])
def test_multiple_groupers_mixed(use_flox) -> None:
Expand Down

0 comments on commit a15b04d

Please sign in to comment.