Skip to content

Commit

Permalink
Merge pull request #455 from machow/fix-across-grouped-sql
Browse files Browse the repository at this point in the history
fix: across handles sql grouped; handles pandas mutate with aggs
  • Loading branch information
machow authored Oct 25, 2022
2 parents 5f374a7 + 641e38f commit 652b4d9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 5 deletions.
9 changes: 8 additions & 1 deletion siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,16 @@ def _mutate_cols(__data, args, kwargs):
if not isinstance(res_arg, pd.DataFrame):
raise NotImplementedError("Only across() can be used as positional argument.")

# unpack result
is_scalar = len(res_arg) == 1

for col_name, col_ser in res_arg.items():
# need to put on the frame so subsequent args, kwargs can use
df_tmp[col_name] = col_ser
if is_scalar:
df_tmp.loc[:, col_name] = col_ser.iloc[0]
else:
df_tmp.loc[:, col_name] = col_ser.array

result_names[col_name] = True

for col_name, expr in kwargs.items():
Expand Down
25 changes: 21 additions & 4 deletions siuba/sql/across.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from siuba.dply.across import across, _get_name_template, _across_setup_fns, ctx_verb_data, ctx_verb_window
from siuba.dply.tidyselect import var_select, var_create
from siuba.siu import FormulaContext, Call
from siuba.siu import FormulaContext, Call, FormulaArg
from siuba.siu.calls import str_to_getitem_call
from siuba.siu.visitors import CallListener

from .backend import LazyTbl
from .utils import _sql_select, _sql_column_collection

from sqlalchemy import sql


class ReplaceFx(CallListener):
def __init__(self, replacement):
self.replacement = replacement

def exit(self, node):
res = super().exit(node)
if isinstance(res, FormulaArg):
return str_to_getitem_call(self.replacement)

return res


@across.register(LazyTbl)
def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) -> LazyTbl:
raise NotImplementedError(
Expand Down Expand Up @@ -49,20 +63,23 @@ def _across_sql_cols(
old_name = new_name

crnt_col = __data[old_name]
context = FormulaContext(Fx=crnt_col, _=__data)
#context = FormulaContext(Fx=crnt_col, _=__data)

# iterate over functions ----
for fn_name, fn in fns_map.items():
fmt_pars = {"fn": fn_name, "col": new_name}

fn_replaced = ReplaceFx(old_name).enter(fn)
new_call = lazy_tbl.shape_call(
fn,
fn_replaced,
window,
verb_name="Across",
arg_name = f"function {fn_name} of {len(fns_map)}"
)

res = new_call(context)
res, windows, _ = lazy_tbl.track_call_windows(new_call, __data)

#res = new_call(context)
res_name = name_template.format(**fmt_pars)
results.append(res.label(res_name))

Expand Down
49 changes: 49 additions & 0 deletions siuba/tests/test_verb_across.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,31 @@ def test_across_in_mutate_grouped_equiv_ungrouped(backend, df):
assert_equal_query2(ungroup(g_res), collect(dst))


def test_across_in_mutate_grouped_agg(backend):
df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]})
src = backend.load_df(df)
g_src = group_by(src, "g")

expr_across = across(_, _[_.x, _.y], f_mean)
g_res = mutate(g_src, expr_across)
dst = mutate(df.groupby("g"), expr_across)

assert_grouping_names(g_res, ["g"])
assert_equal_query2(ungroup(g_res), ungroup(dst))


def test_across_in_mutate_grouped_transform(backend):
df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]})
src = backend.load_df(df)
g_src = group_by(src, "g")

expr_across = across(_, _[_.x, _.y], Fx.rank())
g_res = mutate(g_src, expr_across)
dst = pd.DataFrame({"x": [1., 2, 1], "y": [1., 2, 1], "g": [1, 1, 2]})

assert_grouping_names(g_res, ["g"])
assert_equal_query2(ungroup(g_res), dst)

def test_across_in_summarize(backend, df):
src = backend.load_df(df)
res = summarize(src, across(_, _[_.a_x, _.a_y], f_mean))
Expand Down Expand Up @@ -207,6 +232,18 @@ def test_across_in_summarize_equiv_ungrouped(backend):
assert_frame_sort_equal(collected.drop(columns="g"), collect(dst))


def test_across_in_summarize_grouped(backend):
df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]})
src = backend.load_df(df)
g_src = group_by(src, "g")

expr_across = across(_, _[_.x, _.y], f_mean)
g_res = summarize(g_src, expr_across)
dst = pd.DataFrame({"g": [1, 2], "x": [1.5, 3], "y": [7.5, 9]})

assert_equal_query2(ungroup(g_res), dst)


def test_across_in_filter(backend, df):
src = backend.load_df(df)
res = filter(src, across(_, _[_.a_x, _.a_y], Fx % 2 > 0))
Expand All @@ -227,6 +264,18 @@ def test_across_in_filter_equiv_ungrouped(backend, df):
assert_equal_query2(g_res.obj, dst)


def test_across_in_filter_grouped(backend):
df = pd.DataFrame({"x": [1, 2, 3], "y": [7, 8, 9], "g": [1, 1, 2]})
src = backend.load_df(df)
g_src = group_by(src, "g")

expr_across = across(_, _[_.x, _.y], Fx >= Fx.mean())
g_res = filter(g_src, expr_across)
dst = pd.DataFrame({"x": [2, 3], "y": [8, 9], "g": [1, 2]}, index=[1, 2])

assert_equal_query2(ungroup(g_res), dst)


@pytest.mark.parametrize("f", [
#(arrange),
(verbs.count),
Expand Down

0 comments on commit 652b4d9

Please sign in to comment.