Skip to content

Commit 8ac507c

Browse files
authored
Implement GroupBy.idxmin and GroupBy.idxmax (#585)
1 parent 75dce0b commit 8ac507c

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

dask_expr/_groupby.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,16 @@ class Size(SingleAggregation):
402402
groupby_aggregate = M.sum
403403

404404

405+
class IdxMin(SingleAggregation):
406+
groupby_chunk = M.idxmin
407+
groupby_aggregate = M.first
408+
409+
410+
class IdxMax(IdxMin):
411+
groupby_chunk = M.idxmax
412+
groupby_aggregate = M.first
413+
414+
405415
class ValueCounts(SingleAggregation):
406416
groupby_chunk = staticmethod(_value_counts)
407417
groupby_aggregate = staticmethod(_value_counts_aggregate)
@@ -1043,7 +1053,7 @@ def __init__(
10431053

10441054
def _numeric_only_kwargs(self, numeric_only):
10451055
kwargs = {"numeric_only": numeric_only}
1046-
return {"chunk_kwargs": kwargs, "aggregate_kwargs": kwargs}
1056+
return {"chunk_kwargs": kwargs.copy(), "aggregate_kwargs": kwargs.copy()}
10471057

10481058
def _single_agg(
10491059
self,
@@ -1183,6 +1193,26 @@ def size(self, **kwargs):
11831193
def value_counts(self, **kwargs):
11841194
return self._single_agg(ValueCounts, **kwargs)
11851195

1196+
def idxmin(
1197+
self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs
1198+
):
1199+
# TODO: Add shuffle and remove kwargs
1200+
numeric_kwargs = self._numeric_only_kwargs(numeric_only)
1201+
numeric_kwargs["chunk_kwargs"]["skipna"] = skipna
1202+
return self._single_agg(
1203+
IdxMin, split_every=split_every, split_out=split_out, **numeric_kwargs
1204+
)
1205+
1206+
def idxmax(
1207+
self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs
1208+
):
1209+
# TODO: Add shuffle and remove kwargs
1210+
numeric_kwargs = self._numeric_only_kwargs(numeric_only)
1211+
numeric_kwargs["chunk_kwargs"]["skipna"] = skipna
1212+
return self._single_agg(
1213+
IdxMax, split_every=split_every, split_out=split_out, **numeric_kwargs
1214+
)
1215+
11861216
def head(self, n=5, split_every=None, split_out=1):
11871217
chunk_kwargs = {"n": n}
11881218
aggregate_kwargs = {
@@ -1431,6 +1461,28 @@ def aggregate(self, arg=None, split_every=8, split_out=1, **kwargs):
14311461

14321462
agg = aggregate
14331463

1464+
def idxmin(
1465+
self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs
1466+
):
1467+
# pandas doesn't support numeric_only here, which is odd
1468+
return self._single_agg(
1469+
IdxMin,
1470+
split_every=None,
1471+
split_out=split_out,
1472+
chunk_kwargs=dict(skipna=skipna),
1473+
)
1474+
1475+
def idxmax(
1476+
self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs
1477+
):
1478+
# pandas doesn't support numeric_only here, which is odd
1479+
return self._single_agg(
1480+
IdxMax,
1481+
split_every=split_every,
1482+
split_out=split_out,
1483+
chunk_kwargs=dict(skipna=skipna),
1484+
)
1485+
14341486
def nunique(self, split_every=None, split_out=True):
14351487
slice = self._slice or self.obj.name
14361488
return new_collection(

dask_expr/tests/test_groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def test_groupby_unsupported_by(pdf, df):
3434

3535
@pytest.mark.parametrize("split_every", [None, 5])
3636
@pytest.mark.parametrize(
37-
"api", ["sum", "mean", "min", "max", "prod", "first", "last", "var", "std"]
37+
"api",
38+
["sum", "mean", "min", "max", "prod", "first", "last", "var", "std", "idxmin"],
3839
)
3940
@pytest.mark.parametrize(
4041
"numeric_only",

0 commit comments

Comments
 (0)