diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index b18f54f51..c534a24df 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -399,6 +399,16 @@ class Size(SingleAggregation): groupby_aggregate = M.sum +class IdxMin(SingleAggregation): + groupby_chunk = M.idxmin + groupby_aggregate = M.first + + +class IdxMax(IdxMin): + groupby_chunk = M.idxmax + groupby_aggregate = M.first + + class ValueCounts(SingleAggregation): groupby_chunk = staticmethod(_value_counts) groupby_aggregate = staticmethod(_value_counts_aggregate) @@ -1010,7 +1020,7 @@ def __init__( def _numeric_only_kwargs(self, numeric_only): kwargs = {"numeric_only": numeric_only} - return {"chunk_kwargs": kwargs, "aggregate_kwargs": kwargs} + return {"chunk_kwargs": kwargs.copy(), "aggregate_kwargs": kwargs.copy()} def _single_agg( self, @@ -1128,6 +1138,26 @@ def size(self, **kwargs): def value_counts(self, **kwargs): return self._single_agg(ValueCounts, **kwargs) + def idxmin( + self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs + ): + # TODO: Add shuffle and remove kwargs + numeric_kwargs = self._numeric_only_kwargs(numeric_only) + numeric_kwargs["chunk_kwargs"]["skipna"] = skipna + return self._single_agg( + IdxMin, split_every=split_every, split_out=split_out, **numeric_kwargs + ) + + def idxmax( + self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs + ): + # TODO: Add shuffle and remove kwargs + numeric_kwargs = self._numeric_only_kwargs(numeric_only) + numeric_kwargs["chunk_kwargs"]["skipna"] = skipna + return self._single_agg( + IdxMax, split_every=split_every, split_out=split_out, **numeric_kwargs + ) + def head(self, n=5, split_every=None, split_out=1): chunk_kwargs = {"n": n} aggregate_kwargs = { @@ -1372,6 +1402,28 @@ def aggregate(self, arg=None, split_every=8, split_out=1, **kwargs): agg = aggregate + def idxmin( + self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs + ): + # pandas doesn't support numeric_only here, which is odd + return self._single_agg( + IdxMin, + split_every=None, + split_out=split_out, + chunk_kwargs=dict(skipna=skipna), + ) + + def idxmax( + self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs + ): + # pandas doesn't support numeric_only here, which is odd + return self._single_agg( + IdxMax, + split_every=split_every, + split_out=split_out, + chunk_kwargs=dict(skipna=skipna), + ) + def nunique(self, split_every=None, split_out=True): slice = self._slice or self.obj.name return new_collection( diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index f368377dd..06e129430 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -34,7 +34,8 @@ def test_groupby_unsupported_by(pdf, df): @pytest.mark.parametrize("split_every", [None, 5]) @pytest.mark.parametrize( - "api", ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std"] + "api", + ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"], ) @pytest.mark.parametrize( "numeric_only",