diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index c7eba6fb..f138df8d 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -133,9 +133,16 @@ def _mutate_cols(__data, args, kwargs): if not isinstance(res_arg, pd.DataFrame): raise NotImplementedError("Only across() can be used as positional argument.") + # unpack result + is_scalar = len(res_arg) == 1 + for col_name, col_ser in res_arg.items(): # need to put on the frame so subsequent args, kwargs can use - df_tmp[col_name] = col_ser + if is_scalar: + df_tmp.loc[:, col_name] = col_ser.iloc[0] + else: + df_tmp.loc[:, col_name] = col_ser.array + result_names[col_name] = True for col_name, expr in kwargs.items(): diff --git a/siuba/sql/across.py b/siuba/sql/across.py index 0b7aeedb..59a2ba5e 100644 --- a/siuba/sql/across.py +++ b/siuba/sql/across.py @@ -1,6 +1,8 @@ from siuba.dply.across import across, _get_name_template, _across_setup_fns, ctx_verb_data, ctx_verb_window from siuba.dply.tidyselect import var_select, var_create -from siuba.siu import FormulaContext, Call +from siuba.siu import FormulaContext, Call, FormulaArg +from siuba.siu.calls import str_to_getitem_call +from siuba.siu.visitors import CallListener from .backend import LazyTbl from .utils import _sql_select, _sql_column_collection @@ -8,6 +10,18 @@ from sqlalchemy import sql +class ReplaceFx(CallListener): + def __init__(self, replacement): + self.replacement = replacement + + def exit(self, node): + res = super().exit(node) + if isinstance(res, FormulaArg): + return str_to_getitem_call(self.replacement) + + return res + + @across.register(LazyTbl) def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) -> LazyTbl: raise NotImplementedError( @@ -49,20 +63,23 @@ def _across_sql_cols( old_name = new_name crnt_col = __data[old_name] - context = FormulaContext(Fx=crnt_col, _=__data) + #context = FormulaContext(Fx=crnt_col, _=__data) # iterate over functions ---- for fn_name, fn in fns_map.items(): fmt_pars = {"fn": fn_name, "col": new_name} + fn_replaced = ReplaceFx(old_name).enter(fn) new_call = lazy_tbl.shape_call( - fn, + fn_replaced, window, verb_name="Across", arg_name = f"function {fn_name} of {len(fns_map)}" ) - res = new_call(context) + res, windows, _ = lazy_tbl.track_call_windows(new_call, __data) + + #res = new_call(context) res_name = name_template.format(**fmt_pars) results.append(res.label(res_name)) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index ddad1fd5..6d954524 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -172,6 +172,31 @@ def test_across_in_mutate_grouped_equiv_ungrouped(backend, df): assert_equal_query2(ungroup(g_res), collect(dst)) +def test_across_in_mutate_grouped_agg(backend): + df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]}) + src = backend.load_df(df) + g_src = group_by(src, "g") + + expr_across = across(_, _[_.x, _.y], f_mean) + g_res = mutate(g_src, expr_across) + dst = mutate(df.groupby("g"), expr_across) + + assert_grouping_names(g_res, ["g"]) + assert_equal_query2(ungroup(g_res), ungroup(dst)) + + +def test_across_in_mutate_grouped_transform(backend): + df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]}) + src = backend.load_df(df) + g_src = group_by(src, "g") + + expr_across = across(_, _[_.x, _.y], Fx.rank()) + g_res = mutate(g_src, expr_across) + dst = pd.DataFrame({"x": [1., 2, 1], "y": [1., 2, 1], "g": [1, 1, 2]}) + + assert_grouping_names(g_res, ["g"]) + assert_equal_query2(ungroup(g_res), dst) + def test_across_in_summarize(backend, df): src = backend.load_df(df) res = summarize(src, across(_, _[_.a_x, _.a_y], f_mean)) @@ -207,6 +232,18 @@ def test_across_in_summarize_equiv_ungrouped(backend): assert_frame_sort_equal(collected.drop(columns="g"), collect(dst)) +def test_across_in_summarize_grouped(backend): + df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]}) + src = backend.load_df(df) + g_src = group_by(src, "g") + + expr_across = across(_, _[_.x, _.y], f_mean) + g_res = summarize(g_src, expr_across) + dst = pd.DataFrame({"g": [1, 2], "x": [1.5, 3], "y": [7.5, 9]}) + + assert_equal_query2(ungroup(g_res), dst) + + def test_across_in_filter(backend, df): src = backend.load_df(df) res = filter(src, across(_, _[_.a_x, _.a_y], Fx % 2 > 0)) @@ -227,6 +264,18 @@ def test_across_in_filter_equiv_ungrouped(backend, df): assert_equal_query2(g_res.obj, dst) +def test_across_in_filter_grouped(backend): + df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]}) + src = backend.load_df(df) + g_src = group_by(src, "g") + + expr_across = across(_, _[_.x, _.y], Fx >= Fx.mean()) + g_res = filter(g_src, expr_across) + dst = pd.DataFrame({"x": [2, 3], "y": [8, 9], "g": [1, 2]}, index=[1, 2]) + + assert_equal_query2(ungroup(g_res), dst) + + @pytest.mark.parametrize("f", [ #(arrange), (verbs.count),