diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index b1ee575a..e7583f81 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -657,19 +657,28 @@ def _distinct(__data, *args, _keep_all = False, **kwargs): def if_else(__data, *args, **kwargs): """ Example: - >>> ser1 = pd.Series([1,2,3,4]) - >>> if_else(ser1 > 2, np.nan, ser1) # doctest: +SKIP - array([ 1., 2., nan, nan]) + >>> ser1 = pd.Series([1,2,3]) + >>> if_else(ser1 > 2, np.nan, ser1) + 0 1.0 + 1 2.0 + 2 NaN + dtype: float64 >>> from siuba import _ - >>> f = if_else(_ < 3, _, 3) + >>> f = if_else(_ < 2, _, 2) >>> f(ser1) - array([1, 2, 3, 3]) + 0 1 + 1 2 + 2 2 + dtype: int64 >>> import numpy as np >>> ser2 = pd.Series(['NA', 'a', 'b']) >>> if_else(ser2 == 'NA', np.nan, ser2) - array([nan, 'a', 'b'], dtype=object) + 0 NaN + 1 a + 2 b + dtype: object """ raise_type_error(__data) @@ -683,9 +692,7 @@ def _if_else(__data, *args, **kwargs): def _if_else(cond, true_vals, false_vals): result = np.where(cond.fillna(False), true_vals, false_vals) - # TODO: should functions that take a Series, return a Series? - # for now, just return "O" type. Sort out once better research. - return result + return pd.Series(result) # case_when ---------------- @@ -729,7 +736,7 @@ def case_when(__data, cases): out[:] = val_res # by recreating an array, attempts to cast as best dtype - return np.array(list(out)) + return pd.Series(list(out)) @case_when.register(Symbolic) @case_when.register(Call) diff --git a/siuba/tests/test_verb_case_when.py b/siuba/tests/test_verb_case_when.py index de9998b7..c13f5285 100644 --- a/siuba/tests/test_verb_case_when.py +++ b/siuba/tests/test_verb_case_when.py @@ -3,6 +3,7 @@ import pytest from siuba.dply.verbs import case_when +from pandas.testing import assert_series_equal from numpy.testing import assert_equal from siuba.siu import _ @@ -29,10 +30,9 @@ def data(): #(np.array([True, True, False]), 0, [0, 0, None]) ]) def test_case_when_single_cond(k, v, res, data): - arr_res = np.array(res) out = case_when(data, {k: v}) - assert_equal(out, arr_res) + assert_series_equal(out, pd.Series(res)) def test_case_when_cond_order(data): @@ -41,5 +41,5 @@ def test_case_when_cond_order(data): True : 999 }) - assert_equal(out, np.array([0, 0, 999])) + assert_series_equal(out, pd.Series([0, 0, 999]))