diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 7ece9ad1..e1bdf55a 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -4,6 +4,7 @@ import numpy as np from pandas.core.groupby import DataFrameGroupBy +from pandas.core.dtypes.inference import is_scalar from siuba.siu import Symbolic, Call, strip_symbolic, MetaArg, BinaryOp, create_sym_call, Lazy DPLY_FUNCTIONS = ( @@ -391,10 +392,15 @@ def summarize(__data, **kwargs): for k, v in kwargs.items(): res = v(__data) if callable(v) else v - # TODO: validation? + # validate operations returned single result + if not is_scalar(res) and len(res) > 1: + raise ValueError("Summarize argument, %s, must return result of length 1 or a scalar." % k) - results[k] = res + # keep result, but use underlying array to avoid crazy index issues + # on DataFrame construction (#138) + results[k] = res.array if isinstance(res, pd.Series) else res + # must pass index, or raises error when using all scalar values return DataFrame(results, index = [0]) diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py index b645a59b..aaa80633 100644 --- a/siuba/tests/test_verb_summarize.py +++ b/siuba/tests/test_verb_summarize.py @@ -99,10 +99,25 @@ def test_summarize_unnamed_args(df): ) -@pytest.mark.skip("TODO: Summarize should fail when result len > 1 (#138)") +def test_summarize_validates_length(): + with pytest.raises(ValueError): + summarize(data_frame(x = [1,2]), res = _.x + 1) + + def test_frame_mode_returns_many(): + # related to length validation above with pytest.raises(ValueError): df = data_frame(x = [1, 2, 3]) res = summarize(df, result = _.x.mode()) +def test_summarize_removes_series_index(): + # Note: currently wouldn't work in postgresql, since _.x + _.y not an agg func + df = data_frame(g = ['a', 'b', 'c'], x = [1,2,3], y = [4,5,6]) + + assert_equal_query( + df, + group_by(_.g) >> summarize(res = _.x + _.y), + df.assign(res = df.x + df.y).drop(columns = ["x", "y"]) + ) +