Skip to content

Commit 162fdb9

Browse files
committed
Fixes #79
1 parent c292eee commit 162fdb9

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

numpy_groupies/aggregate_numpy.py

+3
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def _nancumsum(group_idx, a, size, fill_value=None, dtype=None):
272272
generic=_generic_callable,
273273
)
274274
_impl_dict.update(("nan" + k, v) for k, v in list(_impl_dict.items()) if k not in funcs_no_separate_nan)
275+
_impl_dict["nancumsum"] = _nancumsum
275276

276277

277278
def _aggregate_base(
@@ -308,6 +309,8 @@ def _aggregate_base(
308309
if "nan" in func:
309310
if "arg" in func:
310311
kwargs["_nansqueeze"] = True
312+
elif "cum" in func:
313+
pass
311314
else:
312315
good = ~np.isnan(a)
313316
if "len" not in func or is_pandas:

numpy_groupies/tests/test_compare.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
may throw NotImplementedError in order to show missing functionality without throwing
55
test errors.
66
"""
7-
import sys
87
from itertools import product
98

109
import numpy as np

numpy_groupies/tests/test_generic.py

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def _deselect_purepy(aggregate_all, *args, **kwargs):
2424
return aggregate_all.__name__.endswith("purepy")
2525

2626

27+
def _deselect_purepy_and_pandas(aggregate_all, *args, **kwargs):
28+
# purepy and pandas implementation handle some nan cases differently.
29+
# So they need to be excluded from several tests."""
30+
return aggregate_all.__name__.endswith(("pandas", "purepy"))
31+
32+
2733
def _deselect_purepy_and_invalid_axis(aggregate_all, size, axis, *args, **kwargs):
2834
if axis >= len(size):
2935
return True
@@ -358,6 +364,17 @@ def test_cumsum(aggregate_all):
358364
np.testing.assert_array_equal(res, ref)
359365

360366

367+
@pytest.mark.deselect_if(func=_deselect_purepy_and_pandas)
368+
def test_nancumsum(aggregate_all):
369+
# https://github.com/ml31415/numpy-groupies/issues/79
370+
group_idx = [0, 0, 0, 1, 1, 0, 0]
371+
a = [2, 2, np.nan, 2, 2, 2, 2]
372+
ref = [2., 4., 4., 2., 4., 6., 8.]
373+
374+
res = aggregate_all(group_idx, a, func="nancumsum")
375+
np.testing.assert_array_equal(res, ref)
376+
377+
361378
def test_cummax(aggregate_all):
362379
group_idx = np.array([4, 3, 3, 4, 4, 1, 1, 1, 7, 8, 7, 4, 3, 3, 1, 1])
363380
a = np.array([3, 4, 1, 3, 9, 9, 6, 7, 7, 0, 8, 2, 1, 8, 9, 8])

numpy_groupies/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@
117117
np.array: "array",
118118
np.asarray: "array",
119119
np.sort: "sort",
120+
np.cumsum: "cumsum",
121+
np.cumprod: "cumprod",
120122
np.nansum: "nansum",
121123
np.nanprod: "nanprod",
122124
np.nanmean: "nanmean",
@@ -126,8 +128,7 @@
126128
np.nanstd: "nanstd",
127129
np.nanargmax: "nanargmax",
128130
np.nanargmin: "nanargmin",
129-
np.cumsum: "cumsum",
130-
np.cumprod: "cumprod",
131+
np.nancumsum: "nancumsum",
131132
}
132133

133134

@@ -150,7 +151,7 @@ def get_aliasing(*extra):
150151
alias.update((k, k) for k in set(alias.values()))
151152
# Treat nan-functions as firstclass member and add them directly
152153
for key in set(alias.values()):
153-
if key not in funcs_no_separate_nan:
154+
if key not in funcs_no_separate_nan and not key.startswith("nan"):
154155
key = "nan" + key
155156
alias[key] = key
156157
return alias

0 commit comments

Comments
 (0)