Skip to content

Commit 2eb7cb2

Browse files
authored
Merge pull request #103 from machow/feat-case-when-call-vals
Feat case when call vals
2 parents 6208f4b + 4984762 commit 2eb7cb2

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

siuba/dply/verbs.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -709,21 +709,40 @@ def _if_else(cond, true_vals, false_vals):
709709
# TODO: evaluate this non-table verb approach
710710
from siuba.siu import DictCall
711711

712+
def _val_call(call, data, n, indx = None):
713+
if not callable(call):
714+
return call
715+
716+
arr = call(data)
717+
if arr.shape != (n,):
718+
raise ValueError("Expected call to return array of shape {}"
719+
"but it returned shape {}".format(n, arr.shape))
720+
721+
return arr[indx] if indx is not None else arr
722+
723+
712724
@singledispatch2((pd.DataFrame,pd.Series))
713725
def case_when(__data, cases):
714726
if isinstance(cases, Call):
715727
cases = cases(__data)
716728
# TODO: handle when receive list of (k,v) pairs for py < 3.5 compat?
717-
out = np.repeat(None, len(__data))
718-
for k, v in reversed(list(cases.items())):
729+
730+
stripped_cases = {strip_symbolic(k): strip_symbolic(v) for k,v in cases.items()}
731+
n = len(__data)
732+
out = np.repeat(None, n)
733+
for k, v in reversed(list(stripped_cases.items())):
719734
if callable(k):
720-
result = k(__data)
735+
result = _val_call(k, __data, n)
721736
indx = np.where(result)[0]
722-
out[indx] = v
737+
738+
val_res = _val_call(v, __data, n, indx)
739+
out[indx] = val_res
723740
elif k:
724741
# e.g. k is just True, etc..
725-
out[:] = v
742+
val_res = _val_call(v, __data, n)
743+
out[:] = val_res
726744

745+
# by recreating an array, attempts to cast as best dtype
727746
return np.array(list(out))
728747

729748
@case_when.register(Symbolic)

siuba/tests/test_verb_case_when.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pandas as pd
2+
import numpy as np
3+
import pytest
4+
5+
from siuba.dply.verbs import case_when
6+
from numpy.testing import assert_equal
7+
from siuba.siu import _
8+
9+
DATA = pd.DataFrame({
10+
'x': [0,1,2],
11+
'y': [10, 11, 12]
12+
})
13+
14+
15+
@pytest.fixture
16+
def data():
17+
return DATA.copy()
18+
19+
20+
@pytest.mark.parametrize("k,v, res", [
21+
(True, 1, [1]*3),
22+
(True, False, [False]*3),
23+
(True, _.y, [10, 11, 12]),
24+
(True, lambda _: _.y, [10, 11, 12]),
25+
(_.x < 2, 0, [0, 0, None]),
26+
(_.x < 2, "small", ["small", "small", None]),
27+
(_.x < 2, _.y, [10, 11, None]),
28+
(lambda _: _.x < 2, 0, [0, 0, None]),
29+
#(np.array([True, True, False]), 0, [0, 0, None])
30+
])
31+
def test_case_when_single_cond(k, v, res, data):
32+
arr_res = np.array(res)
33+
out = case_when(data, {k: v})
34+
35+
assert_equal(out, arr_res)
36+
37+
38+
def test_case_when_cond_order(data):
39+
out = case_when(data, {
40+
lambda _: _.x < 2 : 0,
41+
True : 999
42+
})
43+
44+
assert_equal(out, np.array([0, 0, 999]))
45+

0 commit comments

Comments
 (0)