From b3017a1deb6691b352bec2ac3ff08fadcb13dce1 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 6 Sep 2022 21:54:20 -0400 Subject: [PATCH 01/27] feat(pandas): across works inside mutate, filter, summarize --- siuba/dply/across.py | 113 +++++++++++++++ siuba/dply/tidyselect.py | 44 +++++- siuba/dply/verbs.py | 92 ++++++++++--- siuba/siu/__init__.py | 3 + siuba/siu/calls.py | 42 +++++- siuba/tests/test_verb_across.py | 236 ++++++++++++++++++++++++++++++++ 6 files changed, 510 insertions(+), 20 deletions(-) create mode 100644 siuba/dply/across.py create mode 100644 siuba/tests/test_verb_across.py diff --git a/siuba/dply/across.py b/siuba/dply/across.py new file mode 100644 index 00000000..4fe6ed56 --- /dev/null +++ b/siuba/dply/across.py @@ -0,0 +1,113 @@ +import pandas as pd +from pandas.api import types as pd_types + +from pandas.core.groupby import DataFrameGroupBy +from .verbs import var_select, var_create +from ..siu import FormulaContext, Call, strip_symbolic, Fx, call +from ..siu.dispatchers import verb_dispatch, symbolic_dispatch + +from collections.abc import Mapping +from typing import Callable, Any + +DEFAULT_MULTI_FUNC_TEMPLATE = "{col}_{fn}" +DEFAULT_SINGLE_FUNC_TEMPLATE = "{col}" + + +# TODO: handle DataFrame manipulation in pandas / sql backends +class AcrossResult(Mapping): + def __init__(self, *args, **kwargs): + self.d = dict(*args, **kwargs) + + def __getitem__(self, k): + return self.d[k] + + def __iter__(self): + return iter(self.d) + + def __len__(self): + return len(self.d) + + +def _across_setup_fns(fns) -> "dict[str, Callable[[FormulaContext], Any]]": + final_calls = {} + if isinstance(fns, (list, tuple)): + raise NotImplementedError( + "Specifying functions as a list or tuple is not supported. " + "Please use a dictionary to define multiple functions to apply. \n\n" + "E.g. across(_[:], {'round': Fx.round(), 'round2': Fx.round() + 1})" + ) + elif isinstance(fns, dict): + for name, fn_call_raw in fns.items(): + # symbolics get stripped by default for arguments to verbs, but + # these are inside a dictionary, so need to strip manually. + fn_call = strip_symbolic(fn_call_raw) + + if not isinstance(fn_call, Call): + raise TypeError( + "All functions to be applied in across must be a siuba.siu.Call, " + f"but received a function of type {type(fn_call)}" + ) + + final_calls[name] = fn_call + + elif isinstance(fns, Call): + final_calls["fn1"] = fns + + elif callable(fns): + final_calls["fn1"] = call(fns, Fx) + + else: + raise NotImplementedError(f"Unsupported function type in across: {type(fns)}") + + return final_calls + + +def _get_name_template(fns, names: "str | None") -> str: + if names is not None: + return names + + if callable(fns): + return DEFAULT_SINGLE_FUNC_TEMPLATE + + return DEFAULT_MULTI_FUNC_TEMPLATE + +@verb_dispatch(pd.DataFrame) +def across(__data, cols, fns, names: "str | None" = None) -> pd.DataFrame: + + name_template = _get_name_template(fns, names) + selected_cols = var_select(__data.columns, *var_create(cols), data=__data) + + fns_map = _across_setup_fns(fns) + + results = {} + for old_name, new_name in selected_cols.items(): + if new_name is None: + new_name = old_name + + crnt_ser = __data[old_name] + context = FormulaContext(Fx=crnt_ser, _=__data) + + for fn_name, fn in fns_map.items(): + fmt_pars = {"fn": fn_name, "col": new_name} + + res = fn(context) + results[name_template.format(**fmt_pars)] = res + + # ensure at least one result is not a scalar, so we don't get the classic + # pandas error: "If using all scalar values, you must pass an index" + index = None + if results: + _, v = next(iter(results.items())) + if pd_types.is_scalar(v): + index = [0] + + return pd.DataFrame(results, index=index) + + +@symbolic_dispatch(cls = pd.Series) +def where(x) -> bool: + if not isinstance(x, bool): + raise TypeError("Result of where clause must be a boolean (True or False).") + + return x + diff --git a/siuba/dply/tidyselect.py b/siuba/dply/tidyselect.py index c86169dd..4df8e412 100644 --- a/siuba/dply/tidyselect.py +++ b/siuba/dply/tidyselect.py @@ -3,6 +3,9 @@ from siuba.siu import Call, MetaArg, BinaryOp from collections import OrderedDict from itertools import chain +from functools import singledispatch + +from typing import List class Var: def __init__(self, name: "str | int | slice | Call", negated = False, alias = None): @@ -137,7 +140,7 @@ def flatten_var(var): return [var] -def var_select(colnames, *args): +def var_select(colnames, *args, data=None): # TODO: don't erase named column if included again colnames = colnames if isinstance(colnames, pd.Series) else pd.Series(colnames) cols = OrderedDict() @@ -147,12 +150,15 @@ def var_select(colnames, *args): # Add entries in pandas.rename style {"orig_name": "new_name"} for ii, arg in enumerate(all_vars): + # strings are added directly if isinstance(arg, str): cols[arg] = None + # integers add colname at corresponding index elif isinstance(arg, int): cols[colnames.iloc[arg]] = None + # general var handling elif isinstance(arg, Var): # remove negated Vars, otherwise include them @@ -165,6 +171,7 @@ def var_select(colnames, *args): start, stop = var_slice(colnames, arg.name) for ii in range(start, stop): var_put_cols(colnames[ii], arg, cols) + # method calls like endswith() elif callable(arg.name): # TODO: not sure if this is a good idea... @@ -176,6 +183,14 @@ def var_select(colnames, *args): var_put_cols(colnames.iloc[arg.name], arg, cols) else: var_put_cols(arg.name, arg, cols) + elif callable(arg) and data is not None: + # TODO: call on the data + col_mask = colwise_eval(data, arg) + + for name in colnames[col_mask]: + cols[name] = None + + else: raise Exception("variable must be either a string or Var instance") @@ -186,14 +201,39 @@ def var_create(*args) -> "tuple[Var]": vl = VarList() all_vars = [] for arg in args: - if callable(arg) and not isinstance(arg, Var): + if isinstance(arg, Call): res = arg(vl) if isinstance(res, VarList): raise ValueError("Must select specific column. Did you pass `_` to select?") all_vars.append(res) elif isinstance(arg, Var): all_vars.append(arg) + elif callable(arg): + all_vars.append(arg) else: all_vars.append(Var(arg)) return tuple(all_vars) + + +@singledispatch +def colwise_eval(data, predicate): + raise NotImplementedError( + f"Cannot evaluate tidyselect predicate on data type: {type(data)}" + ) + + +@colwise_eval.register +def _colwise_eval_pd(data: pd.DataFrame, predicate) -> List[bool]: + mask = [] + for col_name in data: + res = predicate(data.loc[:, col_name]) + if not pd.api.types.is_bool(res): + raise TypeError("TODO") + + mask.append(res) + + return mask + + + diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 4d709af9..327bf322 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -183,7 +183,7 @@ def show_query(__data, simplify = False): # TODO: support for unnamed args @singledispatch2(pd.DataFrame) -def mutate(__data, **kwargs): +def mutate(__data, *args, **kwargs): """Assign new variables to a DataFrame, while keeping existing ones. Parameters @@ -207,9 +207,25 @@ def mutate(__data, **kwargs): 1 6 21.0 110 12 24 """ - - orig_cols = __data.columns - result = __data.assign(**kwargs) + + args_result_df = __data.copy() + + # handle across ---- + for arg in args: + # TODO: make robust. validate input. validate output (e.g. shape). + new_col_map = arg(args_result_df) + + if not isinstance(new_col_map, pd.DataFrame): + raise NotImplementedError("Only across() can be used as positional argument.") + + for col_name, col_ser in new_col_map.items(): + args_result_df[col_name] = col_ser + + # handle everything else ---- + # TODO: what if kw expr returns DataFrame? + + orig_cols = args_result_df.columns + result = args_result_df.assign(**kwargs) new_cols = result.columns[~result.columns.isin(orig_cols)] @@ -218,11 +234,13 @@ def mutate(__data, **kwargs): @mutate.register(DataFrameGroupBy) -def _mutate(__data, **kwargs): +def _mutate(__data, *args, **kwargs): groupings = __data.grouper.groupings orig_index = __data.obj.index - df = __data.apply(lambda d: d.assign(**kwargs)) + f_mutate = mutate.dispatch(pd.DataFrame) + + df = __data.apply(lambda d: f_mutate(d, *args, **kwargs)) # will drop all but original index group_by_lvls = list(range(df.index.nlevels - 1)) @@ -385,7 +403,14 @@ def filter(__data, *args): """ crnt_indx = True for arg in args: - crnt_indx &= arg(__data) if callable(arg) else arg + res = arg(__data) if callable(arg) else arg + + if isinstance(res, pd.DataFrame): + crnt_indx &= res.all(axis=1) + elif isinstance(res, pd.Series): + crnt_indx &= res + else: + crnt_indx &= res # use loc or iloc to subset, depending on crnt_indx ---- # the main issue here is that loc can't remove all rows using a slice @@ -398,6 +423,7 @@ def filter(__data, *args): return result + @filter.register(DataFrameGroupBy) def _filter(__data, *args): groupings = __data.grouper.groupings @@ -415,8 +441,9 @@ def _filter(__data, *args): # Summarize =================================================================== + @singledispatch2(DataFrame) -def summarize(__data, **kwargs): +def summarize(__data, *args, **kwargs): """Assign variables that are single number summaries of a DataFrame. Grouped DataFrames will produce one row for each group. Otherwise, summarize @@ -455,26 +482,57 @@ def summarize(__data, **kwargs): """ results = {} + + for ii, expr in enumerate(args): + if not callable(expr): + raise TypeError( + "Unnamed arguments to summarize must be callable, but argument number " + f"{ii} was type: {type(expr)}" + ) + + res = expr(__data) + if isinstance(res, DataFrame): + if len(res) != 1: + raise ValueError( + f"Summarize argument `{ii}` returned a DataFrame with {len(res)} rows." + " Result must only be a single row." + ) + + for col_name in res.columns: + results[col_name] = res[col_name].array + else: + raise ValueError( + "Unnamed arguments to summarize must return a DataFrame, but argument " + f"`{ii} returned type: {type(expr)}" + ) + + + for k, v in kwargs.items(): + # TODO: raise error if a named expression returns a DataFrame res = v(__data) if callable(v) else v - # validate operations returned single result - if not is_scalar(res) and len(res) > 1: - raise ValueError("Summarize argument, %s, must return result of length 1 or a scalar." % k) + if is_scalar(res) or len(res) == 1: + # keep result, but use underlying array to avoid crazy index issues + # on DataFrame construction (#138) + results[k] = res.array if isinstance(res, pd.Series) else res - # keep result, but use underlying array to avoid crazy index issues - # on DataFrame construction (#138) - results[k] = res.array if isinstance(res, pd.Series) else res + else: + raise ValueError( + f"Summarize argument `{k}` must return result of length 1 or a scalar.\n\n" + f"Result type: {type(res)}\n" + f"Result length: {len(res)}" + ) # must pass index, or raises error when using all scalar values return DataFrame(results, index = [0]) @summarize.register(DataFrameGroupBy) -def _summarize(__data, **kwargs): +def _summarize(__data, *args, **kwargs): df_summarize = summarize.registry[pd.DataFrame] - df = __data.apply(df_summarize, **kwargs) + df = __data.apply(df_summarize, *args, **kwargs) group_by_lvls = list(range(df.index.nlevels - 1)) out = df.reset_index(group_by_lvls) @@ -643,7 +701,7 @@ def select(__data, *args, **kwargs): ) var_list = var_create(*args) - od = var_select(__data.columns, *var_list) + od = var_select(__data.columns, *var_list, data=__data) to_rename = {k: v for k,v in od.items() if v is not None} diff --git a/siuba/siu/__init__.py b/siuba/siu/__init__.py index 5f753f9a..d4f8953f 100644 --- a/siuba/siu/__init__.py +++ b/siuba/siu/__init__.py @@ -6,6 +6,8 @@ BinaryOp, _SliceOpIndex, DictCall, + FormulaArg, + FormulaContext, str_to_getitem_call ) from .symbolic import Symbolic, strip_symbolic, create_sym_call, explain @@ -15,5 +17,6 @@ Lam = Lazy _ = Symbolic() +Fx = Symbolic(FormulaArg("Fx")) diff --git a/siuba/siu/calls.py b/siuba/siu/calls.py index fa295e37..374d443f 100644 --- a/siuba/siu/calls.py +++ b/siuba/siu/calls.py @@ -2,6 +2,7 @@ import operator from abc import ABC +from collections.abc import Mapping # TODO: symbolic formatting: __add__ -> "+" @@ -570,6 +571,22 @@ def __new__(cls, func, *args, **kwargs): SliceOp.register(_SliceOpExt) +# formulas --------------------------- + +class FormulaContext(Mapping): + def __init__(self, *args, **kwargs): + self.d = dict(*args, **kwargs) + + def __getitem__(self, k): + return self.d[k] + + def __iter__(self): + return iter(self.d) + + def __len__(self): + return len(self.d) + + # Special kinds of call arguments ---- # These functions insure that when using siu expressions generated by _, # that call.args[0] is always another call. This allows them to trivially @@ -579,7 +596,6 @@ def __new__(cls, func, *args, **kwargs): # set of behavior similar to theirs. # # TODO: validate that call.args[0] is a Call in tree visitors? - class MetaArg(Call): """Represent an argument, by returning the argument passed to __call__.""" @@ -592,8 +608,32 @@ def __repr__(self): return self.func def __call__(self, x): + if isinstance(x, FormulaContext): + return x[self.func] + return x + +# TODO: MetaArg should be a subclass of this? +class FormulaArg(MetaArg): + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = tuple() + self.kwargs = {} + + def __repr__(self): + return f"FormulaArg({repr(self.func)})" + + def __call__(self, x: FormulaContext): + if isinstance(x, FormulaContext): + return x[self.func] + + raise TypeError( + f"The formula object {self.func} must receive a FormulaContext when called." + ) + + + class FuncArg(Call): """Represent a function to be called.""" diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py new file mode 100644 index 00000000..2117f9ea --- /dev/null +++ b/siuba/tests/test_verb_across.py @@ -0,0 +1,236 @@ +import pandas as pd +import pytest + +from pandas.testing import assert_frame_equal +from pandas.core.groupby import DataFrameGroupBy +from siuba.siu import symbolic_dispatch, Symbolic, Fx +from siuba.dply.verbs import mutate, filter, summarize +from siuba.dply.across import across + + +# Helpers ===================================================================== +_ = Symbolic() + +@symbolic_dispatch(cls = pd.Series) +def f_round(x) -> pd.Series: + return round(x) + + +@symbolic_dispatch(cls = pd.Series) +def f_mean(x) -> pd.Series: + return x.mean() + + +def assert_grouping_names(gdf, names): + assert isinstance(gdf, DataFrameGroupBy) + groupings = gdf.grouper.groupings + + assert len(groupings) == len(names) + + grouping_names = [g.name for g in groupings] + assert grouping_names == names + + +# Fixtures ==================================================================== + +@pytest.fixture +def df(): + return pd.DataFrame({ + "a_x": [1, 2], + "a_y": [3, 4], + "b_x": [5., 6.], # note the floats + "g": ["m", "n"] + }) + + +# Tests ======================================================================= + +@pytest.mark.parametrize("func", [ + (f_round), + (Fx.round()), + (f_round(Fx)), + (lambda x: x.round()), +]) +def test_across_func_transform(df, func): + res = across(df, _[_.a_x, _.a_y], func) + dst = pd.DataFrame({ + "a_x": df.a_x.round(), + "a_y": df.a_y.round() + }) + + assert_frame_equal(res, dst) + + +@pytest.mark.parametrize("func", [ + (f_mean), + (Fx.mean()), + (f_mean(Fx)), + (lambda x: x.mean()), +]) +def test_across_func_aggregate(df, func): + res = across(df, _[_.a_x, _.a_y], func) + dst = pd.DataFrame({ + "a_x": [df.a_x.mean()], + "a_y": [df.a_y.mean()] + }) + + assert_frame_equal(res, dst) + + +@pytest.mark.parametrize("func", [ + (lambda x: x % 2 > 1), + (Fx % 2 > 1), +]) +def test_across_func_bool(df, func): + res = across(df, _[_.a_x, _.a_y], func) + dst = pd.DataFrame({ + "a_x": df.a_x % 2 > 1, + "a_y": df.a_y % 2 > 1 + }) + + assert_frame_equal(res, dst) + + +@pytest.mark.parametrize("selection", [ + (_[0,1]), + (_[0:"a_y"]), + (lambda x: x.dtype == "int64"), + #(where(Fx.dtype == "int64")), + #(where(Fx.dtype != "float64") & ~_.g), + (~_[_.b_x, _.g]), + +]) +def test_across_selection(df, selection): + res = across(df, selection, lambda x: x + 1) + dst = df[["a_x", "a_y"]] + 1 + + assert_frame_equal(res, dst) + + +def test_across_selection_rename(df): + res = across(df, _.zzz == _.a_x, lambda x: x + 1) + assert res.columns.tolist() == ["zzz"] + + assert_frame_equal(res, (df[["a_x"]] + 1).rename(columns={"a_x": "zzz"})) + + +def test_across_in_mutate(df): + res_explicit = mutate(df, across(_, _[_.a_x, _.a_y], f_round)) + res_implicit = mutate(df, across(_[_.a_x, _.a_y], f_round)) + + dst = df.copy() + dst["a_x"] = df.a_x.round() + dst["a_y"] = df.a_y.round() + + assert_frame_equal(res_explicit, dst) + assert_frame_equal(res_implicit, dst) + + +def test_across_in_mutate_grouped_equiv_ungrouped(df): + gdf = df.groupby("g") + + expr_across = across(_, _[_.a_x, _.a_y], f_round) + g_res = mutate(gdf, expr_across) + dst = mutate(df, expr_across) + + assert_grouping_names(g_res, ["g"]) + assert_frame_equal(g_res.obj, dst) + + +def test_across_in_summarize(df): + res = summarize(df, across(_, _[_.a_x, _.a_y], f_mean)) + dst = pd.DataFrame({ + "a_x": [df.a_x.mean()], + "a_y": [df.a_y.mean()] + }) + + assert_frame_equal(res, dst) + + +def test_across_in_summarize_equiv_ungrouped(): + # note that summarize does not automatically regroup on any keys + src = pd.DataFrame({ + "a_x": [1, 2], + "a_y": [3, 4], + "b_x": [5., 6.], + "g": ["ZZ", "ZZ"] # Note: all groups the same + }) + + g_src = src.groupby("g") + + expr_across = across(_, _[_.a_x, _.a_y], f_mean) + g_res = summarize(g_src, expr_across) + dst = summarize(src, expr_across) + + assert g_res.columns.tolist() == ["g", "a_x", "a_y"] + assert g_res["g"].tolist() == ["ZZ"] + + assert_frame_equal(g_res.drop(columns="g"), dst) + + +def test_across_in_filter(df): + res = filter(df, across(_, _[_.a_x, _.a_y], lambda x: x % 2 > 0)) + + dst = df[(df[["a_x", "a_y"]] % 2 > 0).all(axis=1)] + + assert_frame_equal(res, dst) + + +def test_across_in_filter_equiv_ungrouped(df): + gdf = df.groupby("g") + + expr_across = across(_, _[_.a_x, _.a_y], lambda x: x % 2 > 0) + g_res = filter(gdf, expr_across) + dst = filter(df, expr_across) + + assert_grouping_names(g_res, ["g"]) + assert_frame_equal(g_res.obj, dst) + + +def test_across_formula_and_underscore(df): + res = across(df, _[_.a_x, _.a_y], f_round(Fx) / _.b_x) + + dst = pd.DataFrame({ + "a_x": df.a_x.round() / df.b_x, + "a_y": df.a_y.round() / df.b_x + }) + + assert_frame_equal(res, dst) + + +def test_across_names_arg(df): + res = across(df, _[_.a_x, _.a_y], Fx + 1, names="{col}_funkyname") + assert list(res.columns) == ["a_x_funkyname", "a_y_funkyname"] + + dst = (df[["a_x", "a_y"]] + 1).rename(columns = lambda s: s + "_funkyname") + assert_frame_equal(res, dst) + + +def test_across_func_dict(df): + res = across(df, _[_.a_x, _.a_y], {"plus1": Fx + 1, "plus2": Fx + 2}) + + dst = pd.DataFrame({ + "a_x_plus1": df.a_x + 1, + "a_x_plus2": df.a_x + 2, + "a_y_plus1": df.a_y + 1, + "a_y_plus2": df.a_y + 2 + }) + + assert_frame_equal(res, dst) + + +def test_across_func_dict_names_arg(df): + # TODO: also test aggregation + funcs = {"plus1": Fx + 1, "plus2": Fx + 2} + res = across(df, _[_.a_x, _.a_y], funcs, names="{fn}_{col}") + + dst = pd.DataFrame({ + "plus1_a_x": df.a_x + 1, + "plus2_a_x": df.a_x + 2, + "plus1_a_y": df.a_y + 1, + "plus2_a_y": df.a_y + 2, + }) + + assert_frame_equal(res, dst) + + From 5a4e58c8f88c47f9caabce86c7e8130e9a026c26 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Fri, 30 Sep 2022 14:32:39 -0400 Subject: [PATCH 02/27] feat(sql): basic sql mutate across --- siuba/dply/across.py | 43 ++++++++++++++++--- siuba/sql/__init__.py | 1 + siuba/sql/across.py | 67 +++++++++++++++++++++++++++++ siuba/sql/utils.py | 4 +- siuba/sql/verbs.py | 74 +++++++++++++++++++++++++++++---- siuba/tests/test_verb_across.py | 46 +++++++++++++++----- 6 files changed, 208 insertions(+), 27 deletions(-) create mode 100644 siuba/sql/across.py diff --git a/siuba/dply/across.py b/siuba/dply/across.py index 4fe6ed56..5f60fcf3 100644 --- a/siuba/dply/across.py +++ b/siuba/dply/across.py @@ -3,16 +3,42 @@ from pandas.core.groupby import DataFrameGroupBy from .verbs import var_select, var_create -from ..siu import FormulaContext, Call, strip_symbolic, Fx, call -from ..siu.dispatchers import verb_dispatch, symbolic_dispatch +from ..siu import FormulaContext, Call, strip_symbolic, Fx, FuncArg +from ..siu.dispatchers import verb_dispatch, symbolic_dispatch, create_eager_pipe_call from collections.abc import Mapping +from contextvars import ContextVar from typing import Callable, Any DEFAULT_MULTI_FUNC_TEMPLATE = "{col}_{fn}" DEFAULT_SINGLE_FUNC_TEMPLATE = "{col}" +ctx_verb_data = ContextVar("data") + + +def _is_symbolic_operator(f): + # TODO: consolidate these checks, make the result of symbolic_dispatch a class. + return callable(f) and getattr(f, "_siu_symbolic_operator", False) + + +def _require_across(call, verb_name): + if not isinstance(call, Call) or (call.args and call.args[0] is across): + raise NotImplementedError( + "{verb_name} currently only allows a top-level across as an unnamed argument.\n\n" + "Example: {verb_name}(some_data, across(...))" + ) + + +def _eval_with_context(ctx, data, expr): + token = ctx_verb_data.set(ctx) + + try: + return expr(data) + finally: + ctx_verb_data.reset(token) + + # TODO: handle DataFrame manipulation in pandas / sql backends class AcrossResult(Mapping): def __init__(self, *args, **kwargs): @@ -42,19 +68,23 @@ def _across_setup_fns(fns) -> "dict[str, Callable[[FormulaContext], Any]]": # these are inside a dictionary, so need to strip manually. fn_call = strip_symbolic(fn_call_raw) - if not isinstance(fn_call, Call): + if isinstance(fn_call, Call): + final_calls[name] = fn_call + + elif callable(fn_call): + final_calls[name] = create_eager_pipe_call(FuncArg(fn_call), Fx) + + else: raise TypeError( "All functions to be applied in across must be a siuba.siu.Call, " f"but received a function of type {type(fn_call)}" ) - final_calls[name] = fn_call - elif isinstance(fns, Call): final_calls["fn1"] = fns elif callable(fns): - final_calls["fn1"] = call(fns, Fx) + final_calls["fn1"] = create_eager_pipe_call(FuncArg(fns), Fx) else: raise NotImplementedError(f"Unsupported function type in across: {type(fns)}") @@ -71,6 +101,7 @@ def _get_name_template(fns, names: "str | None") -> str: return DEFAULT_MULTI_FUNC_TEMPLATE + @verb_dispatch(pd.DataFrame) def across(__data, cols, fns, names: "str | None" = None) -> pd.DataFrame: diff --git a/siuba/sql/__init__.py b/siuba/sql/__init__.py index 2bd6e3aa..5af7053b 100644 --- a/siuba/sql/__init__.py +++ b/siuba/sql/__init__.py @@ -1,5 +1,6 @@ from .verbs import LazyTbl, sql_raw from .translate import SqlColumn, SqlColumnAgg, SqlFunctionLookupError +from . import across as _across # proceed w/ underscore so it isn't exported by default # we just want to register the singledispatch funcs diff --git a/siuba/sql/across.py b/siuba/sql/across.py new file mode 100644 index 00000000..9b7cbc08 --- /dev/null +++ b/siuba/sql/across.py @@ -0,0 +1,67 @@ +from siuba.dply.across import across, _get_name_template, _across_setup_fns, ctx_verb_data +from siuba.dply.tidyselect import var_select, var_create +from siuba.siu import FormulaContext, Call + +from . verbs import LazyTbl +from .utils import _sql_select, _sql_column_collection + +from sqlalchemy import sql + + +@across.register(LazyTbl) +def _across_lazy_tbl(__data: LazyTbl, cols, fns, names: "str | None" = None) -> LazyTbl: + raise NotImplementedError( + "across() cannot called directly on a LazyTbl. Please use it inside a verb, " + "like mutate(), summarize(), filter(), arrange(), group_by(), etc.." + ) + #selectable = __data.last_op + # + #columns = selectable.alias().columns + #if not isinstance(columns, ImmutableColumnCollection): + # raise TypeError(str(type(columns))) + + #res_cols = across(columns, cols, fns, names) + + #return __data.append_op(_sql_select(res_cols)) + + +@across.register(sql.base.ImmutableColumnCollection) +def _across_sql_cols( + __data: sql.base.ImmutableColumnCollection, + cols, + fns, + names: "str | None" = None +) -> sql.base.ImmutableColumnCollection: + + lazy_tbl = ctx_verb_data.get() + + name_template = _get_name_template(fns, names) + selected_cols = var_select(__data, *var_create(cols), data=__data) + + fns_map = _across_setup_fns(fns) + + results = [] + + # iterate over columns ---- + for new_name, old_name in selected_cols.items(): + if old_name is None: + old_name = new_name + + crnt_col = __data[old_name] + context = FormulaContext(Fx=crnt_col, _=__data) + + # iterate over functions ---- + for fn_name, fn in fns_map.items(): + fmt_pars = {"fn": fn_name, "col": new_name} + + new_call = lazy_tbl.shape_call( + fn, + verb_name="Across", + arg_name = f"function {fn_name} of {len(fns_map)}" + ) + + res = new_call(context) + res_name = name_template.format(**fmt_pars) + results.append(res.label(res_name)) + + return _sql_column_collection(results) diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 7310b328..cf1d26e2 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -114,9 +114,11 @@ def _sql_select(columns, *args, **kwargs): return sql.select(*columns, *args, **kwargs) -def _sql_column_collection(data, columns): +def _sql_column_collection(columns): from sqlalchemy.sql.base import ColumnCollection, ImmutableColumnCollection + data = {col.key: col for col in columns} + if is_sqla_12() or is_sqla_13(): return ImmutableColumnCollection(data, columns) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index c03580a7..2b49488b 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -53,7 +53,6 @@ from sqlalchemy.sql import schema - # TODO: # - distinct # - annotate functions using sel.prefix_with("\n/**/\n") ? @@ -154,6 +153,29 @@ def _get_over_clauses(clause): return windows +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_labels = [x for x in src_columns if isinstance(x, sql.elements.Label)] + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + if el in self.src_labels: + import pdb; pdb.set_trace() + self.applied = True + return self.dst_columns[el.name] + + return None + + #def track_call_windows(call, columns, group_by, order_by, window_cte = None): # listener = WindowReplacer(columns, group_by, order_by, window_cte) # col = listener.enter(call) @@ -205,9 +227,8 @@ def replace_call_windows(col_expr, group_by, order_by, window_cte = None): def lift_inner_cols(tbl): cols = list(tbl.inner_columns) - data = {col.key: col for col in cols} - return _sql_column_collection(data, cols) + return _sql_column_collection(cols) def col_expr_requires_cte(call, sel, is_mutate = False): """Return whether a variable assignment needs a CTE""" @@ -614,19 +635,44 @@ def _filter(__data, *args): @mutate.register(LazyTbl) -def _mutate(__data, **kwargs): - # Cases - # - work with group by - # - window functions +def _mutate(__data, *args, **kwargs): + from siuba.dply.across import _require_across, _eval_with_context + # TODO: verify it can follow a renaming select # track labeled columns in set - if not len(kwargs): + if not (len(args) or len(kwargs)): return __data.append_op(__data.last_op) - sel = __data.last_select + across_sel = __data.last_select + + # special support for across + for ii, func in enumerate(args): + _require_across(func, "Mutate") + _candidate_sel_alias = across_sel.alias() + + inner_cols = lift_inner_cols(across_sel) + + #new_call = __data.shape_call(func, verb_name = "Mutate", arg_name = f"*arg entry {ii}") + cols_result = _eval_with_context(__data, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + # replace any labels that require a subquery ---- + replacer = SqlLabelReplacer(set(inner_cols), _candidate_sel_alias.columns) + replaced_cols = list(map(replacer, cols_result)) + + if replacer.applied: + # TODO: use replace logic from _mutate_select + next_sel = _candidate_sel_alias.select() + else: + next_sel = across_sel + + across_sel = _sql_upsert_columns(across_sel, replaced_cols) # evaluate each call + sel = across_sel for colname, func in kwargs.items(): # keep set of columns labeled (aliased) in this select statement # need to use inner cols, since sel.columns uses ColumnClause, not Label @@ -638,6 +684,16 @@ def _mutate(__data, **kwargs): return __data.append_op(sel) +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + + def _mutate_select(sel, colname, func, labs, __data): """Return select statement containing a new column, func expr as colname. diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 2117f9ea..09cf8f51 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -7,15 +7,25 @@ from siuba.dply.verbs import mutate, filter, summarize from siuba.dply.across import across +from siuba.experimental.pivot.test_pivot import assert_equal_query2 +from siuba.sql.translate import SqlColumn, SqlColumnAgg, sql_scalar + # Helpers ===================================================================== + _ = Symbolic() +# round function ---- + @symbolic_dispatch(cls = pd.Series) def f_round(x) -> pd.Series: return round(x) +f_round.register(SqlColumn, sql_scalar("round")) + +# mean function ---- + @symbolic_dispatch(cls = pd.Series) def f_mean(x) -> pd.Series: return x.mean() @@ -45,12 +55,14 @@ def df(): # Tests ======================================================================= -@pytest.mark.parametrize("func", [ - (f_round), - (Fx.round()), - (f_round(Fx)), - (lambda x: x.round()), -]) +TRANSFORMATION_FUNCS = [ + f_round, + Fx.round(), + f_round(Fx), +] + + +@pytest.mark.parametrize("func", TRANSFORMATION_FUNCS) def test_across_func_transform(df, func): res = across(df, _[_.a_x, _.a_y], func) dst = pd.DataFrame({ @@ -58,6 +70,16 @@ def test_across_func_transform(df, func): "a_y": df.a_y.round() }) + assert_equal_query2(res, dst) + + +def test_across_func_transform_lambda(df): + res = across(df, _[_.a_x, _.a_y], lambda x: x.round()) + dst = pd.DataFrame({ + "a_x": df.a_x.round(), + "a_y": df.a_y.round() + }) + assert_frame_equal(res, dst) @@ -114,16 +136,18 @@ def test_across_selection_rename(df): assert_frame_equal(res, (df[["a_x"]] + 1).rename(columns={"a_x": "zzz"})) -def test_across_in_mutate(df): - res_explicit = mutate(df, across(_, _[_.a_x, _.a_y], f_round)) - res_implicit = mutate(df, across(_[_.a_x, _.a_y], f_round)) +@pytest.mark.parametrize("func", TRANSFORMATION_FUNCS) +def test_across_in_mutate(backend, df, func): + src = backend.load_df(df) + res_explicit = mutate(src, across(_, _[_.a_x, _.a_y], f_round)) + res_implicit = mutate(src, across(_[_.a_x, _.a_y], f_round)) dst = df.copy() dst["a_x"] = df.a_x.round() dst["a_y"] = df.a_y.round() - assert_frame_equal(res_explicit, dst) - assert_frame_equal(res_implicit, dst) + assert_equal_query2(res_explicit, dst, sql_kwargs={"check_dtype": False}) + assert_equal_query2(res_implicit, dst, sql_kwargs={"check_dtype": False}) def test_across_in_mutate_grouped_equiv_ungrouped(df): From 430cc0c2d5a3fa12626c4e03ad4b3e8c3f98f943 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sat, 1 Oct 2022 16:04:45 -0400 Subject: [PATCH 03/27] feat(sql): basic sql filter --- siuba/dply/across.py | 20 +++++++-- siuba/siu/calls.py | 10 +++++ siuba/siu/dispatchers.py | 10 ++--- siuba/sql/verbs.py | 72 +++++++++++++++++++++------------ siuba/tests/test_verb_across.py | 32 +++++++++------ 5 files changed, 97 insertions(+), 47 deletions(-) diff --git a/siuba/dply/across.py b/siuba/dply/across.py index 5f60fcf3..e72cf4fd 100644 --- a/siuba/dply/across.py +++ b/siuba/dply/across.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from contextvars import ContextVar +from contextlib import contextmanager from typing import Callable, Any DEFAULT_MULTI_FUNC_TEMPLATE = "{col}_{fn}" @@ -23,10 +24,13 @@ def _is_symbolic_operator(f): def _require_across(call, verb_name): - if not isinstance(call, Call) or (call.args and call.args[0] is across): + if ( + not isinstance(call, Call) + or not (call.args and getattr(call.args[0], "__name__", None) == "across") + ): raise NotImplementedError( - "{verb_name} currently only allows a top-level across as an unnamed argument.\n\n" - "Example: {verb_name}(some_data, across(...))" + f"{verb_name} currently only allows a top-level across as an unnamed argument.\n\n" + f"Example: {verb_name}(some_data, across(...))" ) @@ -39,6 +43,16 @@ def _eval_with_context(ctx, data, expr): ctx_verb_data.reset(token) +@contextmanager +def _set_data_context(ctx): + try: + token = ctx_verb_data.set(ctx) + yield + finally: + ctx_verb_data.reset(token) + + + # TODO: handle DataFrame manipulation in pandas / sql backends class AcrossResult(Mapping): def __init__(self, *args, **kwargs): diff --git a/siuba/siu/calls.py b/siuba/siu/calls.py index 374d443f..e66224e8 100644 --- a/siuba/siu/calls.py +++ b/siuba/siu/calls.py @@ -330,6 +330,16 @@ def __call__(self, x, *args, **kwargs): return self.args[0] +class _Isolate(Lazy): + """Lazily return calls, and do dispatch visitors.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.func = "" + + def map_subcalls(self, f, args = tuple(), kwargs = None): + return self.args, {**self.kwargs} + class UnaryOp(Call): """Represent unary call operations.""" diff --git a/siuba/siu/dispatchers.py b/siuba/siu/dispatchers.py index bd71655b..889ff020 100644 --- a/siuba/siu/dispatchers.py +++ b/siuba/siu/dispatchers.py @@ -3,7 +3,7 @@ from functools import singledispatch, update_wrapper, wraps import inspect -from .calls import Call, FuncArg, MetaArg, Lazy, PipeCall +from .calls import Call, FuncArg, MetaArg, Lazy, PipeCall, _Isolate from .symbolic import Symbolic, create_sym_call, strip_symbolic from typing import Callable @@ -268,13 +268,13 @@ def __call__(self, x): res = f(res) return res -def _prep_lazy_args(*args): +def _prep_isolate_args(*args): result = [] for ii, arg in enumerate(args): if ii == 0: result.append(strip_symbolic(arg)) else: - result.append(Lazy(strip_symbolic(arg))) + result.append(_Isolate(strip_symbolic(arg))) return result @@ -282,13 +282,13 @@ def _prep_lazy_args(*args): def create_pipe_call(obj, *args, **kwargs) -> Call: """Return a Call of a function on its args and kwargs, wrapped in a Pipeable.""" - stripped_args = _prep_lazy_args(*args) + stripped_args = _prep_isolate_args(*args) return Call( "__call__", strip_symbolic(obj), *stripped_args, - **{k: Lazy(strip_symbolic(v)) for k,v in kwargs.items()} + **{k: _Isolate(strip_symbolic(v)) for k,v in kwargs.items()} ) def create_eager_pipe_call(obj, *args, **kwargs) -> Call: diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 2b49488b..11931c35 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -50,6 +50,7 @@ from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 # TODO: currently needed for select, but can we remove pandas? from pandas import Series +from functools import singledispatch from sqlalchemy.sql import schema @@ -195,10 +196,28 @@ def track_call_windows(call, columns, group_by, order_by, window_cte = None): +@singledispatch def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): over_clauses = WindowReplacer._get_over_clauses(col_expr) @@ -581,32 +600,39 @@ def _select(__data, *args, **kwargs): @filter.register(LazyTbl) def _filter(__data, *args): - # TODO: aggregate funcs + from siuba.dply.across import _require_across, _set_data_context + # Note: currently always produces 2 additional select statements, # 1 for window/aggs, and 1 for the where clause + sel = __data.last_op.alias() # original select win_sel = sel.select() conds = [] windows = [] - for ii, arg in enumerate(args): + with _set_data_context(__data): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - conds.append(col_expr) - windows.extend(win_cols) - - else: - conds.append(arg) + windows.extend(win_cols) + + else: + conds.append(arg) bool_clause = sql.and_(*conds) @@ -615,12 +641,6 @@ def _filter(__data, *args): win_alias = win_sel.alias() - # because track_call_windows in the loop above used select.append_column - # multiple times, sqlalchemy doesn't know our window columns are the ones - # in the final mutated for of win_sel - #col_key_map = {col.key: col for col in win_alias.columns.values()} - #equivalents = {col: [col_key_map[col.key]] for col in windows} - # move non-window functions to refer to win_sel clause (not the innermost) --- bool_clause = sql.util.ClauseAdapter(win_alias) \ .traverse(bool_clause) @@ -669,7 +689,7 @@ def _mutate(__data, *args, **kwargs): else: next_sel = across_sel - across_sel = _sql_upsert_columns(across_sel, replaced_cols) + across_sel = _sql_upsert_columns(next_sel, replaced_cols) # evaluate each call sel = across_sel diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 09cf8f51..122e5d88 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -4,7 +4,7 @@ from pandas.testing import assert_frame_equal from pandas.core.groupby import DataFrameGroupBy from siuba.siu import symbolic_dispatch, Symbolic, Fx -from siuba.dply.verbs import mutate, filter, summarize +from siuba.dply.verbs import mutate, filter, summarize, group_by, collect, ungroup from siuba.dply.across import across from siuba.experimental.pivot.test_pivot import assert_equal_query2 @@ -32,12 +32,16 @@ def f_mean(x) -> pd.Series: def assert_grouping_names(gdf, names): - assert isinstance(gdf, DataFrameGroupBy) - groupings = gdf.grouper.groupings + from siuba.sql import LazyTbl - assert len(groupings) == len(names) + if isinstance(gdf, LazyTbl): + grouping_names = list(gdf.group_by) + else: + assert isinstance(gdf, DataFrameGroupBy) + groupings = gdf.grouper.groupings + grouping_names = [g.name for g in groupings] - grouping_names = [g.name for g in groupings] + assert len(grouping_names) == len(names) assert grouping_names == names @@ -150,15 +154,16 @@ def test_across_in_mutate(backend, df, func): assert_equal_query2(res_implicit, dst, sql_kwargs={"check_dtype": False}) -def test_across_in_mutate_grouped_equiv_ungrouped(df): - gdf = df.groupby("g") +def test_across_in_mutate_grouped_equiv_ungrouped(backend, df): + src = backend.load_df(df) + g_src = group_by(src, "g") expr_across = across(_, _[_.a_x, _.a_y], f_round) - g_res = mutate(gdf, expr_across) - dst = mutate(df, expr_across) + g_res = mutate(g_src, expr_across) + dst = mutate(src, expr_across) assert_grouping_names(g_res, ["g"]) - assert_frame_equal(g_res.obj, dst) + assert_equal_query2(ungroup(g_res), collect(dst)) def test_across_in_summarize(df): @@ -192,12 +197,13 @@ def test_across_in_summarize_equiv_ungrouped(): assert_frame_equal(g_res.drop(columns="g"), dst) -def test_across_in_filter(df): - res = filter(df, across(_, _[_.a_x, _.a_y], lambda x: x % 2 > 0)) +def test_across_in_filter(backend, df): + src = backend.load_df(df) + res = filter(src, across(_, _[_.a_x, _.a_y], Fx % 2 > 0)) dst = df[(df[["a_x", "a_y"]] % 2 > 0).all(axis=1)] - assert_frame_equal(res, dst) + assert_equal_query2(res, dst) def test_across_in_filter_equiv_ungrouped(df): From 42b79817831468be925f7bee55398af002ac3f35 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sat, 1 Oct 2022 18:07:59 -0400 Subject: [PATCH 04/27] refactor(pandas): move out mutate logic to be used in other verbs --- siuba/dply/verbs.py | 59 +++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 327bf322..8364c96b 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -112,6 +112,40 @@ def _regroup(df): return df.groupby(level = grp_levels) +def _mutate_cols(__data, args, kwargs): + from pandas.core.common import apply_if_callable + + result_names = {} # used as ordered set + df_tmp = __data.copy() + + for arg in args: + + # case 1: a simple, existing name is a no-op ---- + simple_name = simple_varname(arg) + if simple_name is not None and simple_name in df_tmp.columns: + result_names[simple_name] = True + continue + + # case 2: across ---- + # TODO: make robust. validate input. validate output (e.g. shape). + res_arg = arg(df_tmp) + + if not isinstance(res_arg, pd.DataFrame): + raise NotImplementedError("Only across() can be used as positional argument.") + + for col_name, col_ser in res_arg.items(): + # need to put on the frame so subsequent args, kwargs can use + df_tmp[col_name] = col_ser + result_names[col_name] = True + + for col_name, expr in kwargs.items(): + # this is exactly what DataFrame.assign does + df_tmp[col_name] = apply_if_callable(expr, df_tmp) + result_names[col_name] = True + + return result_names, df_tmp + + MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}" def raise_type_error(f): @@ -208,29 +242,8 @@ def mutate(__data, *args, **kwargs): """ - args_result_df = __data.copy() - - # handle across ---- - for arg in args: - # TODO: make robust. validate input. validate output (e.g. shape). - new_col_map = arg(args_result_df) - - if not isinstance(new_col_map, pd.DataFrame): - raise NotImplementedError("Only across() can be used as positional argument.") - - for col_name, col_ser in new_col_map.items(): - args_result_df[col_name] = col_ser - - # handle everything else ---- - # TODO: what if kw expr returns DataFrame? - - orig_cols = args_result_df.columns - result = args_result_df.assign(**kwargs) - - new_cols = result.columns[~result.columns.isin(orig_cols)] - - return result.loc[:, [*orig_cols, *new_cols]] - + new_names, df_res = _mutate_cols(__data, args, kwargs) + return df_res @mutate.register(DataFrameGroupBy) From 80cd7818e31ffedc7940dd54be0d840eb5d5d161 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 2 Oct 2022 20:30:11 -0400 Subject: [PATCH 05/27] feat(pandas): support across in group_by, transmute, count --- siuba/dply/verbs.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 8364c96b..195fbed9 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -143,7 +143,7 @@ def _mutate_cols(__data, args, kwargs): df_tmp[col_name] = apply_if_callable(expr, df_tmp) result_names[col_name] = True - return result_names, df_tmp + return list(result_names), df_tmp MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}" @@ -215,7 +215,7 @@ def show_query(__data, simplify = False): # Mutate ====================================================================== -# TODO: support for unnamed args + @singledispatch2(pd.DataFrame) def mutate(__data, *args, **kwargs): """Assign new variables to a DataFrame, while keeping existing ones. @@ -331,19 +331,24 @@ def group_by(__data, *args, add = False, **kwargs): 1 6 21.0 110 (20.2, 21.4] """ + + if isinstance(__data, DataFrameGroupBy): + tmp_df = __data.obj.copy() + else: + tmp_df = __data.copy() - tmp_df = mutate(__data, **kwargs) if kwargs else __data - - by_vars = list(map(simple_varname, args)) - for ii, name in enumerate(by_vars): - if name is None: raise Exception("group by variable %s is not a column name" %ii) + # TODO: super inefficient, since it makes multiple copies of data + # need way to get the by_vars and apply (grouped) computation + computed = ungroup(transmute(__data, *args, **kwargs)) + by_vars = list(computed.columns) - by_vars.extend(kwargs.keys()) + for k in by_vars: + tmp_df[k] = computed[k] - if isinstance(tmp_df, DataFrameGroupBy) and add: + if isinstance(__data, DataFrameGroupBy) and add: prior_groups = [el.name for el in __data.grouper.groupings] all_groups = ordered_union(prior_groups, by_vars) - return tmp_df.obj.groupby(list(all_groups)) + return tmp_df.groupby(list(all_groups)) return tmp_df.groupby(by = by_vars) @@ -595,14 +600,10 @@ def transmute(__data, *args, **kwargs): """ arg_vars = list(map(simple_varname, args)) - for ii, name in enumerate(arg_vars): - if name is None: raise Exception("complex, unnamed expression at pos %s not supported"%ii) - - f_mutate = mutate.registry[pd.DataFrame] - df = f_mutate(__data, **kwargs) + col_names, df_res = _mutate_cols(__data, args, kwargs) + return df_res[col_names] - return df[[*arg_vars, *kwargs.keys()]] @transmute.register(DataFrameGroupBy) def _transmute(__data, *args, **kwargs): From dde0800d770baad2aac45d3001bd24d272bd1f59 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sun, 2 Oct 2022 22:56:22 -0400 Subject: [PATCH 06/27] fix(pandas): better group handling; fix add_count when changing group vars --- siuba/dply/verbs.py | 127 ++++++++++++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 34 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 195fbed9..0d02ff69 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -146,6 +146,10 @@ def _mutate_cols(__data, args, kwargs): return list(result_names), df_tmp +def _make_groupby_safe(gdf): + return gdf.obj.groupby(gdf.grouper, group_keys=False) + + MSG_TYPE_ERROR = "The first argument to {func} must be one of: {types}" def raise_type_error(f): @@ -248,20 +252,20 @@ def mutate(__data, *args, **kwargs): @mutate.register(DataFrameGroupBy) def _mutate(__data, *args, **kwargs): - groupings = __data.grouper.groupings - orig_index = __data.obj.index + out = __data.obj.copy() + groupings = {ping.name: ping for ping in __data.grouper.groupings} - f_mutate = mutate.dispatch(pd.DataFrame) + f_transmute = transmute.dispatch(pd.DataFrame) - df = __data.apply(lambda d: f_mutate(d, *args, **kwargs)) - - # will drop all but original index - group_by_lvls = list(range(df.index.nlevels - 1)) - g_df = df.reset_index(group_by_lvls, drop = True).loc[orig_index].groupby(groupings) + df = _make_groupby_safe(__data).apply(lambda d: f_transmute(d, *args, **kwargs)) - return g_df + for varname, ser in df.items(): + if varname in groupings: + groupings[varname] = varname + out[varname] = ser + return out.groupby(list(groupings.values())) # Group By ==================================================================== @@ -339,16 +343,20 @@ def group_by(__data, *args, add = False, **kwargs): # TODO: super inefficient, since it makes multiple copies of data # need way to get the by_vars and apply (grouped) computation - computed = ungroup(transmute(__data, *args, **kwargs)) + computed = transmute(tmp_df, *args, **kwargs) by_vars = list(computed.columns) for k in by_vars: tmp_df[k] = computed[k] if isinstance(__data, DataFrameGroupBy) and add: - prior_groups = [el.name for el in __data.grouper.groupings] - all_groups = ordered_union(prior_groups, by_vars) - return tmp_df.groupby(list(all_groups)) + groupings = {el.name: el for el in __data.grouper.groupings} + + for varname in by_vars: + # ensures group levels are recalculated if varname was in transmute + groupings[varname] = varname + + return tmp_df.groupby(list(groupings.values())) return tmp_df.groupby(by = by_vars) @@ -376,10 +384,10 @@ def ungroup(__data): # the groupby? if isinstance(__data, pd.DataFrame): return __data - if isinstance(__data, pd.Series): - return __data.reset_index() - - return __data.obj.reset_index(drop = True) + elif isinstance(__data, DataFrameGroupBy): + return __data.obj + else: + raise TypeError(f"Unsupported type {type(__data)}") @@ -607,23 +615,20 @@ def transmute(__data, *args, **kwargs): @transmute.register(DataFrameGroupBy) def _transmute(__data, *args, **kwargs): - arg_vars = list(map(simple_varname, args)) - for ii, name in enumerate(arg_vars): - if name is None: raise Exception("complex, unnamed expression at pos %s not supported"%ii) - - f_mutate = mutate.registry[DataFrameGroupBy] + groupings = {ping.name: ping for ping in __data.grouper.groupings} - gdf = f_mutate(__data, **kwargs) - groupings = gdf.grouper.groupings + f_transmute = transmute.dispatch(pd.DataFrame) - group_names = [x.name for x in groupings] - if None in group_names: - raise ValueError("Passed a grouped DataFrame to transmute, but not all " - "its groups are named. Groups: %s" % group_names) + df = _make_groupby_safe(__data).apply(lambda d: f_transmute(d, *args, **kwargs)) - subset = ungroup(gdf)[[*group_names, *arg_vars, *kwargs.keys()]] + + for varname in reversed(list(groupings)): + if varname in df.columns: + groupings[varname] = varname + else: + df.insert(0, varname, __data.obj[varname]) - return subset.groupby(groupings) + return df.groupby(list(groupings.values())) @@ -1260,8 +1265,24 @@ def count(__data, *args, wt = None, sort = False, **kwargs): return counts +def _check_name(name, columns): + if name is None: + name = "n" + while name in columns: + name = name + "n" + + if name != "n": + # TODO: warning + pass + + elif not isinstance(name, str): + raise TypeError("`name` must be a single string.") + + return name + + @singledispatch2((pd.DataFrame, DataFrameGroupBy)) -def add_count(__data, *args, wt = None, sort = False, **kwargs): +def add_count(__data, *args, wt = None, sort = False, name = None, **kwargs): """Add a column that is the number of observations for each grouping of data. Note that this function is similar to count(), but does not aggregate. It's @@ -1309,10 +1330,48 @@ def add_count(__data, *args, wt = None, sort = False, **kwargs): """ - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(counts.columns)[:-1] - return inner_join(__data, counts, by = by) + no_grouping_vars = not args and not kwargs and isinstance(__data, pd.DataFrame) + + if no_grouping_vars: + out = __data + else: + out = group_by(__data, *args, add=True, **kwargs) + + var_names = ungroup(out).columns + name = _check_name(name, set(var_names)) + + if wt is None: + if no_grouping_vars: + # no groups, just use number of rows + counts = __data.copy() + counts[name] = counts.shape[0] + else: + # note that it's easy to transform tally using single grouped column, so + # we arbitrarily grab the first column.. + counts = out.obj.copy() + counts[name] = out[var_names[0]].transform("size") + + else: + wt_col = simple_varname(wt) + if wt_col is None: + raise Exception("wt argument has to be simple column name") + + if no_grouping_vars: + # no groups, sum weights + counts = __data.copy() + counts[name] = counts[wt_col].sum() + else: + # TODO: should flip topmost if/else so grouped code is together + # do weighted tally + counts = out.obj.copy() + counts[name] = out[wt_col].transform("sum") + + if sort: + return counts.sort_values(out_col, ascending = False) + + return counts + From 67680fbba574eafe63a1cfe1a037a030985da49a Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 13:02:12 -0400 Subject: [PATCH 07/27] feat(sql): support across in group_by, transmute --- siuba/sql/verbs.py | 165 +++++++++++++++++++++++++++------------------ 1 file changed, 100 insertions(+), 65 deletions(-) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 11931c35..4ed582f0 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -54,6 +54,8 @@ from sqlalchemy.sql import schema +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + # TODO: # - distinct # - annotate functions using sel.prefix_with("\n/**/\n") ? @@ -161,7 +163,8 @@ class SqlLabelReplacer: """ def __init__(self, src_columns, dst_columns): - self.src_labels = [x for x in src_columns if isinstance(x, sql.elements.Label)] + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) self.dst_columns = dst_columns self.applied = False @@ -169,10 +172,24 @@ def __call__(self, clause): return sql.util.visitors.replacement_traverse(clause, {}, self.visit) def visit(self, el): - if el in self.src_labels: - import pdb; pdb.set_trace() - self.applied = True - return self.dst_columns[el.name] + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # Raw SQL, which will need a subquery, but not substitution + if el.key != "*": + self.applied = True return None @@ -600,8 +617,6 @@ def _select(__data, *args, **kwargs): @filter.register(LazyTbl) def _filter(__data, *args): - from siuba.dply.across import _require_across, _set_data_context - # Note: currently always produces 2 additional select statements, # 1 for window/aggs, and 1 for the where clause @@ -656,61 +671,88 @@ def _filter(__data, *args): @mutate.register(LazyTbl) def _mutate(__data, *args, **kwargs): - from siuba.dply.across import _require_across, _eval_with_context - # TODO: verify it can follow a renaming select # track labeled columns in set if not (len(args) or len(kwargs)): return __data.append_op(__data.last_op) - across_sel = __data.last_select + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select - # special support for across for ii, func in enumerate(args): - _require_across(func, "Mutate") - _candidate_sel_alias = across_sel.alias() + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + result_names[simple_name] = True + continue - inner_cols = lift_inner_cols(across_sel) + # case 2: across ---- + _require_across(func, verb_name) - #new_call = __data.shape_call(func, verb_name = "Mutate", arg_name = f"*arg entry {ii}") + inner_cols = lift_inner_cols(sel) cols_result = _eval_with_context(__data, inner_cols, func) # TODO: remove or raise a more informative error assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) # replace any labels that require a subquery ---- - replacer = SqlLabelReplacer(set(inner_cols), _candidate_sel_alias.columns) - replaced_cols = list(map(replacer, cols_result)) + sel = _select_mutate_result(sel, cols_result) - if replacer.applied: - # TODO: use replace logic from _mutate_select - next_sel = _candidate_sel_alias.select() - else: - next_sel = across_sel + result_names.update({k: True for k in cols_result.keys()}) + + for new_name, func in kwargs.items(): + inner_cols = lift_inner_cols(sel) - across_sel = _sql_upsert_columns(next_sel, replaced_cols) + expr_shaped = __data.shape_call(func, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - # evaluate each call - sel = across_sel - for colname, func in kwargs.items(): - # keep set of columns labeled (aliased) in this select statement - # need to use inner cols, since sel.columns uses ColumnClause, not Label - labs = set(k for k,v in lift_inner_cols(sel).items() if isinstance(v, sql.elements.Label)) - new_call = __data.shape_call(func, verb_name = "Mutate", arg_name = colname) + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + labeled = new_col.label(new_name) + sel = _select_mutate_result(sel, labeled) - sel = _mutate_select(sel, colname, new_call, labs, __data) + result_names[new_name] = True - return __data.append_op(sel) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) + return list(result_names), sel @@ -747,14 +789,13 @@ def _mutate_select(sel, colname, func, labs, __data): @transmute.register(LazyTbl) -def _transmute(__data, **kwargs): +def _transmute(__data, *args, **kwargs): # will use mutate, then select some cols - f_mutate = mutate.registry[type(__data)] + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") # transmute keeps grouping cols, and any defined in kwargs - cols_to_keep = ordered_union(__data.group_by, kwargs) - - sel = f_mutate(__data, **kwargs).last_select + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] columns = lift_inner_cols(sel) sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) @@ -873,11 +914,8 @@ def _add_count(__data, *args, wt = None, sort = False, **kwargs): @summarize.register(LazyTbl) -def _summarize(__data, **kwargs): +def _summarize(__data, *args, **kwargs): # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - # what if windowed mutate or filter has been done? - # - filter is fine, since it uses a CTE - # - need to detect any window functions... old_sel = __data.last_select._clone() new_calls = {} @@ -938,26 +976,23 @@ def _summarize(__data, **kwargs): @group_by.register(LazyTbl) def _group_by(__data, *args, add = False, **kwargs): - if kwargs: - data = mutate(__data, **kwargs) - else: - data = __data + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - # put kwarg grouping vars last, so similar order to function call - groups = tuple(simple_varname(arg) for arg in args) + tuple(kwargs) - if None in groups: - raise NotImplementedError("Complex expressions not supported in sql group_by") + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - # ensure group_by variables are in the select columns - cols = data.last_op.alias().columns - unmatched = set(groups) - set(cols.keys()) - if unmatched: - raise KeyError("group_by specifies columns missing from table: %s" %unmatched) + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op if add: - groups = ordered_union(data.group_by, groups) + group_names = ordered_union(__data.group_by, group_names) - return data.copy(group_by = groups) + return __data.append_op(sel, group_by = tuple(group_names)) @ungroup.register(LazyTbl) From fbe8f6a71eced5ad0bdbf12a65105bd24976fad4 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 13:12:36 -0400 Subject: [PATCH 08/27] feat(sql): support across in count --- siuba/sql/verbs.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 4ed582f0..8b25d4b5 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -877,8 +877,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): ) arg_names.append(name) - sel_inner = mutate(__data, **kwargs).last_op - group_cols = arg_names + list(kwargs) + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") # create outer select ---- # holds selected columns and tally (n) @@ -886,11 +885,11 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): inner_cols = sel_inner_cte.columns # apply any group vars from a group_by verb call first - tbl_group_cols = [inner_cols[k] for k in __data.group_by] - count_group_cols = [inner_cols[k] for k in group_cols] + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] - # combine with any defined in the count verb call - outer_group_cols = ordered_union(tbl_group_cols, count_group_cols) # holds the actual count (e.g. n) count_col = sql.functions.count().label(res_name) From 93bc927ca10ac163eb43e5305cb33e8f60e0abce Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 14:03:17 -0400 Subject: [PATCH 09/27] refactor(sql): prep _mutate_cols to support arrange --- siuba/sql/verbs.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 8b25d4b5..93c64067 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -710,8 +710,9 @@ def _select_mutate_result(src_sel, expr_result): return _sql_upsert_columns(src_sel, orig_cols) -def _mutate_cols(__data, args, kwargs, verb_name): +def _mutate_cols(__data, args, kwargs, verb_name, arrange_clause=False): result_names = {} # used as ordered set + result_expr = [] sel = __data.last_select for ii, func in enumerate(args): @@ -731,9 +732,12 @@ def _mutate_cols(__data, args, kwargs, verb_name): assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) + if arrange_clause: + result_expr.extend(cols_result) + else: + sel = _select_mutate_result(sel, cols_result) + result_names.update({k: True for k in cols_result.keys()}) - result_names.update({k: True for k in cols_result.keys()}) for new_name, func in kwargs.items(): inner_cols = lift_inner_cols(sel) @@ -748,9 +752,16 @@ def _mutate_cols(__data, args, kwargs, verb_name): ) labeled = new_col.label(new_name) - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True + if arrange_clause: + result_expr.append(labeled) + else: + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + if arrange_clause: + return result_expr, sel return list(result_names), sel @@ -813,7 +824,9 @@ def _arrange(__data, *args): last_sel = __data.last_select cols = lift_inner_cols(last_sel) - + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + new_calls = [] for ii, expr in enumerate(args): if callable(expr): From e3865799fa4da432800d280f161ee469251f096f Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 16:01:22 -0400 Subject: [PATCH 10/27] fix(sql): sqlalchemy 1.3 compat for simplify_select --- siuba/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index cf1d26e2..8f302efa 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -185,7 +185,7 @@ def simplify_sel(sel): # TODO: find simpler way to clone an element. We cannot use the visitors # argument of cloned_traverse, since it visits the inner-most element first. - clone_el = cloned_traverse(select, {}, {}) + clone_el = select._clone() # modify in-place traverse(clone_el, {}, {"select": simplify_sel}) From 0c327545d5e879f4c9b54144ab71697134a7fa88 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 17:18:37 -0400 Subject: [PATCH 11/27] tests: fix bigquery failing due to ordering --- siuba/tests/test_verb_across.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 122e5d88..9376344b 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -10,6 +10,13 @@ from siuba.experimental.pivot.test_pivot import assert_equal_query2 from siuba.sql.translate import SqlColumn, SqlColumnAgg, sql_scalar +# TODO: test transmute +# TODO: test verb(data, _.simple_name) +# TODO: test changing a group var (e.g. mutate, transmute, add_count), then summarizing +# TODO: group_by(cyl) >> count(cyl = cyl + 1) +# TODO: SQL mutate requires immediate CTE (e.g. due to GROUP BY clause) +# TODO: count "n" name + # Helpers ===================================================================== @@ -150,8 +157,9 @@ def test_across_in_mutate(backend, df, func): dst["a_x"] = df.a_x.round() dst["a_y"] = df.a_y.round() - assert_equal_query2(res_explicit, dst, sql_kwargs={"check_dtype": False}) - assert_equal_query2(res_implicit, dst, sql_kwargs={"check_dtype": False}) + sql_kwargs = {"check_dtype": False} + assert_equal_query2(res_explicit, dst, sql_kwargs=sql_kwargs, sql_ordered=False) + assert_equal_query2(res_implicit, dst, sql_kwargs=sql_kwargs, sql_ordered=False) def test_across_in_mutate_grouped_equiv_ungrouped(backend, df): From 30fb5a530e439b054ec53d5067e55ad6aac9aee0 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 3 Oct 2022 20:24:10 -0400 Subject: [PATCH 12/27] tests: assert_equal_query2 defaults to sql_ordered False --- siuba/experimental/pivot/test_pivot.py | 8 ++++---- siuba/experimental/pivot/test_pivot_wide.py | 20 ++++++++++++-------- siuba/tests/test_verb_across.py | 4 ++-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/siuba/experimental/pivot/test_pivot.py b/siuba/experimental/pivot/test_pivot.py index 357cf96b..0195805c 100644 --- a/siuba/experimental/pivot/test_pivot.py +++ b/siuba/experimental/pivot/test_pivot.py @@ -24,7 +24,7 @@ -def assert_equal_query2(left, right, *args, sql_ordered=True, sql_kwargs=None, **kwargs): +def assert_equal_query2(left, right, *args, sql_ordered=False, sql_kwargs=None, **kwargs): from siuba.sql import LazyTbl from siuba.sql.utils import _is_dialect_duckdb from pandas.core.groupby import DataFrameGroupBy @@ -111,7 +111,7 @@ def test_preserves_original_keys(backend): pv = pivot_longer(remote, _["y":"z"]) - assert_equal_query2(pv, dst, sql_ordered=False) + assert_equal_query2(pv, dst) def test_can_drop_missing_values(backend): @@ -123,7 +123,7 @@ def test_can_drop_missing_values(backend): # TODO: sql databases sometimes return float. we need assertions to take # backend (or sql vs pandas) specific options. - assert_equal_query2(pv, dst, sql_ordered=False, check_dtype=False) + assert_equal_query2(pv, dst, check_dtype=False) def test_can_handle_missing_combinations(backend): @@ -139,7 +139,7 @@ def test_can_handle_missing_combinations(backend): src = backend.load_df(df) pv = pivot_longer(src, -_.id, names_to = (".value", "n"), names_sep = "_") - assert_equal_query2(pv, dst, sql_ordered=False) + assert_equal_query2(pv, dst) @pytest.mark.xfail diff --git a/siuba/experimental/pivot/test_pivot_wide.py b/siuba/experimental/pivot/test_pivot_wide.py index c877c277..75be28cf 100644 --- a/siuba/experimental/pivot/test_pivot_wide.py +++ b/siuba/experimental/pivot/test_pivot_wide.py @@ -23,7 +23,7 @@ def test_pivot_all_cols(backend): pv = pivot_wider(src, names_from=_.key, values_from=_.val) # Note: duckdb is ordered - assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}) + assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}, sql_ordered=True) def test_pivot_id_cols_default_preserve(backend): @@ -35,7 +35,7 @@ def test_pivot_id_cols_default_preserve(backend): pv = pivot_wider(src, names_from = _.key, values_from = _.val) # Note: duckdb is ordered - assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}) + assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}, sql_ordered=True) def test_pivot_implicit_missings_to_explicit(): @@ -44,7 +44,7 @@ def test_pivot_implicit_missings_to_explicit(): pv = pivot_wider(src, names_from = _.key, values_from = _.val) - assert_equal_query2(pv, dst) + assert_equal_query2(pv, dst, sql_ordered=True) def test_pivot_implicit_missings_to_explicit_from_spec(backend): @@ -96,7 +96,7 @@ def test_names_repair_unique(skip_backend, backend): pv = pivot_wider(src, names_repair=lambda x: [f"{v}_{ii}" for ii, v in enumerate(x)]) - assert_equal_query2(pv, dst) + assert_equal_query2(pv, dst, sql_ordered=True) def test_names_repair_minimal(): @@ -131,7 +131,7 @@ def test_weird_column_name_select(skip_backend, backend): pv = pivot_wider(src, names_from = _["...8"], values_from = _.val) - assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}) + assert_equal_query2(pv, dst, sql_kwargs = {"check_like": True}, sql_ordered=True) @pytest.mark.skip("Won't do") @@ -429,7 +429,7 @@ def test_selecting_all_id_cols_excludes_names_from_values_from(backend): dst = data_frame(key = "x", a = 1) pv = pivot_wider(src, _[:]) - assert_equal_query2(pv, dst) + assert_equal_query2(pv, dst, sql_ordered=True) # TODO: also test pivot_wider_spec @@ -467,7 +467,12 @@ def test_pivot_zero_row_frame_id_excludes_values_from(backend): pv = pivot_wider(src, names_from = _.name, values_from = _.value) # SQL backends return a empty Index, pandas an empty RangeIndex ¯\_(ツ)_/¯ - assert_equal_query2(pv, dst, sql_kwargs = {"check_index_type": False, "check_dtype": False}) + assert_equal_query2( + pv, + dst, + sql_kwargs = {"check_index_type": False, "check_dtype": False}, + sql_ordered=True + ) #TODO: @@ -584,7 +589,6 @@ def test_values_fn_arg_str(backend): assert_equal_query2( pv, data_frame(a = [1, 2], x = [3, 3]), - sql_ordered=False, sql_kwargs={"check_dtype": False} ) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 9376344b..b977c0b9 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -158,8 +158,8 @@ def test_across_in_mutate(backend, df, func): dst["a_y"] = df.a_y.round() sql_kwargs = {"check_dtype": False} - assert_equal_query2(res_explicit, dst, sql_kwargs=sql_kwargs, sql_ordered=False) - assert_equal_query2(res_implicit, dst, sql_kwargs=sql_kwargs, sql_ordered=False) + assert_equal_query2(res_explicit, dst, sql_kwargs=sql_kwargs) + assert_equal_query2(res_implicit, dst, sql_kwargs=sql_kwargs) def test_across_in_mutate_grouped_equiv_ungrouped(backend, df): From 4e66fa0b0bce38a0a8b4b96b2ceffba061672de3 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 4 Oct 2022 15:24:08 -0400 Subject: [PATCH 13/27] feat(sql): extremely rough version of summarize across --- siuba/dply/across.py | 10 +- siuba/sql/across.py | 4 +- siuba/sql/utils.py | 10 +- siuba/sql/verbs.py | 276 ++++++++++++++++++-------------- siuba/tests/test_verb_across.py | 12 +- 5 files changed, 187 insertions(+), 125 deletions(-) diff --git a/siuba/dply/across.py b/siuba/dply/across.py index e72cf4fd..d52d652b 100644 --- a/siuba/dply/across.py +++ b/siuba/dply/across.py @@ -16,6 +16,7 @@ ctx_verb_data = ContextVar("data") +ctx_verb_window = ContextVar("window") def _is_symbolic_operator(f): @@ -34,22 +35,27 @@ def _require_across(call, verb_name): ) -def _eval_with_context(ctx, data, expr): +def _eval_with_context(ctx, window_ctx, data, expr): + # TODO: should just set the translator as context (e.g. agg translater, etc..) token = ctx_verb_data.set(ctx) + token_win = ctx_verb_window.set(window_ctx) try: return expr(data) finally: ctx_verb_data.reset(token) + ctx_verb_window.reset(token_win) @contextmanager -def _set_data_context(ctx): +def _set_data_context(ctx, window): try: token = ctx_verb_data.set(ctx) + token_win = ctx_verb_window.set(window) yield finally: ctx_verb_data.reset(token) + ctx_verb_window.reset(token_win) diff --git a/siuba/sql/across.py b/siuba/sql/across.py index 9b7cbc08..4fec8865 100644 --- a/siuba/sql/across.py +++ b/siuba/sql/across.py @@ -1,4 +1,4 @@ -from siuba.dply.across import across, _get_name_template, _across_setup_fns, ctx_verb_data +from siuba.dply.across import across, _get_name_template, _across_setup_fns, ctx_verb_data, ctx_verb_window from siuba.dply.tidyselect import var_select, var_create from siuba.siu import FormulaContext, Call @@ -34,6 +34,7 @@ def _across_sql_cols( ) -> sql.base.ImmutableColumnCollection: lazy_tbl = ctx_verb_data.get() + window = ctx_verb_window.get() name_template = _get_name_template(fns, names) selected_cols = var_select(__data, *var_create(cols), data=__data) @@ -56,6 +57,7 @@ def _across_sql_cols( new_call = lazy_tbl.shape_call( fn, + window, verb_name="Across", arg_name = f"function {fn_name} of {len(fns_map)}" ) diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 8f302efa..68b9c7cd 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -136,9 +136,15 @@ def _sql_add_columns(select, columns): def _sql_with_only_columns(select, columns): if is_sqla_12() or is_sqla_13(): - return select.with_only_columns(columns) + out = select.with_only_columns(columns) + else: + out = select.with_only_columns(*columns) + + # ensure removing all columns doesn't remove from clause table reference + for _from in select.froms: + out = out.select_from(_from) - return select.with_only_columns(*columns) + return out def _sql_case(*args, **kwargs): diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 93c64067..df7831fa 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -260,6 +260,16 @@ def _(col_expr, group_by, order_by, window_cte = None): return col_expr, over_clauses, window_cte +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] def lift_inner_cols(tbl): cols = list(tbl.inner_columns) @@ -625,7 +635,7 @@ def _filter(__data, *args): conds = [] windows = [] - with _set_data_context(__data): + with _set_data_context(__data, window=True): for ii, arg in enumerate(args): if isinstance(arg, Call): @@ -710,93 +720,64 @@ def _select_mutate_result(src_sel, expr_result): return _sql_upsert_columns(src_sel, orig_cols) -def _mutate_cols(__data, args, kwargs, verb_name, arrange_clause=False): - result_names = {} # used as ordered set - result_expr = [] - sel = __data.last_select +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) - for ii, func in enumerate(args): - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - result_names[simple_name] = True - continue + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] - # case 2: across ---- - _require_across(func, verb_name) + # case 2: across ---- + _require_across(func, verb_name) - inner_cols = lift_inner_cols(sel) - cols_result = _eval_with_context(__data, inner_cols, func) + cols_result = _eval_with_context(__data, window, inner_cols, func) - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - # replace any labels that require a subquery ---- - if arrange_clause: - result_expr.extend(cols_result) - else: - sel = _select_mutate_result(sel, cols_result) - result_names.update({k: True for k in cols_result.keys()}) + return cols_result - - for new_name, func in kwargs.items(): - inner_cols = lift_inner_cols(sel) - expr_shaped = __data.shape_call(func, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - labeled = new_col.label(new_name) + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - if arrange_clause: - result_expr.append(labeled) - else: - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + return new_col.label(new_name) - if arrange_clause: - return result_expr, sel - return list(result_names), sel +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) -def _mutate_select(sel, colname, func, labs, __data): - """Return select statement containing a new column, func expr as colname. - - Note: since a func can refer to previous columns generated by mutate, this - function handles whether to add a column to the existing select statement, - or to use it as a subquery. - """ - replace_col = colname in lift_inner_cols(sel) - # Call objects let us check whether column expr used a derived column - # e.g. SELECT a as b, b + 1 as c raises an error in SQL, so need subquery - if not col_expr_requires_cte(func, sel, is_mutate = True): - # New column may be able to modify existing select - columns = lift_inner_cols(sel) + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True - else: - # anything else requires a subquery - cte = sel.alias(None) - columns = cte.columns - sel = cte.select() + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - # evaluate call expr on columns, making sure to use group vars - new_col, windows, _ = __data.track_call_windows(func, columns) + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True - # replacing an existing column, so strip it from select statement - if replace_col: - replaced = {**columns} - replaced[colname] = new_col.label(colname) - return sel.with_only_columns(list(replaced.values())) - return _sql_add_columns(sel, [new_col.label(colname)]) + return list(result_names), sel @transmute.register(LazyTbl) @@ -928,62 +909,125 @@ def _add_count(__data, *args, wt = None, sort = False, **kwargs): @summarize.register(LazyTbl) def _summarize(__data, *args, **kwargs): # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - old_sel = __data.last_select._clone() - new_calls = {} - for k, expr in kwargs.items(): - new_calls[k] = __data.shape_call( - expr, window = False, - verb_name = "Summarize", arg_name = k - ) + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - needs_cte = [col_expr_requires_cte(call, old_sel) for call in new_calls.values()] - group_on_labels = set(__data.group_by) & get_inner_labels(old_sel) + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) - # create select statement ---- + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] - if any(needs_cte) or len(group_on_labels): - # need a cte, due to alias cols or existing group by - # current select stmt has group by clause, so need to make it subquery - cte = old_sel.alias() - columns = cte.columns - sel = sql.select().select_from(cte) - else: - # otherwise, can alter the existing select statement - columns = lift_inner_cols(old_sel) - sel = old_sel - - # explicitly add original from clause tables, since we will be limiting - # the columns this select uses, which causes sqlalchemy to remove - # unreferenced tables - for _from in sel.froms: - sel = sel.select_from(_from) - + final_sel = out_sel.group_by(*group_cols) - # add group by columns ---- - group_cols = [columns[k] for k in __data.group_by] + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data - # add each aggregate column ---- - # TODO: can't do summarize(b = mean(a), c = b + mean(a)) - # since difficult for c to refer to agg and unagg cols in SQL - expr_cols = [] - for k, expr in new_calls.items(): - missing_cols = get_missing_columns(expr, columns) - if missing_cols: - raise NotImplementedError( - "Summarize cannot find the following columns: %s. " - "Note that it cannot refer to variables defined earlier in the " - "same summarize call." % missing_cols - ) - col = expr(columns).label(k) - expr_cols.append(col) +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col - all_cols = [*group_cols, *expr_cols] - final_sel = _sql_with_only_columns(sel, all_cols).group_by(*group_cols) - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) @group_by.register(LazyTbl) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index b977c0b9..b6781ebf 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -8,7 +8,7 @@ from siuba.dply.across import across from siuba.experimental.pivot.test_pivot import assert_equal_query2 -from siuba.sql.translate import SqlColumn, SqlColumnAgg, sql_scalar +from siuba.sql.translate import SqlColumn, SqlColumnAgg, sql_scalar, win_agg, sql_agg # TODO: test transmute # TODO: test verb(data, _.simple_name) @@ -37,6 +37,9 @@ def f_round(x) -> pd.Series: def f_mean(x) -> pd.Series: return x.mean() +f_mean.register(SqlColumn, win_agg("avg")) +f_mean.register(SqlColumnAgg, sql_agg("avg")) + def assert_grouping_names(gdf, names): from siuba.sql import LazyTbl @@ -174,14 +177,15 @@ def test_across_in_mutate_grouped_equiv_ungrouped(backend, df): assert_equal_query2(ungroup(g_res), collect(dst)) -def test_across_in_summarize(df): - res = summarize(df, across(_, _[_.a_x, _.a_y], f_mean)) +def test_across_in_summarize(backend, df): + src = backend.load_df(df) + res = summarize(src, across(_, _[_.a_x, _.a_y], f_mean)) dst = pd.DataFrame({ "a_x": [df.a_x.mean()], "a_y": [df.a_y.mean()] }) - assert_frame_equal(res, dst) + assert_equal_query2(res, dst) def test_across_in_summarize_equiv_ungrouped(): From 27391d570af7e3129fdbf39e65a286daf9dc36be Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 4 Oct 2022 16:45:21 -0400 Subject: [PATCH 14/27] refactor(sql): prepare to split verbs file --- siuba/sql/{verbs.py => backend.py} | 0 siuba/sql/verbs/arrange.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/compute.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/conditional.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/count.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/distinct.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/explain.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/filter.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/group_by.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/head.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/join.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/mutate.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/select.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/summarize.py | 1395 ++++++++++++++++++++++++++++ siuba/sql/verbs/transmute.py | 1395 ++++++++++++++++++++++++++++ 15 files changed, 19530 insertions(+) rename siuba/sql/{verbs.py => backend.py} (100%) create mode 100644 siuba/sql/verbs/arrange.py create mode 100644 siuba/sql/verbs/compute.py create mode 100644 siuba/sql/verbs/conditional.py create mode 100644 siuba/sql/verbs/count.py create mode 100644 siuba/sql/verbs/distinct.py create mode 100644 siuba/sql/verbs/explain.py create mode 100644 siuba/sql/verbs/filter.py create mode 100644 siuba/sql/verbs/group_by.py create mode 100644 siuba/sql/verbs/head.py create mode 100644 siuba/sql/verbs/join.py create mode 100644 siuba/sql/verbs/mutate.py create mode 100644 siuba/sql/verbs/select.py create mode 100644 siuba/sql/verbs/summarize.py create mode 100644 siuba/sql/verbs/transmute.py diff --git a/siuba/sql/verbs.py b/siuba/sql/backend.py similarity index 100% rename from siuba/sql/verbs.py rename to siuba/sql/backend.py diff --git a/siuba/sql/verbs/arrange.py b/siuba/sql/verbs/arrange.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/arrange.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/compute.py b/siuba/sql/verbs/compute.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/compute.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/conditional.py b/siuba/sql/verbs/conditional.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/conditional.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/count.py b/siuba/sql/verbs/count.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/count.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/distinct.py b/siuba/sql/verbs/distinct.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/distinct.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/explain.py b/siuba/sql/verbs/explain.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/explain.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/filter.py b/siuba/sql/verbs/filter.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/filter.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/group_by.py b/siuba/sql/verbs/group_by.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/group_by.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/head.py b/siuba/sql/verbs/head.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/head.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/join.py b/siuba/sql/verbs/join.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/join.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/mutate.py b/siuba/sql/verbs/mutate.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/mutate.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/select.py b/siuba/sql/verbs/select.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/select.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/summarize.py b/siuba/sql/verbs/summarize.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/summarize.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + diff --git a/siuba/sql/verbs/transmute.py b/siuba/sql/verbs/transmute.py new file mode 100644 index 00000000..8dcbef7b --- /dev/null +++ b/siuba/sql/verbs/transmute.py @@ -0,0 +1,1395 @@ +""" +Implements LazyTbl to represent tables of SQL data, and registers it on verbs. + +This module is responsible for the handling of the "table" side of things, while +translate.py handles translating column operations. + + +""" + +import warnings + +from siuba.dply.verbs import ( + show_query, collect, + simple_varname, + select, + mutate, + transmute, + filter, + arrange, _call_strip_ascending, + summarize, + count, add_count, + group_by, ungroup, + case_when, + join, left_join, right_join, inner_join, semi_join, anti_join, + head, + rename, + distinct, + if_else, + _select_group_renames, + _var_select_simple + ) + +from siuba.dply.tidyselect import VarList, var_select + +from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _is_dialect_duckdb, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + _sql_simplify_select, + MockConnection +) + +from sqlalchemy import sql +import sqlalchemy +from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 +# TODO: currently needed for select, but can we remove pandas? +from pandas import Series +from functools import singledispatch + +from sqlalchemy.sql import schema + +from siuba.dply.across import _require_across, _set_data_context, _eval_with_context + +# TODO: +# - distinct +# - annotate functions using sel.prefix_with("\n/**/\n") ? + + +# Helpers --------------------------------------------------------------------- + +class SqlFunctionLookupError(FunctionLookupError): pass + + +class CallListener: + """Generic listener. Each exit is called on a node's copy.""" + def enter(self, node): + args, kwargs = node.map_subcalls(self.enter) + + return self.exit(node.__class__(node.func, *args, **kwargs)) + + def exit(self, node): + return node + + +class WindowReplacer(CallListener): + """Call tree listener. + + Produces 2 important behaviors via the enter method: + - returns evaluated sql call expression, with labels on all window expressions. + - stores all labeled window expressions via the windows property. + + TODO: could replace with a sqlalchemy transformer + """ + + def __init__(self, columns, group_by, order_by, window_cte = None): + self.columns = columns + self.group_by = group_by + self.order_by = order_by + self.window_cte = window_cte + self.windows = [] + + def exit(self, node): + col_expr = node(self.columns) + + if not isinstance(col_expr, sql.elements.ClauseElement): + return col_expr + + over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] + + # put groupings and orderings onto custom over clauses + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + group_by = sql.elements.ClauseList( + *[self.columns[name] for name in self.group_by] + ) + order_by = sql.elements.ClauseList( + *_create_order_by_clause(self.columns, *self.order_by) + ) + + over.set_over(group_by, order_by) + + if len(over_clauses) and self.window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + self.window_cte = _sql_add_columns(self.window_cte, [label]) + win_col = lift_inner_cols(self.window_cte).values()[-1] + self.windows.append(win_col) + + return win_col + + return col_expr + + @staticmethod + def _get_unique_name(prefix, columns): + column_names = set(columns.keys()) + + i = 1 + name = prefix + str(i) + while name in column_names: + i += 1 + name = prefix + str(i) + + + return name + + @staticmethod + def _get_over_clauses(clause): + windows = [] + append_win = lambda col: windows.append(col) + + sql.util.visitors.traverse(clause, {}, {"over": append_win}) + + return windows + + +class SqlLabelReplacer: + """Create a visitor to replace source labels with destination. + + Note that this is meant to be used with sqlalchemy visitors. + """ + + def __init__(self, src_columns, dst_columns): + self.src_columns = src_columns + self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) + self.dst_columns = dst_columns + self.applied = False + + def __call__(self, clause): + return sql.util.visitors.replacement_traverse(clause, {}, self.visit) + + def visit(self, el): + from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause + from sqlalchemy.sql.schema import Column + + if isinstance(el, TypeClause): + # TODO: for some reason this type throws an error if unguarded + return None + + if isinstance(el, ClauseElement): + if el in self.src_labels: + self.applied = True + return self.dst_columns[el.name] + elif el in self.src_columns: + return self.dst_columns[el.name] + + # TODO: should we create a subquery if the user passed raw text? + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True + + return None + + +#def track_call_windows(call, columns, group_by, order_by, window_cte = None): +# listener = WindowReplacer(columns, group_by, order_by, window_cte) +# col = listener.enter(call) +# return col, listener.windows, listener.window_cte + + +def track_call_windows(call, columns, group_by, order_by, window_cte = None): + col_expr = call(columns) + + crnt_group_by = sql.elements.ClauseList( + *[columns[name] for name in group_by] + ) + crnt_order_by = sql.elements.ClauseList( + *_create_order_by_clause(columns, *order_by) + ) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) + + + +@singledispatch +def replace_call_windows(col_expr, group_by, order_by, window_cte = None): + raise TypeError(str(type(col_expr))) + + +@replace_call_windows.register(sql.base.ImmutableColumnCollection) +def _(col_expr, group_by, order_by, window_cte = None): + all_over_clauses = [] + for col in col_expr: + _, over_clauses, window_cte = replace_call_windows( + col, + group_by, + order_by, + window_cte + ) + all_over_clauses.extend(over_clauses) + + return col_expr, all_over_clauses, window_cte + + +@replace_call_windows.register(sql.elements.ClauseElement) +def _(col_expr, group_by, order_by, window_cte = None): + + over_clauses = WindowReplacer._get_over_clauses(col_expr) + + for over in over_clauses: + # TODO: shouldn't mutate these over clauses + over.set_over(group_by, order_by) + + if len(over_clauses) and window_cte is not None: + # custom name, or parameters like "%(...)s" may nest and break psycopg2 + # with columns you can set a key to fix this, but it doesn't seem to + # be an option with labels + name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) + label = col_expr.label(name) + + # put into CTE, and return its resulting column, so that subsequent + # operations will refer to the window column on window_cte. Note that + # the operations will use the actual column, so may need to use the + # ClauseAdaptor to make it a reference to the label + window_cte = _sql_add_columns(window_cte, [label]) + win_col = lift_inner_cols(window_cte).values()[-1] + + return win_col, over_clauses, window_cte + + return col_expr, over_clauses, window_cte + +def get_single_from(sel): + froms = sel.froms + + n_froms = len(froms) + if n_froms != 1: + raise ValueError( + f"Expected a single table in the from clause, but found {n_froms}" + ) + + return froms[0] + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + +def col_expr_requires_cte(call, sel, is_mutate = False): + """Return whether a variable assignment needs a CTE""" + + call_vars = set(call.op_vars(attr_calls = False)) + + sel_labs = get_inner_labels(sel) + + # I use the acronym fwg sol (frog soul) to remember sql clause eval order + # from, where, group by, select, order by, limit + # group clause evaluated before select clause, so not issue for mutate + group_needs_cte = not is_mutate and len(sel._group_by_clause) + + return ( group_needs_cte + # TODO: detect when a new var in mutate conflicts w/ order by + #or len(sel._order_by_clause) + or not sel_labs.isdisjoint(call_vars) + ) + +def get_inner_labels(sel): + columns = lift_inner_cols(sel) + sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) + return sel_labs + +def get_missing_columns(call, columns): + missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) + return missing_cols + +def compile_el(tbl, el): + compiled = el.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + return compiled + +# Misc utilities -------------------------------------------------------------- + +def ordered_union(x, y): + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + + +def _warn_missing(missing_groups): + warnings.warn(f"Adding missing grouping variables: {missing_groups}") + + +# Table ----------------------------------------------------------------------- + +class LazyTbl: + def __init__( + self, source, tbl, columns = None, + ops = None, group_by = tuple(), order_by = tuple(), + translator = None + ): + """Create a representation of a SQL table. + + Args: + source: a sqlalchemy.Engine or sqlalchemy.Connection instance. + tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. + columns: if specified, a listlike of column names. + + Examples + -------- + + :: + from sqlalchemy import create_engine + from siuba.data import mtcars + + # create database and table + engine = create_engine("sqlite:///:memory:") + mtcars.to_sql('mtcars', engine) + + tbl_mtcars = LazyTbl(engine, 'mtcars') + + """ + + # connection and dialect specific functions + self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source + + # get dialect name + dialect = self.source.dialect.name + self.translator = get_dialect_translator(dialect) + + self.tbl = self._create_table(tbl, columns, self.source) + + # important states the query can be in (e.g. grouped) + self.ops = [self.tbl] if ops is None else ops + + self.group_by = group_by + self.order_by = order_by + + + def append_op(self, op, **kwargs): + cpy = self.copy(**kwargs) + cpy.ops = cpy.ops + [op] + return cpy + + def copy(self, **kwargs): + return self.__class__(**{**self.__dict__, **kwargs}) + + def shape_call( + self, + call, window = True, str_accessors = False, + verb_name = None, arg_name = None, + ): + return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) + + def track_call_windows(self, call, columns = None, window_cte = None): + """Returns tuple of (new column expression, list of window exprs)""" + + columns = self.last_op.columns if columns is None else columns + return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + def get_ordered_col_names(self): + """Return columns from current select, with grouping columns first.""" + ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] + return list(self.group_by) + ungrouped + + #def label_breaks_order_by(self, name): + # """Returns True if a new column label would break the order by vars.""" + + # # TODO: arrange currently allows literals, which breaks this. it seems + # # better to only allow calls in arrange. + # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} + + + + + @property + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op + + @staticmethod + def _create_table(tbl, columns = None, source = None): + """Return a sqlalchemy.Table, autoloading column info if needed. + + Arguments: + tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. + columns: a tuple of column names for the table. Overrides source argument. + source: a sqlalchemy engine, used to autoload columns. + + """ + if isinstance(tbl, sql.selectable.FromClause): + return tbl + + if not isinstance(tbl, str): + raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) + + if columns is None and source is None: + raise ValueError("One of columns or source must be specified") + + schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] + + columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() + + # TODO: pybigquery uses schema to mean project_id, so we cannot use + # siuba's classic breakdown "{schema}.{table_name}". Basically + # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal + # logic. An important side effect is that bigquery errors for + # `dataset`.`table`, but not `dataset.table`. + if source and source.dialect.name == "bigquery": + table_name = tbl + schema = None + + return sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(bind = source), + *columns, + schema = schema, + autoload_with = source if not columns else None + ) + + def _get_preview(self): + # need to make prev op a cte, so we don't override any previous limit + new_sel = self.last_select.limit(5) + tbl_small = self.append_op(new_sel) + return collect(tbl_small) + + def __repr__(self): + template = ( + "# Source: lazy query\n" + "# DB Conn: {}\n" + "# Preview:\n{}\n" + "# .. may have more rows" + ) + + return template.format(repr(self.source.engine), repr(self._get_preview())) + + def _repr_html_(self): + template = ( + "
" + "
"
+                "# Source: lazy query\n"
+                "# DB Conn: {}\n"
+                "# Preview:\n"
+                "
" + "{}" + "

# .. may have more rows

" + "
" + ) + + data = self._get_preview() + + # _repr_html_ can not exist or return None, to signify that repr should be used + if not hasattr(data, '_repr_html_'): + return None + + html_data = data._repr_html_() + if html_data is None: + return None + + return template.format(self.source.engine, html_data) + + +def _repr_grouped_df_html_(self): + return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" + + + +# Main Funcs +# ============================================================================= + +# sql raw -------------- + +sql_raw = sql.literal_column + +# show query ----------- + +@show_query.register(LazyTbl) +def _show_query(tbl, simplify = False, return_table = True): + #query = tbl.last_op #if not simplify else + compile_query = lambda query: query.compile( + dialect = tbl.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + + + if simplify: + # try to strip table names and labels where unnecessary + simple_sel = _sql_simplify_select(tbl.last_select) + + explained = compile_query(simple_sel) + else: + # use a much more verbose query + explained = compile_query(tbl.last_select) + + if return_table: + print(str(explained)) + return tbl + + return str(explained) + + + +# collect ---------- + +@collect.register(LazyTbl) +def _collect(__data, as_df = True): + # TODO: maybe remove as_df options, always return dataframe + + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + + # compile query ---- + + if _is_dialect_duckdb(__data.source): + # TODO: can be removed once next release of duckdb fixes: + # https://github.com/duckdb/duckdb/issues/2972 + query = __data.last_select + compiled = query.compile( + dialect = __data.source.dialect, + compile_kwargs = {"literal_binds": True} + ) + else: + compiled = __data.last_select + + # execute query ---- + + with __data.source.connect() as conn: + if as_df: + sql_db = _FixedSqlDatabase(conn) + + if _is_dialect_duckdb(__data.source): + # TODO: pandas read_sql is very slow with duckdb. + # see https://github.com/pandas-dev/pandas/issues/45678 + # going to handle here for now. address once LazyTbl gets + # subclassed per backend. + duckdb_con = conn.connection.c + return duckdb_con.query(str(compiled)).to_df() + else: + # + return sql_db.read_sql(compiled) + + return conn.execute(compiled) + + +@select.register(LazyTbl) +def _select(__data, *args, **kwargs): + # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object + if kwargs: + raise NotImplementedError( + "Using kwargs in select not currently supported. " + "Use _.newname == _.oldname instead" + ) + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} + + # same as for DataFrame + colnames = Series(list(columns)) + vl = VarList() + evaluated = (arg(vl) if callable(arg) else arg for arg in args) + od = var_select(colnames, *evaluated) + + missing_groups, group_keys = _select_group_renames(od, __data.group_by) + + if missing_groups: + _warn_missing(missing_groups) + + final_od = {**{k: None for k in missing_groups}, **od} + + col_list = [] + for k,v in final_od.items(): + col = columns[k] + col_list.append(col if v is None else col.label(v)) + + return __data.append_op( + last_sel.with_only_columns(col_list), + group_by = group_keys + ) + + + +@filter.register(LazyTbl) +def _filter(__data, *args): + # Note: currently always produces 2 additional select statements, + # 1 for window/aggs, and 1 for the where clause + + sel = __data.last_op.alias() # original select + win_sel = sel.select() + + conds = [] + windows = [] + with _set_data_context(__data, window=True): + for ii, arg in enumerate(args): + + if isinstance(arg, Call): + new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) + #var_cols = new_call.op_vars(attr_calls = False) + + # note that a new win_sel is returned, w/ window columns appended + col_expr, win_cols, win_sel = __data.track_call_windows( + new_call, + sel.columns, + window_cte = win_sel + ) + + if isinstance(col_expr, sql.base.ImmutableColumnCollection): + conds.extend(col_expr) + else: + conds.append(col_expr) + + windows.extend(win_cols) + + else: + conds.append(arg) + + bool_clause = sql.and_(*conds) + + # first cte, windows ---- + if len(windows): + + win_alias = win_sel.alias() + + # move non-window functions to refer to win_sel clause (not the innermost) --- + bool_clause = sql.util.ClauseAdapter(win_alias) \ + .traverse(bool_clause) + + orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] + else: + orig_cols = [sel] + + # create second cte ---- + filt_sel = _sql_select(orig_cols).where(bool_clause) + return __data.append_op(filt_sel) + + +@mutate.register(LazyTbl) +def _mutate(__data, *args, **kwargs): + # TODO: verify it can follow a renaming select + + # track labeled columns in set + if not (len(args) or len(kwargs)): + return __data.append_op(__data.last_op) + + names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") + return __data.append_op(sel_out) + + +def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): + orig_cols = lift_inner_cols(sel) + replaced = {**orig_cols} + + for new_col in new_columns: + replaced[new_col.name] = new_col + return _sql_with_only_columns(sel, list(replaced.values())) + + +def _select_mutate_result(src_sel, expr_result): + dst_alias = src_sel.alias() + src_columns = set(lift_inner_cols(src_sel)) + replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) + + if isinstance(expr_result, sql.base.ImmutableColumnCollection): + replaced_cols = list(map(replacer, expr_result)) + orig_cols = expr_result + #elif isinstance(expr_result, None): + # pass + else: + replaced_cols = [replacer(expr_result)] + orig_cols = [expr_result] + + if replacer.applied: + return _sql_upsert_columns(dst_alias.select(), replaced_cols) + + return _sql_upsert_columns(src_sel, orig_cols) + + +def _eval_expr_arg(__data, sel, func, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + # case 1: simple names ---- + simple_name = simple_varname(func) + if simple_name is not None: + return inner_cols[simple_name] + + # case 2: across ---- + _require_across(func, verb_name) + + cols_result = _eval_with_context(__data, window, inner_cols, func) + + # TODO: remove or raise a more informative error + assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) + + return cols_result + + +def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): + inner_cols = lift_inner_cols(sel) + + expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) + new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) + + if isinstance(new_col, sql.base.ImmutableColumnCollection): + raise TyepError( + f"{verb_name} named arguments must return a single column, but `{k}` " + "returned multiple columns." + ) + + return new_col.label(new_name) + + +def _mutate_cols(__data, args, kwargs, verb_name): + result_names = {} # used as ordered set + sel = __data.last_select + + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name) + + # replace any labels that require a subquery ---- + sel = _select_mutate_result(sel, cols_result) + + if isinstance(cols_result, sql.base.ImmutableColumnCollection): + result_names.update({k: True for k in cols_result.keys()}) + else: + result_names[cols_result.name] = True + + + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) + + sel = _select_mutate_result(sel, labeled) + result_names[new_name] = True + + + return list(result_names), sel + + +@transmute.register(LazyTbl) +def _transmute(__data, *args, **kwargs): + # will use mutate, then select some cols + result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") + + # transmute keeps grouping cols, and any defined in kwargs + missing = [x for x in __data.group_by if x not in result_names] + cols_to_keep = [*missing, *result_names] + + columns = lift_inner_cols(sel) + sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) + + return __data.append_op(sel_stripped) + + +@arrange.register(LazyTbl) +def _arrange(__data, *args): + # Note that SQL databases often do not subquery order by clauses. Arrange + # sets order_by on the backend, so it can set order by in over elements, + # and handle when new columns are named the same as order by vars. + # see: https://dba.stackexchange.com/q/82930 + + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) + + # TODO: implement across in arrange + #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + + new_calls = [] + for ii, expr in enumerate(args): + if callable(expr): + + res = __data.shape_call( + expr, window = False, + verb_name = "Arrange", arg_name = ii + ) + + else: + res = expr + + new_calls.append(res) + + sort_cols = _create_order_by_clause(cols, *new_calls) + + order_by = __data.order_by + tuple(new_calls) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + + +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + + + +@count.register(LazyTbl) +def _count(__data, *args, sort = False, wt = None, **kwargs): + # TODO: if already col named n, use name nn, etc.. get logic from tidy.py + if wt is not None: + raise NotImplementedError("TODO") + + res_name = "n" + # similar to filter verb, we need two select statements, + # an inner one for derived cols, and outer to group by them + + # inner select ---- + # holds any mutation style columns + #arg_names = [] + #for arg in args: + # name = simple_varname(arg) + # if name is None: + # raise NotImplementedError( + # "Count positional arguments must be single column name. " + # "Use a named argument to count using complex expressions." + # ) + # arg_names.append(name) + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_op + + # create outer select ---- + # holds selected columns and tally (n) + sel_inner_cte = sel_inner.alias() + inner_cols = sel_inner_cte.columns + + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + # holds the actual count (e.g. n) + count_col = sql.functions.count().label(res_name) + + sel_outer = _sql_select([*outer_group_cols, count_col]) \ + .select_from(sel_inner_cte) \ + .group_by(*outer_group_cols) + + # count is like summarize, so removes order_by + return __data.append_op( + sel_outer.order_by(count_col.desc()), + order_by = tuple() + ) + + +@add_count.register(LazyTbl) +def _add_count(__data, *args, wt = None, sort = False, **kwargs): + counts = count(__data, *args, wt = wt, sort = sort, **kwargs) + by = list(c.name for c in counts.last_select.inner_columns)[:-1] + + return inner_join(__data, counts, by = by) + + +@summarize.register(LazyTbl) +def _summarize(__data, *args, **kwargs): + # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query + + # get query with correct from clause, and maybe unneeded subquery + safe_from = __data.last_select.alias() + result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") + + # see if we can remove subquery + out_sel = _collapse_select(sel, safe_from) + + from_tbl = get_single_from(out_sel) + group_cols = [from_tbl.columns[k] for k in __data.group_by] + + final_sel = out_sel.group_by(*group_cols) + + new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) + return new_data + + +def _collapse_select(outer_sel, inner_alias): + # check whether any outer columns reference an inner label ---- + inner_sel = inner_alias.element + + columns = lift_inner_cols(outer_sel) + inner_cols = lift_inner_cols(inner_sel) + + inner_labels = set([ + x.name for x in inner_cols + if isinstance(x, sql.elements.Label) + ]) + + col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) + + bad_refs = [] + + def collect_refs(el): + if el in col_requires_cte: + bad_refs.append(el) + + for col in columns: + sql.util.visitors.traverse(col, {}, {"column": collect_refs}) + + # if possible, remove the outer query ---- + if not (bad_refs or len(inner_sel._group_by_clause)): + from sqlalchemy.sql.elements import ColumnClause, Label + + from_obj = get_single_from(inner_sel) + adaptor = sql.util.ClauseAdapter( + from_obj, + adapt_on_names=True, + include_fn=lambda c: isinstance(c, (ColumnClause, Label)) + ) + + new_cols = [] + for col in columns: + if isinstance(col, Label): + res = adaptor.traverse(col.element).label(col.name) + new_cols.append(res) + + else: + new_cols.append(adaptor.traverse(col)) + #new_cols = list(map(adaptor.traverse, columns)) + + return _sql_with_only_columns(inner_sel, new_cols) + + return outer_sel + + +def _aggregate_cols(__data, subquery, args, kwargs, verb_name): + # cases: + # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) + # * no existing labels referred to - can use same select + # * existing labels referred to - need 1 subquery tops + # * groups + summarize columns can replace everything + + def get_label_clauses(clause): + out = [] + sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) + + return out + + def quote_varname(x): + return f"`{x}`" + + def validate_references(arg_name, expr, verb_name): + bad_varnames = get_label_clauses(expr) + repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) + + if not bad_varnames: + return + + raise NotImplementedError( + f"In SQL, you cannot refer to a column created in the same {verb_name}. " + f"`{arg_name}` refers to columns created earlier: {repr_names}." + ) + + sel = subquery.select() + + final_cols = {k: subquery.columns[k] for k in __data.group_by} + + # handle args ---- + for ii, func in enumerate(args): + cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) + + for col in cols_result: + validate_references(col.name, col.element, verb_name) + final_cols[col.name] = col + + sel = _sql_upsert_columns(sel, cols_result) + + + # handle kwargs ---- + for new_name, func in kwargs.items(): + labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) + + validate_references(labeled.name, labeled.element, verb_name) + final_cols[new_name] = labeled + + sel = _sql_upsert_columns(sel, [labeled]) + + return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) + + +@group_by.register(LazyTbl) +def _group_by(__data, *args, add = False, **kwargs): + if not (args or kwargs): + return __data.copy() + + group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") + + if None in group_names: + raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") + + # check whether we can just use underlying table ---- + new_cols = lift_inner_cols(sel) + if set(new_cols).issubset(set(__data.last_op.columns)): + sel = __data.last_op + + if add: + group_names = ordered_union(__data.group_by, group_names) + + return __data.append_op(sel, group_by = tuple(group_names)) + + +@ungroup.register(LazyTbl) +def _ungroup(__data): + return __data.copy(group_by = tuple()) + + +@case_when.register(sql.base.ImmutableColumnCollection) +def _case_when(__data, cases): + # TODO: will need listener to enter case statements, to handle when they use windows + if isinstance(cases, Call): + cases = cases(__data) + + whens = [] + case_items = list(cases.items()) + n_items = len(case_items) + + else_val = None + for ii, (expr, val) in enumerate(case_items): + # handle where val is a column expr + if callable(val): + val = val(__data) + + # handle when expressions + if ii+1 == n_items and expr is True: + else_val = val + elif callable(expr): + whens.append((expr(__data), val)) + else: + whens.append((expr, val)) + + return sql.case(whens, else_ = else_val) + + +# Join ------------------------------------------------------------------------ + +from collections.abc import Mapping + +def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): + """Return labeled columns, according to selection rules for joins. + + Rules: + 1. For join keys, keep left table's column + 2. When keys have the same labels, add suffix + """ + + # TODO: remove sets, so uses stable ordering + # when left and right cols have same name, suffix with _x / _y + keep_right = set(right_cols.keys()) - set(on_keys.values()) + shared_labs = set(left_cols.keys()).intersection(keep_right) + + right_cols_no_keys = {k: right_cols[k] for k in keep_right} + + # for an outer join, have key columns coalesce values + + left_cols = {**left_cols} + if how == "full": + for lk, rk in on_keys.items(): + col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) + left_cols[lk] = col.label(lk) + elif how == "right": + for lk, rk in on_keys.items(): + # Make left key columns actually be right ones (which contain left + extra) + left_cols[lk] = right_cols[rk].label(lk) + + + # create labels ---- + l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) + r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) + + return l_labs + r_labs + + + +def _relabeled_cols(columns, keys, suffix): + # add a suffix to all columns with names in keys + cols = [] + for k, v in columns.items(): + new_col = v.label(k + str(suffix)) if k in keys else v + cols.append(new_col) + return cols + + +@join.register(LazyTbl) +def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): + _raise_if_args(args) + + if on is None and by is not None: + on = by + + # Needs to be on the table, not the select + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on) + how = _validate_join_arg_how(how) + + # for equality join used to combine keys into single column + consolidate_keys = on if sql_on is None else {} + + if how == "right": + # switch joins, since sqlalchemy doesn't have right join arg + # see https://stackoverflow.com/q/11400307/1144523 + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create join ---- + join = left_sel.join( + right_sel, + onclause = bool_clause, + isouter = how != "inner", + full = how == "full" + ) + + # if right join, set selects back + if how == "right": + left_sel, right_sel = right_sel, left_sel + on = {v:k for k,v in on.items()} + + # note, shared_keys assumes on is a mapping... + # TODO: shared_keys appears to be for when on is not specified, but was unused + #shared_keys = [k for k,v in on.items() if k == v] + labeled_cols = _joined_cols( + left_sel.columns, + right_sel.columns, + on_keys = consolidate_keys, + how = how + ) + + sel = _sql_select(labeled_cols).select_from(join) + return left.append_op(sel, order_by = tuple()) + + +@semi_join.register(LazyTbl) +def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + # only keep left hand select's columns ---- + sel = _sql_select(left_sel.columns) \ + .select_from(left_sel) \ + .where(sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + + +@anti_join.register(LazyTbl) +def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): + if on is None and by is not None: + on = by + + _raise_if_args(args) + + left_sel = left.last_op.alias() + right_sel = right.last_op.alias() + + # handle arguments ---- + on = _validate_join_arg_on(on, sql_on, left, right) + + # create join conditions ---- + bool_clause = _create_join_conds(left_sel, right_sel, on) + + # create inner join ---- + #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) + exists_clause = _sql_select([sql.literal(1)]) \ + .select_from(right_sel) \ + .where(bool_clause) + + sel = left_sel.select().where(~sql.exists(exists_clause)) + + return left.append_op(sel, order_by = tuple()) + +def _raise_if_args(args): + if len(args): + raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") + +def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): + # handle sql on case + if sql_on is not None: + if on is not None: + raise ValueError("Cannot specify both on and sql_on") + + return sql_on + + # handle general cases + if on is None: + # TODO: currently, we check for lhs and rhs tables to indicate whether + # a verb supports inferring columns. Otherwise, raise an error. + if lhs is not None and rhs is not None: + # TODO: consolidate with duplicate logic in pandas verb code + warnings.warn( + "No on column passed to join. " + "Inferring join columns instead using shared column names." + ) + + on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) + + if not on_cols: + raise ValueError( + "No join column specified, or shared column names in join." + ) + + # trivial dict mapping shared names to themselves + warnings.warn("Detected shared columns: %s" % on_cols) + on = dict(zip(on_cols, on_cols)) + + else: + raise NotImplementedError("on arg currently cannot be None (default) for SQL") + elif isinstance(on, str): + on = {on: on} + elif isinstance(on, (list, tuple)): + on = dict(zip(on, on)) + + + if not isinstance(on, Mapping): + raise TypeError("on must be a Mapping (e.g. dict)") + + return on + +def _validate_join_arg_how(how): + how_options = ("inner", "left", "right", "full") + if how not in how_options: + raise ValueError("how argument needs to be one of %s" %how_options) + + return how + +def _create_join_conds(left_sel, right_sel, on): + left_cols = left_sel.columns #lift_inner_cols(left_sel) + right_cols = right_sel.columns #lift_inner_cols(right_sel) + + if callable(on): + # callable, like with sql_on arg + conds = [on(left_cols, right_cols)] + else: + # dict-like of form {left: right} + conds = [] + for l, r in on.items(): + col_expr = left_cols[l] == right_cols[r] + conds.append(col_expr) + + return sql.and_(*conds) + + +# Head ------------------------------------------------------------------------ + +@head.register(LazyTbl) +def _head(__data, n = 5): + sel = __data.last_select + + return __data.append_op(sel.limit(n)) + + +# Rename ---------------------------------------------------------------------- + +@rename.register(LazyTbl) +def _rename(__data, **kwargs): + sel = __data.last_select + columns = lift_inner_cols(sel) + + # old_keys uses dict as ordered set + old_to_new = {simple_varname(v):k for k,v in kwargs.items()} + + if None in old_to_new: + raise KeyError("positional arguments must be simple column, " + "e.g. _.colname or _['colname']" + ) + + labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] + + new_sel = sel.with_only_columns(labs) + + missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) + + return __data.append_op(new_sel, group_by=group_keys) + + +# Distinct -------------------------------------------------------------------- + +@distinct.register(LazyTbl) +def _distinct(__data, *args, _keep_all = False, **kwargs): + if (args or kwargs) and _keep_all: + raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") + + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select + + # TODO: this is copied from the df distinct version + # cols dict below is used as ordered set + cols = _var_select_simple(args) + cols.update(kwargs) + + # use all columns by default + if not cols: + cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + + final_names = {**{k: True for k in __data.group_by}, **cols} + + if not len(inner_sel._order_by_clause): + # select distinct has to include any columns in the order by clause, + # so can only safely modify existing statement when there's no order by + sel_cols = lift_inner_cols(inner_sel) + distinct_cols = [sel_cols[k] for k in final_names] + sel = inner_sel.with_only_columns(distinct_cols).distinct() + else: + # fallback to cte + cte = inner_sel.alias() + distinct_cols = [cte.columns[k] for k in final_names] + sel = _sql_select(distinct_cols).select_from(cte).distinct() + + return __data.append_op(sel) + + +# if_else --------------------------------------------------------------------- + +@if_else.register(sql.elements.ColumnElement) +def _if_else(cond, true_vals, false_vals): + whens = [(cond, true_vals)] + return sql.case(whens, else_ = false_vals) + + From d5be35b5908353712fddb793343427acfe9508eb Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Fri, 7 Oct 2022 18:35:37 -0400 Subject: [PATCH 15/27] refactor(sql): move verbs into own files --- siuba/sql/__init__.py | 3 +- siuba/sql/across.py | 2 +- siuba/sql/backend.py | 1001 +------------------ siuba/sql/utils.py | 7 + siuba/sql/verbs/__init__.py | 15 + siuba/sql/verbs/arrange.py | 1364 +------------------------- siuba/sql/verbs/compute.py | 1353 +------------------------- siuba/sql/verbs/conditional.py | 1359 +------------------------- siuba/sql/verbs/count.py | 1324 +------------------------ siuba/sql/verbs/distinct.py | 1364 +------------------------- siuba/sql/verbs/explain.py | 1363 +------------------------- siuba/sql/verbs/filter.py | 1342 +------------------------- siuba/sql/verbs/group_by.py | 1371 +------------------------- siuba/sql/verbs/head.py | 1390 +------------------------- siuba/sql/verbs/join.py | 1170 +--------------------- siuba/sql/verbs/mutate.py | 1275 +----------------------- siuba/sql/verbs/select.py | 1336 +------------------------ siuba/sql/verbs/summarize.py | 1271 +----------------------- siuba/sql/verbs/transmute.py | 1395 --------------------------- siuba/tests/test_sql_utils.py | 2 +- siuba/tests/test_sql_verbs.py | 8 - siuba/tests/test_verb_join.py | 2 +- siuba/tests/test_verb_show_query.py | 3 +- siuba/tests/test_verb_utils.py | 4 +- 24 files changed, 104 insertions(+), 19620 deletions(-) create mode 100644 siuba/sql/verbs/__init__.py delete mode 100644 siuba/sql/verbs/transmute.py diff --git a/siuba/sql/__init__.py b/siuba/sql/__init__.py index 5af7053b..9182ac46 100644 --- a/siuba/sql/__init__.py +++ b/siuba/sql/__init__.py @@ -1,8 +1,9 @@ -from .verbs import LazyTbl, sql_raw +from .backend import LazyTbl, sql_raw from .translate import SqlColumn, SqlColumnAgg, SqlFunctionLookupError from . import across as _across # proceed w/ underscore so it isn't exported by default # we just want to register the singledispatch funcs +from . import verbs as _verbs from .dply import vector as _vector from .dply import string as _string diff --git a/siuba/sql/across.py b/siuba/sql/across.py index 4fec8865..0b7aeedb 100644 --- a/siuba/sql/across.py +++ b/siuba/sql/across.py @@ -2,7 +2,7 @@ from siuba.dply.tidyselect import var_select, var_create from siuba.siu import FormulaContext, Call -from . verbs import LazyTbl +from .backend import LazyTbl from .utils import _sql_select, _sql_column_collection from sqlalchemy import sql diff --git a/siuba/sql/backend.py b/siuba/sql/backend.py index df7831fa..6c81b398 100644 --- a/siuba/sql/backend.py +++ b/siuba/sql/backend.py @@ -9,58 +9,18 @@ import warnings -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg +from .translate import CustomOverClause from .utils import ( get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, _sql_column_collection, _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection ) from sqlalchemy import sql import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series +from siuba.siu import FunctionLookupError from functools import singledispatch -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - # Helpers --------------------------------------------------------------------- class SqlFunctionLookupError(FunctionLookupError): pass @@ -77,62 +37,8 @@ def exit(self, node): return node -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ +class WindowReplacer: - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - @staticmethod def _get_unique_name(prefix, columns): column_names = set(columns.keys()) @@ -200,6 +106,27 @@ def visit(self, el): # return col, listener.windows, listener.window_cte +# TODO: consolidate / pull expr handling funcs into own file? +def _create_order_by_clause(columns, *args): + from siuba.dply.verbs import _call_strip_ascending + + sort_cols = [] + for arg in args: + # simple named column + if isinstance(arg, str): + sort_cols.append(columns[arg]) + # an expression + elif callable(arg): + # handle special case where -_.colname -> colname DESC + f, asc = _call_strip_ascending(arg) + col_op = f(columns) if asc else f(columns).desc() + #col_op = arg(columns) + sort_cols.append(col_op) + else: + raise NotImplementedError("Must be string or callable") + + return sort_cols + def track_call_windows(call, columns, group_by, order_by, window_cte = None): col_expr = call(columns) @@ -464,6 +391,8 @@ def _create_table(tbl, columns = None, source = None): def _get_preview(self): # need to make prev op a cte, so we don't override any previous limit + from siuba.dply.verbs import collect + new_sel = self.last_select.limit(5) tbl_small = self.append_op(new_sel) return collect(tbl_small) @@ -508,883 +437,5 @@ def _repr_grouped_df_html_(self): return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - sql_raw = sql.literal_column -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - arg_names = [] - for arg in args: - name = simple_varname(arg) - if name is None: - raise NotImplementedError( - "Count positional arguments must be single column name. " - "Use a named argument to count using complex expressions." - ) - arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 68b9c7cd..b022e309 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -198,3 +198,10 @@ def simplify_sel(sel): return clone_el + + +def lift_inner_cols(tbl): + cols = list(tbl.inner_columns) + + return _sql_column_collection(cols) + diff --git a/siuba/sql/verbs/__init__.py b/siuba/sql/verbs/__init__.py new file mode 100644 index 00000000..19d9c81a --- /dev/null +++ b/siuba/sql/verbs/__init__.py @@ -0,0 +1,15 @@ +from . import ( + arrange, + compute, + conditional, + count, + distinct, + explain, + filter, + group_by, + head, + join, + mutate, + select, + summarize, +) diff --git a/siuba/sql/verbs/arrange.py b/siuba/sql/verbs/arrange.py index 8dcbef7b..a981c63d 100644 --- a/siuba/sql/verbs/arrange.py +++ b/siuba/sql/verbs/arrange.py @@ -1,801 +1,9 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - +from siuba.dply.verbs import arrange +from ..utils import lift_inner_cols +from ..backend import LazyTbl, _create_order_by_clause # Helpers --------------------------------------------------------------------- -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - @arrange.register(LazyTbl) def _arrange(__data, *args): # Note that SQL databases often do not subquery order by clauses. Arrange @@ -827,569 +35,3 @@ def _arrange(__data, *args): order_by = __data.order_by + tuple(new_calls) return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/compute.py b/siuba/sql/verbs/compute.py index 8dcbef7b..f01aaff8 100644 --- a/siuba/sql/verbs/compute.py +++ b/siuba/sql/verbs/compute.py @@ -1,549 +1,7 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - +from siuba.dply.verbs import collect +from ..backend import LazyTbl +from ..utils import _FixedSqlDatabase, _is_dialect_duckdb, MockConnection # collect ---------- @@ -588,808 +46,3 @@ def _collect(__data, as_df = True): return sql_db.read_sql(compiled) return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/conditional.py b/siuba/sql/verbs/conditional.py index 8dcbef7b..de307c50 100644 --- a/siuba/sql/verbs/conditional.py +++ b/siuba/sql/verbs/conditional.py @@ -1,1065 +1,9 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - import warnings -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) +from siuba.dply.verbs import case_when, if_else +from siuba.siu import Call @case_when.register(sql.base.ImmutableColumnCollection) def _case_when(__data, cases): @@ -1088,308 +32,9 @@ def _case_when(__data, cases): return sql.case(whens, else_ = else_val) -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - # if_else --------------------------------------------------------------------- @if_else.register(sql.elements.ColumnElement) def _if_else(cond, true_vals, false_vals): whens = [(cond, true_vals)] return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/count.py b/siuba/sql/verbs/count.py index 8dcbef7b..c68c9f2d 100644 --- a/siuba/sql/verbs/count.py +++ b/siuba/sql/verbs/count.py @@ -7,846 +7,14 @@ """ -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) +from siuba.dply.verbs import count, add_count, inner_join +from ..utils import _sql_select, lift_inner_cols +from ..backend import LazyTbl, ordered_union -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols +from .mutate import _mutate_cols @@ -909,487 +77,3 @@ def _add_count(__data, *args, wt = None, sort = False, **kwargs): by = list(c.name for c in counts.last_select.inner_columns)[:-1] return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/distinct.py b/siuba/sql/verbs/distinct.py index 8dcbef7b..c56563d7 100644 --- a/siuba/sql/verbs/distinct.py +++ b/siuba/sql/verbs/distinct.py @@ -1,1357 +1,9 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. +from siuba.dply.verbs import distinct, mutate, _var_select_simple -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. +from ..backend import LazyTbl +from ..utils import _sql_select, lift_inner_cols -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - @distinct.register(LazyTbl) def _distinct(__data, *args, _keep_all = False, **kwargs): if (args or kwargs) and _keep_all: @@ -1383,13 +35,3 @@ def _distinct(__data, *args, _keep_all = False, **kwargs): sel = _sql_select(distinct_cols).select_from(cte).distinct() return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/explain.py b/siuba/sql/verbs/explain.py index 8dcbef7b..c1d14f03 100644 --- a/siuba/sql/verbs/explain.py +++ b/siuba/sql/verbs/explain.py @@ -7,517 +7,11 @@ """ -import warnings +from ..backend import LazyTbl +from ..utils import _sql_simplify_select -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) +from siuba.dply.verbs import show_query -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- @show_query.register(LazyTbl) def _show_query(tbl, simplify = False, return_table = True): @@ -542,854 +36,3 @@ def _show_query(tbl, simplify = False, return_table = True): return tbl return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/filter.py b/siuba/sql/verbs/filter.py index 8dcbef7b..9275c651 100644 --- a/siuba/sql/verbs/filter.py +++ b/siuba/sql/verbs/filter.py @@ -1,629 +1,12 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. +from siuba.dply.verbs import filter -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) +from ..backend import LazyTbl +from ..utils import _sql_select from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) +from siuba.siu import Call +from siuba.dply.across import _set_data_context @filter.register(LazyTbl) @@ -678,718 +61,3 @@ def _filter(__data, *args): # create second cte ---- filt_sel = _sql_select(orig_cols).where(bool_clause) return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/group_by.py b/siuba/sql/verbs/group_by.py index 8dcbef7b..d74e4e56 100644 --- a/siuba/sql/verbs/group_by.py +++ b/siuba/sql/verbs/group_by.py @@ -1,1038 +1,9 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. +from siuba.dply.verbs import group_by, ungroup -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. +from ..backend import LazyTbl, ordered_union +from ..utils import lift_inner_cols - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) +from .mutate import _mutate_cols @group_by.register(LazyTbl) @@ -1059,337 +30,3 @@ def _group_by(__data, *args, add = False, **kwargs): @ungroup.register(LazyTbl) def _ungroup(__data): return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/head.py b/siuba/sql/verbs/head.py index 8dcbef7b..2fcda317 100644 --- a/siuba/sql/verbs/head.py +++ b/siuba/sql/verbs/head.py @@ -1,1395 +1,9 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. +from siuba.dply.verbs import head -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ +from ..backend import LazyTbl @head.register(LazyTbl) def _head(__data, n = 5): sel = __data.last_select return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/join.py b/siuba/sql/verbs/join.py index 8dcbef7b..112066c4 100644 --- a/siuba/sql/verbs/join.py +++ b/siuba/sql/verbs/join.py @@ -1,1096 +1,12 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - import warnings -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - +from collections.abc import Mapping from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} +from siuba.dply.verbs import join, left_join, right_join, inner_join, semi_join, anti_join - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) +from ..backend import LazyTbl +from ..utils import _sql_select - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): """Return labeled columns, according to selection rules for joins. @@ -1315,81 +231,3 @@ def _create_join_conds(left_sel, right_sel, on): conds.append(col_expr) return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/mutate.py b/siuba/sql/verbs/mutate.py index 8dcbef7b..5f4e8da4 100644 --- a/siuba/sql/verbs/mutate.py +++ b/siuba/sql/verbs/mutate.py @@ -1,683 +1,19 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - from siuba.dply.verbs import ( - show_query, collect, simple_varname, - select, mutate, transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple ) -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, +from ..backend import LazyTbl, SqlLabelReplacer +from ..utils import ( _sql_with_only_columns, - _sql_simplify_select, - MockConnection + lift_inner_cols ) from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 # TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) +from siuba.dply.across import _require_across, _eval_with_context @mutate.register(LazyTbl) @@ -747,8 +83,8 @@ def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " + raise TypeError( + f"{verb_name} named arguments must return a single column, but `{new_name}` " "returned multiple columns." ) @@ -794,602 +130,3 @@ def _transmute(__data, *args, **kwargs): sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/select.py b/siuba/sql/verbs/select.py index 8dcbef7b..dafb293f 100644 --- a/siuba/sql/verbs/select.py +++ b/siuba/sql/verbs/select.py @@ -1,593 +1,11 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - +from siuba.dply.verbs import select, rename, _select_group_renames from siuba.dply.tidyselect import VarList, var_select +from siuba.dply.verbs import simple_varname -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) +from ..backend import LazyTbl, _warn_missing +from ..utils import lift_inner_cols @select.register(LazyTbl) @@ -625,709 +43,6 @@ def _select(__data, *args, **kwargs): ) - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - @rename.register(LazyTbl) def _rename(__data, **kwargs): sel = __data.last_select @@ -1350,46 +65,3 @@ def _rename(__data, **kwargs): return __data.append_op(new_sel, group_by=group_keys) -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/summarize.py b/siuba/sql/verbs/summarize.py index 8dcbef7b..c3c9b03a 100644 --- a/siuba/sql/verbs/summarize.py +++ b/siuba/sql/verbs/summarize.py @@ -1,914 +1,11 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - sel = __data.last_op.alias() # original select - win_sel = sel.select() +from siuba.dply.verbs import summarize - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): +from .mutate import _sql_upsert_columns, _eval_expr_arg, _eval_expr_kwarg - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) +from ..utils import lift_inner_cols, _sql_with_only_columns +from ..backend import LazyTbl, get_single_from @summarize.register(LazyTbl) @@ -1033,363 +130,3 @@ def validate_references(arg_name, expr, verb_name): sel = _sql_upsert_columns(sel, [labeled]) return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/sql/verbs/transmute.py b/siuba/sql/verbs/transmute.py deleted file mode 100644 index 8dcbef7b..00000000 --- a/siuba/sql/verbs/transmute.py +++ /dev/null @@ -1,1395 +0,0 @@ -""" -Implements LazyTbl to represent tables of SQL data, and registers it on verbs. - -This module is responsible for the handling of the "table" side of things, while -translate.py handles translating column operations. - - -""" - -import warnings - -from siuba.dply.verbs import ( - show_query, collect, - simple_varname, - select, - mutate, - transmute, - filter, - arrange, _call_strip_ascending, - summarize, - count, add_count, - group_by, ungroup, - case_when, - join, left_join, right_join, inner_join, semi_join, anti_join, - head, - rename, - distinct, - if_else, - _select_group_renames, - _var_select_simple - ) - -from siuba.dply.tidyselect import VarList, var_select - -from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import ( - get_dialect_translator, - _FixedSqlDatabase, - _is_dialect_duckdb, - _sql_select, - _sql_column_collection, - _sql_add_columns, - _sql_with_only_columns, - _sql_simplify_select, - MockConnection -) - -from sqlalchemy import sql -import sqlalchemy -from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2 -# TODO: currently needed for select, but can we remove pandas? -from pandas import Series -from functools import singledispatch - -from sqlalchemy.sql import schema - -from siuba.dply.across import _require_across, _set_data_context, _eval_with_context - -# TODO: -# - distinct -# - annotate functions using sel.prefix_with("\n/**/\n") ? - - -# Helpers --------------------------------------------------------------------- - -class SqlFunctionLookupError(FunctionLookupError): pass - - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - -class WindowReplacer(CallListener): - """Call tree listener. - - Produces 2 important behaviors via the enter method: - - returns evaluated sql call expression, with labels on all window expressions. - - stores all labeled window expressions via the windows property. - - TODO: could replace with a sqlalchemy transformer - """ - - def __init__(self, columns, group_by, order_by, window_cte = None): - self.columns = columns - self.group_by = group_by - self.order_by = order_by - self.window_cte = window_cte - self.windows = [] - - def exit(self, node): - col_expr = node(self.columns) - - if not isinstance(col_expr, sql.elements.ClauseElement): - return col_expr - - over_clauses = [x for x in self._get_over_clauses(col_expr) if isinstance(x, CustomOverClause)] - - # put groupings and orderings onto custom over clauses - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - group_by = sql.elements.ClauseList( - *[self.columns[name] for name in self.group_by] - ) - order_by = sql.elements.ClauseList( - *_create_order_by_clause(self.columns, *self.order_by) - ) - - over.set_over(group_by, order_by) - - if len(over_clauses) and self.window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = self._get_unique_name('win', lift_inner_cols(self.window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - self.window_cte = _sql_add_columns(self.window_cte, [label]) - win_col = lift_inner_cols(self.window_cte).values()[-1] - self.windows.append(win_col) - - return win_col - - return col_expr - - @staticmethod - def _get_unique_name(prefix, columns): - column_names = set(columns.keys()) - - i = 1 - name = prefix + str(i) - while name in column_names: - i += 1 - name = prefix + str(i) - - - return name - - @staticmethod - def _get_over_clauses(clause): - windows = [] - append_win = lambda col: windows.append(col) - - sql.util.visitors.traverse(clause, {}, {"over": append_win}) - - return windows - - -class SqlLabelReplacer: - """Create a visitor to replace source labels with destination. - - Note that this is meant to be used with sqlalchemy visitors. - """ - - def __init__(self, src_columns, dst_columns): - self.src_columns = src_columns - self.src_labels = set([x for x in src_columns if isinstance(x, sql.elements.Label)]) - self.dst_columns = dst_columns - self.applied = False - - def __call__(self, clause): - return sql.util.visitors.replacement_traverse(clause, {}, self.visit) - - def visit(self, el): - from sqlalchemy.sql.elements import ColumnClause, Label, ClauseElement, TypeClause - from sqlalchemy.sql.schema import Column - - if isinstance(el, TypeClause): - # TODO: for some reason this type throws an error if unguarded - return None - - if isinstance(el, ClauseElement): - if el in self.src_labels: - self.applied = True - return self.dst_columns[el.name] - elif el in self.src_columns: - return self.dst_columns[el.name] - - # TODO: should we create a subquery if the user passed raw text? - #elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # # Raw SQL, which will need a subquery, but not substitution - # if el.key != "*": - # self.applied = True - - return None - - -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - -def track_call_windows(call, columns, group_by, order_by, window_cte = None): - col_expr = call(columns) - - crnt_group_by = sql.elements.ClauseList( - *[columns[name] for name in group_by] - ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) - return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) - - - -@singledispatch -def replace_call_windows(col_expr, group_by, order_by, window_cte = None): - raise TypeError(str(type(col_expr))) - - -@replace_call_windows.register(sql.base.ImmutableColumnCollection) -def _(col_expr, group_by, order_by, window_cte = None): - all_over_clauses = [] - for col in col_expr: - _, over_clauses, window_cte = replace_call_windows( - col, - group_by, - order_by, - window_cte - ) - all_over_clauses.extend(over_clauses) - - return col_expr, all_over_clauses, window_cte - - -@replace_call_windows.register(sql.elements.ClauseElement) -def _(col_expr, group_by, order_by, window_cte = None): - - over_clauses = WindowReplacer._get_over_clauses(col_expr) - - for over in over_clauses: - # TODO: shouldn't mutate these over clauses - over.set_over(group_by, order_by) - - if len(over_clauses) and window_cte is not None: - # custom name, or parameters like "%(...)s" may nest and break psycopg2 - # with columns you can set a key to fix this, but it doesn't seem to - # be an option with labels - name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte)) - label = col_expr.label(name) - - # put into CTE, and return its resulting column, so that subsequent - # operations will refer to the window column on window_cte. Note that - # the operations will use the actual column, so may need to use the - # ClauseAdaptor to make it a reference to the label - window_cte = _sql_add_columns(window_cte, [label]) - win_col = lift_inner_cols(window_cte).values()[-1] - - return win_col, over_clauses, window_cte - - return col_expr, over_clauses, window_cte - -def get_single_from(sel): - froms = sel.froms - - n_froms = len(froms) - if n_froms != 1: - raise ValueError( - f"Expected a single table in the from clause, but found {n_froms}" - ) - - return froms[0] - -def lift_inner_cols(tbl): - cols = list(tbl.inner_columns) - - return _sql_column_collection(cols) - -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - -# Misc utilities -------------------------------------------------------------- - -def ordered_union(x, y): - dx = {el: True for el in x} - dy = {el: True for el in y} - - return tuple({**dx, **dy}) - - -def _warn_missing(missing_groups): - warnings.warn(f"Adding missing grouping variables: {missing_groups}") - - -# Table ----------------------------------------------------------------------- - -class LazyTbl: - def __init__( - self, source, tbl, columns = None, - ops = None, group_by = tuple(), order_by = tuple(), - translator = None - ): - """Create a representation of a SQL table. - - Args: - source: a sqlalchemy.Engine or sqlalchemy.Connection instance. - tbl: table of form 'schema_name.table_name', 'table_name', or sqlalchemy.Table. - columns: if specified, a listlike of column names. - - Examples - -------- - - :: - from sqlalchemy import create_engine - from siuba.data import mtcars - - # create database and table - engine = create_engine("sqlite:///:memory:") - mtcars.to_sql('mtcars', engine) - - tbl_mtcars = LazyTbl(engine, 'mtcars') - - """ - - # connection and dialect specific functions - self.source = sqlalchemy.create_engine(source) if isinstance(source, str) else source - - # get dialect name - dialect = self.source.dialect.name - self.translator = get_dialect_translator(dialect) - - self.tbl = self._create_table(tbl, columns, self.source) - - # important states the query can be in (e.g. grouped) - self.ops = [self.tbl] if ops is None else ops - - self.group_by = group_by - self.order_by = order_by - - - def append_op(self, op, **kwargs): - cpy = self.copy(**kwargs) - cpy.ops = cpy.ops + [op] - return cpy - - def copy(self, **kwargs): - return self.__class__(**{**self.__dict__, **kwargs}) - - def shape_call( - self, - call, window = True, str_accessors = False, - verb_name = None, arg_name = None, - ): - return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name) - - def track_call_windows(self, call, columns = None, window_cte = None): - """Returns tuple of (new column expression, list of window exprs)""" - - columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) - - def get_ordered_col_names(self): - """Return columns from current select, with grouping columns first.""" - ungrouped = [k for k in self.last_op.columns.keys() if k not in self.group_by] - return list(self.group_by) + ungrouped - - #def label_breaks_order_by(self, name): - # """Returns True if a new column label would break the order by vars.""" - - # # TODO: arrange currently allows literals, which breaks this. it seems - # # better to only allow calls in arrange. - # order_by_vars = {c.op_vars(attr_calls=False) for c in self.order_by} - - - - - @property - def last_op(self) -> "sql.Table | sql.Select": - last_op = self.ops[-1] - - if last_op is None: - raise TypeError() - - return last_op - - @property - def last_select(self): - last_op = self.last_op - if not isinstance(last_op, sql.selectable.SelectBase): - return last_op.select() - - return last_op - - @staticmethod - def _create_table(tbl, columns = None, source = None): - """Return a sqlalchemy.Table, autoloading column info if needed. - - Arguments: - tbl: a sqlalchemy.Table or string of form 'table_name' or 'schema_name.table_name'. - columns: a tuple of column names for the table. Overrides source argument. - source: a sqlalchemy engine, used to autoload columns. - - """ - if isinstance(tbl, sql.selectable.FromClause): - return tbl - - if not isinstance(tbl, str): - raise ValueError("tbl must be a sqlalchemy Table or string, but was %s" %type(tbl)) - - if columns is None and source is None: - raise ValueError("One of columns or source must be specified") - - schema, table_name = tbl.split('.') if '.' in tbl else [None, tbl] - - columns = map(sqlalchemy.Column, columns) if columns is not None else tuple() - - # TODO: pybigquery uses schema to mean project_id, so we cannot use - # siuba's classic breakdown "{schema}.{table_name}". Basically - # pybigquery uses "{schema=project_id}.{dataset_dot_table_name}" in its internal - # logic. An important side effect is that bigquery errors for - # `dataset`.`table`, but not `dataset.table`. - if source and source.dialect.name == "bigquery": - table_name = tbl - schema = None - - return sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(bind = source), - *columns, - schema = schema, - autoload_with = source if not columns else None - ) - - def _get_preview(self): - # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_select.limit(5) - tbl_small = self.append_op(new_sel) - return collect(tbl_small) - - def __repr__(self): - template = ( - "# Source: lazy query\n" - "# DB Conn: {}\n" - "# Preview:\n{}\n" - "# .. may have more rows" - ) - - return template.format(repr(self.source.engine), repr(self._get_preview())) - - def _repr_html_(self): - template = ( - "
" - "
"
-                "# Source: lazy query\n"
-                "# DB Conn: {}\n"
-                "# Preview:\n"
-                "
" - "{}" - "

# .. may have more rows

" - "
" - ) - - data = self._get_preview() - - # _repr_html_ can not exist or return None, to signify that repr should be used - if not hasattr(data, '_repr_html_'): - return None - - html_data = data._repr_html_() - if html_data is None: - return None - - return template.format(self.source.engine, html_data) - - -def _repr_grouped_df_html_(self): - return "

(grouped data frame)

" + self._selected_obj._repr_html_() + "
" - - - -# Main Funcs -# ============================================================================= - -# sql raw -------------- - -sql_raw = sql.literal_column - -# show query ----------- - -@show_query.register(LazyTbl) -def _show_query(tbl, simplify = False, return_table = True): - #query = tbl.last_op #if not simplify else - compile_query = lambda query: query.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - - - if simplify: - # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_select) - - explained = compile_query(simple_sel) - else: - # use a much more verbose query - explained = compile_query(tbl.last_select) - - if return_table: - print(str(explained)) - return tbl - - return str(explained) - - - -# collect ---------- - -@collect.register(LazyTbl) -def _collect(__data, as_df = True): - # TODO: maybe remove as_df options, always return dataframe - - if isinstance(__data.source, MockConnection): - # a mock sqlalchemy is being used to show_query, and echo queries. - # it doesn't return a result object or have a context handler, so - # we need to bail out early - return - - # compile query ---- - - if _is_dialect_duckdb(__data.source): - # TODO: can be removed once next release of duckdb fixes: - # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_select - compiled = query.compile( - dialect = __data.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - else: - compiled = __data.last_select - - # execute query ---- - - with __data.source.connect() as conn: - if as_df: - sql_db = _FixedSqlDatabase(conn) - - if _is_dialect_duckdb(__data.source): - # TODO: pandas read_sql is very slow with duckdb. - # see https://github.com/pandas-dev/pandas/issues/45678 - # going to handle here for now. address once LazyTbl gets - # subclassed per backend. - duckdb_con = conn.connection.c - return duckdb_con.query(str(compiled)).to_df() - else: - # - return sql_db.read_sql(compiled) - - return conn.execute(compiled) - - -@select.register(LazyTbl) -def _select(__data, *args, **kwargs): - # see https://stackoverflow.com/questions/25914329/rearrange-columns-in-sqlalchemy-select-object - if kwargs: - raise NotImplementedError( - "Using kwargs in select not currently supported. " - "Use _.newname == _.oldname instead" - ) - last_sel = __data.last_select - columns = {c.key: c for c in last_sel.inner_columns} - - # same as for DataFrame - colnames = Series(list(columns)) - vl = VarList() - evaluated = (arg(vl) if callable(arg) else arg for arg in args) - od = var_select(colnames, *evaluated) - - missing_groups, group_keys = _select_group_renames(od, __data.group_by) - - if missing_groups: - _warn_missing(missing_groups) - - final_od = {**{k: None for k in missing_groups}, **od} - - col_list = [] - for k,v in final_od.items(): - col = columns[k] - col_list.append(col if v is None else col.label(v)) - - return __data.append_op( - last_sel.with_only_columns(col_list), - group_by = group_keys - ) - - - -@filter.register(LazyTbl) -def _filter(__data, *args): - # Note: currently always produces 2 additional select statements, - # 1 for window/aggs, and 1 for the where clause - - sel = __data.last_op.alias() # original select - win_sel = sel.select() - - conds = [] - windows = [] - with _set_data_context(__data, window=True): - for ii, arg in enumerate(args): - - if isinstance(arg, Call): - new_call = __data.shape_call(arg, verb_name = "Filter", arg_name = ii) - #var_cols = new_call.op_vars(attr_calls = False) - - # note that a new win_sel is returned, w/ window columns appended - col_expr, win_cols, win_sel = __data.track_call_windows( - new_call, - sel.columns, - window_cte = win_sel - ) - - if isinstance(col_expr, sql.base.ImmutableColumnCollection): - conds.extend(col_expr) - else: - conds.append(col_expr) - - windows.extend(win_cols) - - else: - conds.append(arg) - - bool_clause = sql.and_(*conds) - - # first cte, windows ---- - if len(windows): - - win_alias = win_sel.alias() - - # move non-window functions to refer to win_sel clause (not the innermost) --- - bool_clause = sql.util.ClauseAdapter(win_alias) \ - .traverse(bool_clause) - - orig_cols = [win_alias.columns[k] for k in sel.columns.keys()] - else: - orig_cols = [sel] - - # create second cte ---- - filt_sel = _sql_select(orig_cols).where(bool_clause) - return __data.append_op(filt_sel) - - -@mutate.register(LazyTbl) -def _mutate(__data, *args, **kwargs): - # TODO: verify it can follow a renaming select - - # track labeled columns in set - if not (len(args) or len(kwargs)): - return __data.append_op(__data.last_op) - - names, sel_out = _mutate_cols(__data, args, kwargs, "Mutate") - return __data.append_op(sel_out) - - -def _sql_upsert_columns(sel, new_columns: "list[base.Label | base.Column]"): - orig_cols = lift_inner_cols(sel) - replaced = {**orig_cols} - - for new_col in new_columns: - replaced[new_col.name] = new_col - return _sql_with_only_columns(sel, list(replaced.values())) - - -def _select_mutate_result(src_sel, expr_result): - dst_alias = src_sel.alias() - src_columns = set(lift_inner_cols(src_sel)) - replacer = SqlLabelReplacer(set(src_columns), dst_alias.columns) - - if isinstance(expr_result, sql.base.ImmutableColumnCollection): - replaced_cols = list(map(replacer, expr_result)) - orig_cols = expr_result - #elif isinstance(expr_result, None): - # pass - else: - replaced_cols = [replacer(expr_result)] - orig_cols = [expr_result] - - if replacer.applied: - return _sql_upsert_columns(dst_alias.select(), replaced_cols) - - return _sql_upsert_columns(src_sel, orig_cols) - - -def _eval_expr_arg(__data, sel, func, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - # case 1: simple names ---- - simple_name = simple_varname(func) - if simple_name is not None: - return inner_cols[simple_name] - - # case 2: across ---- - _require_across(func, verb_name) - - cols_result = _eval_with_context(__data, window, inner_cols, func) - - # TODO: remove or raise a more informative error - assert isinstance(cols_result, sql.base.ImmutableColumnCollection), type(cols_result) - - return cols_result - - -def _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=True): - inner_cols = lift_inner_cols(sel) - - expr_shaped = __data.shape_call(func, window, verb_name = verb_name, arg_name = new_name) - new_col, windows, _ = __data.track_call_windows(expr_shaped, inner_cols) - - if isinstance(new_col, sql.base.ImmutableColumnCollection): - raise TyepError( - f"{verb_name} named arguments must return a single column, but `{k}` " - "returned multiple columns." - ) - - return new_col.label(new_name) - - -def _mutate_cols(__data, args, kwargs, verb_name): - result_names = {} # used as ordered set - sel = __data.last_select - - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name) - - # replace any labels that require a subquery ---- - sel = _select_mutate_result(sel, cols_result) - - if isinstance(cols_result, sql.base.ImmutableColumnCollection): - result_names.update({k: True for k in cols_result.keys()}) - else: - result_names[cols_result.name] = True - - - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name) - - sel = _select_mutate_result(sel, labeled) - result_names[new_name] = True - - - return list(result_names), sel - - -@transmute.register(LazyTbl) -def _transmute(__data, *args, **kwargs): - # will use mutate, then select some cols - result_names, sel = _mutate_cols(__data, args, kwargs, "Transmute") - - # transmute keeps grouping cols, and any defined in kwargs - missing = [x for x in __data.group_by if x not in result_names] - cols_to_keep = [*missing, *result_names] - - columns = lift_inner_cols(sel) - sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) - - return __data.append_op(sel_stripped) - - -@arrange.register(LazyTbl) -def _arrange(__data, *args): - # Note that SQL databases often do not subquery order by clauses. Arrange - # sets order_by on the backend, so it can set order by in over elements, - # and handle when new columns are named the same as order by vars. - # see: https://dba.stackexchange.com/q/82930 - - last_sel = __data.last_select - cols = lift_inner_cols(last_sel) - - # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) - - new_calls = [] - for ii, expr in enumerate(args): - if callable(expr): - - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) - - else: - res = expr - - new_calls.append(res) - - sort_cols = _create_order_by_clause(cols, *new_calls) - - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - - -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - - - -@count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py - if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) - - result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") - - # remove unnecessary select, if we're operating on a table ---- - if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): - sel_inner = __data.last_op - - # create outer select ---- - # holds selected columns and tally (n) - sel_inner_cte = sel_inner.alias() - inner_cols = sel_inner_cte.columns - - # apply any group vars from a group_by verb call first - missing = [k for k in __data.group_by if k not in result_names] - - all_group_names = ordered_union(__data.group_by, result_names) - outer_group_cols = [inner_cols[k] for k in all_group_names] - - # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) - - sel_outer = _sql_select([*outer_group_cols, count_col]) \ - .select_from(sel_inner_cte) \ - .group_by(*outer_group_cols) - - # count is like summarize, so removes order_by - return __data.append_op( - sel_outer.order_by(count_col.desc()), - order_by = tuple() - ) - - -@add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - - return inner_join(__data, counts, by = by) - - -@summarize.register(LazyTbl) -def _summarize(__data, *args, **kwargs): - # https://stackoverflow.com/questions/14754994/why-is-sqlalchemy-count-much-slower-than-the-raw-query - - # get query with correct from clause, and maybe unneeded subquery - safe_from = __data.last_select.alias() - result_names, sel = _aggregate_cols(__data, safe_from, args, kwargs, "Summarize") - - # see if we can remove subquery - out_sel = _collapse_select(sel, safe_from) - - from_tbl = get_single_from(out_sel) - group_cols = [from_tbl.columns[k] for k in __data.group_by] - - final_sel = out_sel.group_by(*group_cols) - - new_data = __data.append_op(final_sel, group_by = tuple(), order_by = tuple()) - return new_data - - -def _collapse_select(outer_sel, inner_alias): - # check whether any outer columns reference an inner label ---- - inner_sel = inner_alias.element - - columns = lift_inner_cols(outer_sel) - inner_cols = lift_inner_cols(inner_sel) - - inner_labels = set([ - x.name for x in inner_cols - if isinstance(x, sql.elements.Label) - ]) - - col_requires_cte = set(inner_alias.columns[k] for k in inner_labels) - - bad_refs = [] - - def collect_refs(el): - if el in col_requires_cte: - bad_refs.append(el) - - for col in columns: - sql.util.visitors.traverse(col, {}, {"column": collect_refs}) - - # if possible, remove the outer query ---- - if not (bad_refs or len(inner_sel._group_by_clause)): - from sqlalchemy.sql.elements import ColumnClause, Label - - from_obj = get_single_from(inner_sel) - adaptor = sql.util.ClauseAdapter( - from_obj, - adapt_on_names=True, - include_fn=lambda c: isinstance(c, (ColumnClause, Label)) - ) - - new_cols = [] - for col in columns: - if isinstance(col, Label): - res = adaptor.traverse(col.element).label(col.name) - new_cols.append(res) - - else: - new_cols.append(adaptor.traverse(col)) - #new_cols = list(map(adaptor.traverse, columns)) - - return _sql_with_only_columns(inner_sel, new_cols) - - return outer_sel - - -def _aggregate_cols(__data, subquery, args, kwargs, verb_name): - # cases: - # * grouping cols can not be overwritten (in dbplyr they can't be ref'd) - # * no existing labels referred to - can use same select - # * existing labels referred to - need 1 subquery tops - # * groups + summarize columns can replace everything - - def get_label_clauses(clause): - out = [] - sql.util.visitors.traverse(clause, {}, {"label": lambda c: out.append(c)}) - - return out - - def quote_varname(x): - return f"`{x}`" - - def validate_references(arg_name, expr, verb_name): - bad_varnames = get_label_clauses(expr) - repr_names = ", ".join(map(quote_varname, [el.name for el in bad_varnames])) - - if not bad_varnames: - return - - raise NotImplementedError( - f"In SQL, you cannot refer to a column created in the same {verb_name}. " - f"`{arg_name}` refers to columns created earlier: {repr_names}." - ) - - sel = subquery.select() - - final_cols = {k: subquery.columns[k] for k in __data.group_by} - - # handle args ---- - for ii, func in enumerate(args): - cols_result = _eval_expr_arg(__data, sel, func, verb_name, window=False) - - for col in cols_result: - validate_references(col.name, col.element, verb_name) - final_cols[col.name] = col - - sel = _sql_upsert_columns(sel, cols_result) - - - # handle kwargs ---- - for new_name, func in kwargs.items(): - labeled = _eval_expr_kwarg(__data, sel, func, new_name, verb_name, window=False) - - validate_references(labeled.name, labeled.element, verb_name) - final_cols[new_name] = labeled - - sel = _sql_upsert_columns(sel, [labeled]) - - return list(final_cols), _sql_with_only_columns(sel, list(final_cols.values())) - - -@group_by.register(LazyTbl) -def _group_by(__data, *args, add = False, **kwargs): - if not (args or kwargs): - return __data.copy() - - group_names, sel = _mutate_cols(__data, args, kwargs, "Group by") - - if None in group_names: - raise NotImplementedError("Complex, unnamed expressions not supported in sql group_by") - - # check whether we can just use underlying table ---- - new_cols = lift_inner_cols(sel) - if set(new_cols).issubset(set(__data.last_op.columns)): - sel = __data.last_op - - if add: - group_names = ordered_union(__data.group_by, group_names) - - return __data.append_op(sel, group_by = tuple(group_names)) - - -@ungroup.register(LazyTbl) -def _ungroup(__data): - return __data.copy(group_by = tuple()) - - -@case_when.register(sql.base.ImmutableColumnCollection) -def _case_when(__data, cases): - # TODO: will need listener to enter case statements, to handle when they use windows - if isinstance(cases, Call): - cases = cases(__data) - - whens = [] - case_items = list(cases.items()) - n_items = len(case_items) - - else_val = None - for ii, (expr, val) in enumerate(case_items): - # handle where val is a column expr - if callable(val): - val = val(__data) - - # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val - elif callable(expr): - whens.append((expr(__data), val)) - else: - whens.append((expr, val)) - - return sql.case(whens, else_ = else_val) - - -# Join ------------------------------------------------------------------------ - -from collections.abc import Mapping - -def _joined_cols(left_cols, right_cols, on_keys, how, suffix = ("_x", "_y")): - """Return labeled columns, according to selection rules for joins. - - Rules: - 1. For join keys, keep left table's column - 2. When keys have the same labels, add suffix - """ - - # TODO: remove sets, so uses stable ordering - # when left and right cols have same name, suffix with _x / _y - keep_right = set(right_cols.keys()) - set(on_keys.values()) - shared_labs = set(left_cols.keys()).intersection(keep_right) - - right_cols_no_keys = {k: right_cols[k] for k in keep_right} - - # for an outer join, have key columns coalesce values - - left_cols = {**left_cols} - if how == "full": - for lk, rk in on_keys.items(): - col = sql.functions.coalesce(left_cols[lk], right_cols[rk]) - left_cols[lk] = col.label(lk) - elif how == "right": - for lk, rk in on_keys.items(): - # Make left key columns actually be right ones (which contain left + extra) - left_cols[lk] = right_cols[rk].label(lk) - - - # create labels ---- - l_labs = _relabeled_cols(left_cols, shared_labs, suffix[0]) - r_labs = _relabeled_cols(right_cols_no_keys, shared_labs, suffix[1]) - - return l_labs + r_labs - - - -def _relabeled_cols(columns, keys, suffix): - # add a suffix to all columns with names in keys - cols = [] - for k, v in columns.items(): - new_col = v.label(k + str(suffix)) if k in keys else v - cols.append(new_col) - return cols - - -@join.register(LazyTbl) -def _join(left, right, on = None, *args, by = None, how = "inner", sql_on = None): - _raise_if_args(args) - - if on is None and by is not None: - on = by - - # Needs to be on the table, not the select - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on) - how = _validate_join_arg_how(how) - - # for equality join used to combine keys into single column - consolidate_keys = on if sql_on is None else {} - - if how == "right": - # switch joins, since sqlalchemy doesn't have right join arg - # see https://stackoverflow.com/q/11400307/1144523 - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create join ---- - join = left_sel.join( - right_sel, - onclause = bool_clause, - isouter = how != "inner", - full = how == "full" - ) - - # if right join, set selects back - if how == "right": - left_sel, right_sel = right_sel, left_sel - on = {v:k for k,v in on.items()} - - # note, shared_keys assumes on is a mapping... - # TODO: shared_keys appears to be for when on is not specified, but was unused - #shared_keys = [k for k,v in on.items() if k == v] - labeled_cols = _joined_cols( - left_sel.columns, - right_sel.columns, - on_keys = consolidate_keys, - how = how - ) - - sel = _sql_select(labeled_cols).select_from(join) - return left.append_op(sel, order_by = tuple()) - - -@semi_join.register(LazyTbl) -def _semi_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left_sel, right_sel) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - # only keep left hand select's columns ---- - sel = _sql_select(left_sel.columns) \ - .select_from(left_sel) \ - .where(sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - - -@anti_join.register(LazyTbl) -def _anti_join(left, right = None, on = None, *args, by = None, sql_on = None): - if on is None and by is not None: - on = by - - _raise_if_args(args) - - left_sel = left.last_op.alias() - right_sel = right.last_op.alias() - - # handle arguments ---- - on = _validate_join_arg_on(on, sql_on, left, right) - - # create join conditions ---- - bool_clause = _create_join_conds(left_sel, right_sel, on) - - # create inner join ---- - #not_exists = ~sql.exists([1], from_obj = right_sel).where(bool_clause) - exists_clause = _sql_select([sql.literal(1)]) \ - .select_from(right_sel) \ - .where(bool_clause) - - sel = left_sel.select().where(~sql.exists(exists_clause)) - - return left.append_op(sel, order_by = tuple()) - -def _raise_if_args(args): - if len(args): - raise NotImplemented("*args is reserved for future arguments (e.g. suffix)") - -def _validate_join_arg_on(on, sql_on = None, lhs = None, rhs = None): - # handle sql on case - if sql_on is not None: - if on is not None: - raise ValueError("Cannot specify both on and sql_on") - - return sql_on - - # handle general cases - if on is None: - # TODO: currently, we check for lhs and rhs tables to indicate whether - # a verb supports inferring columns. Otherwise, raise an error. - if lhs is not None and rhs is not None: - # TODO: consolidate with duplicate logic in pandas verb code - warnings.warn( - "No on column passed to join. " - "Inferring join columns instead using shared column names." - ) - - on_cols = list(set(lhs.columns.keys()).intersection(set(rhs.columns.keys()))) - - if not on_cols: - raise ValueError( - "No join column specified, or shared column names in join." - ) - - # trivial dict mapping shared names to themselves - warnings.warn("Detected shared columns: %s" % on_cols) - on = dict(zip(on_cols, on_cols)) - - else: - raise NotImplementedError("on arg currently cannot be None (default) for SQL") - elif isinstance(on, str): - on = {on: on} - elif isinstance(on, (list, tuple)): - on = dict(zip(on, on)) - - - if not isinstance(on, Mapping): - raise TypeError("on must be a Mapping (e.g. dict)") - - return on - -def _validate_join_arg_how(how): - how_options = ("inner", "left", "right", "full") - if how not in how_options: - raise ValueError("how argument needs to be one of %s" %how_options) - - return how - -def _create_join_conds(left_sel, right_sel, on): - left_cols = left_sel.columns #lift_inner_cols(left_sel) - right_cols = right_sel.columns #lift_inner_cols(right_sel) - - if callable(on): - # callable, like with sql_on arg - conds = [on(left_cols, right_cols)] - else: - # dict-like of form {left: right} - conds = [] - for l, r in on.items(): - col_expr = left_cols[l] == right_cols[r] - conds.append(col_expr) - - return sql.and_(*conds) - - -# Head ------------------------------------------------------------------------ - -@head.register(LazyTbl) -def _head(__data, n = 5): - sel = __data.last_select - - return __data.append_op(sel.limit(n)) - - -# Rename ---------------------------------------------------------------------- - -@rename.register(LazyTbl) -def _rename(__data, **kwargs): - sel = __data.last_select - columns = lift_inner_cols(sel) - - # old_keys uses dict as ordered set - old_to_new = {simple_varname(v):k for k,v in kwargs.items()} - - if None in old_to_new: - raise KeyError("positional arguments must be simple column, " - "e.g. _.colname or _['colname']" - ) - - labs = [c.label(old_to_new[k]) if k in old_to_new else c for k,c in columns.items()] - - new_sel = sel.with_only_columns(labs) - - missing_groups, group_keys = _select_group_renames(old_to_new, __data.group_by) - - return __data.append_op(new_sel, group_by=group_keys) - - -# Distinct -------------------------------------------------------------------- - -@distinct.register(LazyTbl) -def _distinct(__data, *args, _keep_all = False, **kwargs): - if (args or kwargs) and _keep_all: - raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} - - final_names = {**{k: True for k in __data.group_by}, **cols} - - if not len(inner_sel._order_by_clause): - # select distinct has to include any columns in the order by clause, - # so can only safely modify existing statement when there's no order by - sel_cols = lift_inner_cols(inner_sel) - distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() - else: - # fallback to cte - cte = inner_sel.alias() - distinct_cols = [cte.columns[k] for k in final_names] - sel = _sql_select(distinct_cols).select_from(cte).distinct() - - return __data.append_op(sel) - - -# if_else --------------------------------------------------------------------- - -@if_else.register(sql.elements.ColumnElement) -def _if_else(cond, true_vals, false_vals): - whens = [(cond, true_vals)] - return sql.case(whens, else_ = false_vals) - - diff --git a/siuba/tests/test_sql_utils.py b/siuba/tests/test_sql_utils.py index 232e0acf..1d7efa96 100644 --- a/siuba/tests/test_sql_utils.py +++ b/siuba/tests/test_sql_utils.py @@ -1,5 +1,5 @@ from siuba.sql.utils import get_dialect_translator, mock_sqlalchemy_engine -from siuba.sql.verbs import collect +from siuba.dply.verbs import collect from siuba.sql import LazyTbl import pytest diff --git a/siuba/tests/test_sql_verbs.py b/siuba/tests/test_sql_verbs.py index 4870ef16..0c7e428a 100644 --- a/siuba/tests/test_sql_verbs.py +++ b/siuba/tests/test_sql_verbs.py @@ -76,14 +76,6 @@ def test_lazy_tbl_shape_call_error(db): # track_call_windows ---------------------------------------------------------- -from siuba.sql.verbs import track_call_windows -from siuba.sql.translate import win_over - -def test_track_call_windows_basic(): - pass - - - # TODO: remove these old tests? should be redundant =========================== diff --git a/siuba/tests/test_verb_join.py b/siuba/tests/test_verb_join.py index 20cb39b5..64224bd4 100644 --- a/siuba/tests/test_verb_join.py +++ b/siuba/tests/test_verb_join.py @@ -10,7 +10,7 @@ semi_join, anti_join ) from siuba.dply.vector import row_number, n -from siuba.sql.verbs import collect +from siuba.dply.verbs import collect import pytest from .helpers import assert_equal_query, assert_frame_sort_equal, data_frame, backend_notimpl, backend_sql diff --git a/siuba/tests/test_verb_show_query.py b/siuba/tests/test_verb_show_query.py index 8504cc6a..49ea89a3 100644 --- a/siuba/tests/test_verb_show_query.py +++ b/siuba/tests/test_verb_show_query.py @@ -1,4 +1,5 @@ -from siuba.sql.verbs import collect, show_query, mutate, LazyTbl +from siuba.dply.verbs import collect, show_query, mutate +from siuba.sql import LazyTbl from siuba.dply.verbs import Pipeable from siuba.tests.helpers import SqlBackend, data_frame from siuba import _ diff --git a/siuba/tests/test_verb_utils.py b/siuba/tests/test_verb_utils.py index a048bfad..35f381a5 100644 --- a/siuba/tests/test_verb_utils.py +++ b/siuba/tests/test_verb_utils.py @@ -1,6 +1,6 @@ from siuba.siu import Symbolic -from siuba.sql.verbs import collect, show_query, LazyTbl -from siuba.dply.verbs import Call +from siuba.dply.verbs import collect, show_query, Call +from siuba.sql import LazyTbl from .helpers import data_frame import pandas as pd From f675ac1357fa0e82716e2b23ab319418757149eb Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 4 Oct 2022 18:33:24 -0400 Subject: [PATCH 16/27] refactor(sql): remove unused functions --- siuba/sql/backend.py | 52 -------------------------------------------- 1 file changed, 52 deletions(-) diff --git a/siuba/sql/backend.py b/siuba/sql/backend.py index 6c81b398..02d6eed2 100644 --- a/siuba/sql/backend.py +++ b/siuba/sql/backend.py @@ -25,18 +25,6 @@ class SqlFunctionLookupError(FunctionLookupError): pass - -class CallListener: - """Generic listener. Each exit is called on a node's copy.""" - def enter(self, node): - args, kwargs = node.map_subcalls(self.enter) - - return self.exit(node.__class__(node.func, *args, **kwargs)) - - def exit(self, node): - return node - - class WindowReplacer: @staticmethod @@ -100,12 +88,6 @@ def visit(self, el): return None -#def track_call_windows(call, columns, group_by, order_by, window_cte = None): -# listener = WindowReplacer(columns, group_by, order_by, window_cte) -# col = listener.enter(call) -# return col, listener.windows, listener.window_cte - - # TODO: consolidate / pull expr handling funcs into own file? def _create_order_by_clause(columns, *args): from siuba.dply.verbs import _call_strip_ascending @@ -203,40 +185,6 @@ def lift_inner_cols(tbl): return _sql_column_collection(cols) -def col_expr_requires_cte(call, sel, is_mutate = False): - """Return whether a variable assignment needs a CTE""" - - call_vars = set(call.op_vars(attr_calls = False)) - - sel_labs = get_inner_labels(sel) - - # I use the acronym fwg sol (frog soul) to remember sql clause eval order - # from, where, group by, select, order by, limit - # group clause evaluated before select clause, so not issue for mutate - group_needs_cte = not is_mutate and len(sel._group_by_clause) - - return ( group_needs_cte - # TODO: detect when a new var in mutate conflicts w/ order by - #or len(sel._order_by_clause) - or not sel_labs.isdisjoint(call_vars) - ) - -def get_inner_labels(sel): - columns = lift_inner_cols(sel) - sel_labs = set(k for k,v in columns.items() if isinstance(v, sql.elements.Label)) - return sel_labs - -def get_missing_columns(call, columns): - missing_cols = set(call.op_vars(attr_calls = False)) - set(columns.keys()) - return missing_cols - -def compile_el(tbl, el): - compiled = el.compile( - dialect = tbl.source.dialect, - compile_kwargs = {"literal_binds": True} - ) - return compiled - # Misc utilities -------------------------------------------------------------- def ordered_union(x, y): From e12f4a24c960c58505c0426cf71357287bf4b4fe Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Wed, 5 Oct 2022 11:09:05 -0400 Subject: [PATCH 17/27] tests: across in verbs, bare columns in verbs --- siuba/tests/test_verb_across.py | 58 +++++++++++++++++++++--------- siuba/tests/test_verb_count.py | 7 ++++ siuba/tests/test_verb_mutate.py | 8 +++-- siuba/tests/test_verb_summarize.py | 13 +++++++ 4 files changed, 66 insertions(+), 20 deletions(-) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index b6781ebf..7c97e5cb 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -4,18 +4,13 @@ from pandas.testing import assert_frame_equal from pandas.core.groupby import DataFrameGroupBy from siuba.siu import symbolic_dispatch, Symbolic, Fx +from siuba.dply import verbs from siuba.dply.verbs import mutate, filter, summarize, group_by, collect, ungroup from siuba.dply.across import across from siuba.experimental.pivot.test_pivot import assert_equal_query2 from siuba.sql.translate import SqlColumn, SqlColumnAgg, sql_scalar, win_agg, sql_agg - -# TODO: test transmute -# TODO: test verb(data, _.simple_name) -# TODO: test changing a group var (e.g. mutate, transmute, add_count), then summarizing -# TODO: group_by(cyl) >> count(cyl = cyl + 1) -# TODO: SQL mutate requires immediate CTE (e.g. due to GROUP BY clause) -# TODO: count "n" name +from siuba.tests.helpers import assert_frame_sort_equal # Helpers ===================================================================== @@ -188,25 +183,28 @@ def test_across_in_summarize(backend, df): assert_equal_query2(res, dst) -def test_across_in_summarize_equiv_ungrouped(): +def test_across_in_summarize_equiv_ungrouped(backend): # note that summarize does not automatically regroup on any keys - src = pd.DataFrame({ + df = pd.DataFrame({ "a_x": [1, 2], "a_y": [3, 4], "b_x": [5., 6.], "g": ["ZZ", "ZZ"] # Note: all groups the same }) + src = backend.load_df(df) - g_src = src.groupby("g") + g_src = group_by(src, _.g) expr_across = across(_, _[_.a_x, _.a_y], f_mean) g_res = summarize(g_src, expr_across) dst = summarize(src, expr_across) - assert g_res.columns.tolist() == ["g", "a_x", "a_y"] - assert g_res["g"].tolist() == ["ZZ"] + + collected = collect(g_res) + assert collected.columns.tolist() == ["g", "a_x", "a_y"] + assert collected["g"].tolist() == ["ZZ"] - assert_frame_equal(g_res.drop(columns="g"), dst) + assert_frame_sort_equal(collected.drop(columns="g"), collect(dst)) def test_across_in_filter(backend, df): @@ -218,15 +216,41 @@ def test_across_in_filter(backend, df): assert_equal_query2(res, dst) -def test_across_in_filter_equiv_ungrouped(df): - gdf = df.groupby("g") +def test_across_in_filter_equiv_ungrouped(backend, df): + src = backend.load_df(df) expr_across = across(_, _[_.a_x, _.a_y], lambda x: x % 2 > 0) - g_res = filter(gdf, expr_across) + g_res = filter(group_by(df, _.g), expr_across) dst = filter(df, expr_across) assert_grouping_names(g_res, ["g"]) - assert_frame_equal(g_res.obj, dst) + assert_equal_query2(g_res.obj, dst) + + +@pytest.mark.parametrize("f", [ + #(arrange), + (verbs.count), + #(add_count), + #(verbs.distinct), + (verbs.group_by), + (verbs.transmute), + +]) +def test_across_in_verb(backend, df, f): + src = backend.load_df(df) + expr_across = across(_, _[_.a_x, _.a_y], Fx % 2 > 0) + expr_manual = {"a_x": _.a_x % 2 > 0, "a_y": _.a_y % 2 > 0} + + res = f(src, expr_across) + dst = f(df, **expr_manual) + + if isinstance(dst, DataFrameGroupBy): + assert_grouping_names(res, ["a_x", "a_y"]) + + assert_equal_query2(ungroup(res), ungroup(dst)) + +# TODO: test verb(data, _.simple_name) +# TODO: count "n" name def test_across_formula_and_underscore(df): diff --git a/siuba/tests/test_verb_count.py b/siuba/tests/test_verb_count.py index 4e8e0518..6816d3b5 100644 --- a/siuba/tests/test_verb_count.py +++ b/siuba/tests/test_verb_count.py @@ -82,3 +82,10 @@ def test_count_on_grouped_df(df2): ) +def test_count_on_grouped_df_when_mutating_group_key(df): + assert_equal_query( + df, + group_by(_.g) >> count(g = _.g + "z"), + pd.DataFrame({"g": ["az", "bz"], "n": [2, 2]}) + ) + diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py index 38d00843..decb2ddf 100644 --- a/siuba/tests/test_verb_mutate.py +++ b/siuba/tests/test_verb_mutate.py @@ -22,7 +22,9 @@ def dfs(backend): (mutate(x = _.a + _.b) >> summarize(ttl = _.x.sum().astype(float)), data_frame(ttl = 30.0)), (mutate(x = _.a + 1, y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])), (mutate(x = _.a + 1) >> mutate(y = _.b - 1), DATA.assign(x = [2,3,4], y = [8,7,6])), - (mutate(x = _.a + 1, y = _.x + 1), DATA.assign(x = [2,3,4], y = [3,4,5])) + (mutate(x = _.a + 1, y = _.x + 1), DATA.assign(x = [2,3,4], y = [3,4,5])), + (mutate(_, _.a), DATA.copy()), + (mutate(_, _.a, _.a), DATA.copy()), ]) def test_mutate_basic(dfs, query, output): assert_equal_query(dfs, query, output) @@ -30,7 +32,7 @@ def test_mutate_basic(dfs, query, output): @pytest.mark.parametrize("query, output", [ (mutate(x = 1), DATA.assign(x = 1)), (mutate(x = "a"), DATA.assign(x = "a")), - (mutate(x = 1.2), DATA.assign(x = 1.2)) + (mutate(x = 1.2), DATA.assign(x = 1.2)), ]) def test_mutate_literal(dfs, query, output): assert_equal_query(dfs, query, output) @@ -66,7 +68,6 @@ def test_mutate_reassign_column_ordering(dfs): data_frame(a = [1,1,1], b = [9,8,7], c = [3,3,3]) ) -@pytest.mark.skip("TODO: in SQL this returns a table with 1 row") def test_mutate_reassign_all_cols_keeps_rowsize(dfs): assert_equal_query( dfs, @@ -119,3 +120,4 @@ def test_mutate_overwrites_prev(backend): + diff --git a/siuba/tests/test_verb_summarize.py b/siuba/tests/test_verb_summarize.py index eaa26e4f..d32ff4ed 100644 --- a/siuba/tests/test_verb_summarize.py +++ b/siuba/tests/test_verb_summarize.py @@ -148,3 +148,16 @@ def test_summarize_subquery_op_vars(backend, df): text = str(query(df).last_op) assert text.count('FROM') == 2 + +@backend_sql +def test_summarize_back_to_back(backend, df): + query = group_by(_.g) >> summarize(low=_.x.min()) >> summarize(high=_.low.max()) + assert_equal_query( + df, + query, + data_frame(high = 3) + ) + + # low defined in first query, high in second + text = str(query(df).last_op) + assert text.count('FROM') == 2 From fb575671ca0a480a38ac502171e45b54ce9023d1 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Wed, 5 Oct 2022 11:31:01 -0400 Subject: [PATCH 18/27] fix(sql): case_when now passes pandas tests --- siuba/sql/verbs/conditional.py | 18 ++++++++++++++++-- siuba/tests/test_verb_case_when.py | 21 ++++++++++++++++----- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/siuba/sql/verbs/conditional.py b/siuba/sql/verbs/conditional.py index de307c50..d9acd73a 100644 --- a/siuba/sql/verbs/conditional.py +++ b/siuba/sql/verbs/conditional.py @@ -5,6 +5,9 @@ from siuba.dply.verbs import case_when, if_else from siuba.siu import Call +from ..backend import LazyTbl + + @case_when.register(sql.base.ImmutableColumnCollection) def _case_when(__data, cases): # TODO: will need listener to enter case statements, to handle when they use windows @@ -22,14 +25,25 @@ def _case_when(__data, cases): val = val(__data) # handle when expressions - if ii+1 == n_items and expr is True: - else_val = val + #if ii+1 == n_items and expr is True: + # else_val = val + if expr is True: + # note: only sqlalchemy v1.3 requires wrapping in literal + whens.append((sql.literal(expr), val)) elif callable(expr): whens.append((expr(__data), val)) else: whens.append((expr, val)) return sql.case(whens, else_ = else_val) + + +@case_when.register(LazyTbl) +def _case_when(__data, cases): + raise NotImplementedError( + "`case_when()` must be used inside a verb like `mutate()`, when using a " + "SQL backend." + ) # if_else --------------------------------------------------------------------- diff --git a/siuba/tests/test_verb_case_when.py b/siuba/tests/test_verb_case_when.py index c13f5285..bbb22834 100644 --- a/siuba/tests/test_verb_case_when.py +++ b/siuba/tests/test_verb_case_when.py @@ -2,10 +2,13 @@ import numpy as np import pytest -from siuba.dply.verbs import case_when from pandas.testing import assert_series_equal from numpy.testing import assert_equal +from siuba.tests.helpers import assert_equal_query + from siuba.siu import _ +from siuba.dply.verbs import case_when, mutate + DATA = pd.DataFrame({ 'x': [0,1,2], @@ -18,7 +21,7 @@ def data(): return DATA.copy() -@pytest.mark.parametrize("k,v, res", [ +@pytest.mark.parametrize("k,v, dst", [ (True, 1, [1]*3), (True, False, [False]*3), (True, _.y, [10, 11, 12]), @@ -29,10 +32,18 @@ def data(): (lambda _: _.x < 2, 0, [0, 0, None]), #(np.array([True, True, False]), 0, [0, 0, None]) ]) -def test_case_when_single_cond(k, v, res, data): - out = case_when(data, {k: v}) +def test_case_when_single_cond(backend, data, k, v, dst): + src = backend.load_df(data) + query = mutate(_, res = case_when(_, {k: v})) + + assert_equal_query(src, query, data.assign(res = dst)) + + +def test_case_when_multiple_clauses(backend, data): + src = backend.load_df(data) + query = mutate(_, res = case_when({_.x == 0: "zero", _.x > 1: "big", True: "small"})) - assert_series_equal(out, pd.Series(res)) + assert_equal_query(src, query, data.assign(res = ["zero", "small", "big"])) def test_case_when_cond_order(data): From 7e7144aa2c4d9bcf4f6c503194044297a05f6dbe Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 6 Oct 2022 10:32:34 -0400 Subject: [PATCH 19/27] feat: expose Fx, across as top-level imports --- siuba/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/siuba/__init__.py b/siuba/__init__.py index c333a133..3ab55dd2 100644 --- a/siuba/__init__.py +++ b/siuba/__init__.py @@ -2,9 +2,10 @@ __version__ = "0.3.0" # default imports-------------------------------------------------------------- -from .siu import _, Lam +from .siu import _, Fx, Lam +from .dply.across import across from .dply.verbs import * from .dply.verbs import __all__ as ALL_DPLY # necessary, since _ won't be exposed in import * by default -__all__ = ['_', *ALL_DPLY] +__all__ = ['_', "Fx", "across", *ALL_DPLY] From 7e04927d24a6e5ea71c3c4643bfa370bd8a0ce8b Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 6 Oct 2022 13:10:23 -0400 Subject: [PATCH 20/27] feat(sql): add grouped distinct, improve tests --- siuba/dply/verbs.py | 51 ++++++++++--------------------- siuba/sql/verbs/distinct.py | 33 ++++++++++---------- siuba/tests/test_verb_across.py | 7 ++--- siuba/tests/test_verb_distinct.py | 29 +++++++++++++++--- 4 files changed, 60 insertions(+), 60 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 0d02ff69..78385103 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -922,18 +922,6 @@ def _arrange(__data, *args): # Distinct ==================================================================== -def _var_select_simple(args) -> "dict[str, bool]": - """Return an 'ordered set' of selected column names.""" - cols = {simple_varname(x): True for x in args} - if None in cols: - raise Exception( - "Positional arguments must be simple column. " - "e.g. _.colname or _['colname']\n\n" - f"Received: {repr(cols[None])}" - ) - - return cols - @singledispatch2(DataFrame) def distinct(__data, *args, _keep_all = False, **kwargs): """Keep only distinct (unique) rows from a table. @@ -977,45 +965,38 @@ def distinct(__data, *args, _keep_all = False, **kwargs): 1 Gentoo Biscoe 46.1 13.2 2 Chinstrap Dream 46.5 17.9 """ - # using dict as ordered set - cols = _var_select_simple(args) - # mutate kwargs - cols.update(kwargs) + if not (args or kwargs): + return __data.drop_duplicates().reset_index(drop=True) - # special case: use all variables when none are specified - if not len(cols): cols = __data.columns - - tmp_data = mutate(__data, **kwargs).drop_duplicates(list(cols)).reset_index(drop = True) + new_names, df_res = _mutate_cols(__data, args, kwargs) + tmp_data = df_res.drop_duplicates(new_names).reset_index(drop=True) if not _keep_all: - return tmp_data[list(cols)] + return tmp_data[new_names] return tmp_data - + @distinct.register(DataFrameGroupBy) def _distinct(__data, *args, _keep_all = False, **kwargs): - cols = _var_select_simple(args) - cols.update(kwargs) + group_names = [ping.name for ping in __data.grouper.groupings] - # special case: use all variables when none are specified - if not len(cols): cols = __data.columns - group_cols_ordered = {ping.name: True for ping in __data.grouper.groupings} - final_cols = list({**group_cols_ordered, **cols, **kwargs}) + f_distinct = distinct.dispatch(type(__data.obj)) - mutated = mutate(__data, **kwargs).obj + tmp_data = (__data + .apply(f_distinct, *args, _keep_all=_keep_all, **kwargs) + ) - if not _keep_all: - pre_df = mutated[final_cols] - else: - pre_df = mutated + index_keys = tmp_data.index.names[:-1] + keys_to_drop = [k for k in index_keys if k in tmp_data.columns] + keys_to_keep = [k for k in index_keys if k not in tmp_data.columns] - res = pre_df.drop_duplicates(list(final_cols)).reset_index(drop = True) - return res.groupby(list(group_cols_ordered)) + final = tmp_data.reset_index(keys_to_drop, drop=True).reset_index(keys_to_keep) + return final.groupby(group_names) # if_else, case_when ========================================================== diff --git a/siuba/sql/verbs/distinct.py b/siuba/sql/verbs/distinct.py index c56563d7..0cf5ca12 100644 --- a/siuba/sql/verbs/distinct.py +++ b/siuba/sql/verbs/distinct.py @@ -1,7 +1,9 @@ -from siuba.dply.verbs import distinct, mutate, _var_select_simple +from siuba.dply.verbs import distinct, mutate -from ..backend import LazyTbl -from ..utils import _sql_select, lift_inner_cols +from ..backend import LazyTbl, ordered_union +from ..utils import _sql_select, _sql_with_only_columns, lift_inner_cols + +from .mutate import _mutate_cols @distinct.register(LazyTbl) @@ -9,25 +11,24 @@ def _distinct(__data, *args, _keep_all = False, **kwargs): if (args or kwargs) and _keep_all: raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select - - # TODO: this is copied from the df distinct version - # cols dict below is used as ordered set - cols = _var_select_simple(args) - cols.update(kwargs) - - # use all columns by default - if not cols: - cols = {k: True for k in lift_inner_cols(inner_sel).keys()} + result_names, inner_sel = _mutate_cols(__data, args, kwargs, "Distinct") + + # create list of final column names ---- + missing = [name for name in __data.group_by if name not in result_names] + if not result_names: + # use all columns if none passed to distinct + all_names = list(lift_inner_cols(inner_sel).keys()) + final_names = ordered_union(missing, all_names) + else: + final_names = ordered_union(missing, result_names) - final_names = {**{k: True for k in __data.group_by}, **cols} - if not len(inner_sel._order_by_clause): + if not (len(inner_sel._order_by_clause) or len(inner_sel._group_by_clause)): # select distinct has to include any columns in the order by clause, # so can only safely modify existing statement when there's no order by sel_cols = lift_inner_cols(inner_sel) distinct_cols = [sel_cols[k] for k in final_names] - sel = inner_sel.with_only_columns(distinct_cols).distinct() + sel = _sql_with_only_columns(inner_sel, distinct_cols).distinct() else: # fallback to cte cte = inner_sel.alias() diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 7c97e5cb..e4791dc6 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -230,8 +230,8 @@ def test_across_in_filter_equiv_ungrouped(backend, df): @pytest.mark.parametrize("f", [ #(arrange), (verbs.count), - #(add_count), - #(verbs.distinct), + #(verbs.add_count), + (verbs.distinct), (verbs.group_by), (verbs.transmute), @@ -249,9 +249,6 @@ def test_across_in_verb(backend, df, f): assert_equal_query2(ungroup(res), ungroup(dst)) -# TODO: test verb(data, _.simple_name) -# TODO: count "n" name - def test_across_formula_and_underscore(df): res = across(df, _[_.a_x, _.a_y], f_round(Fx) / _.b_x) diff --git a/siuba/tests/test_verb_distinct.py b/siuba/tests/test_verb_distinct.py index e848b6d6..b70de6bc 100644 --- a/siuba/tests/test_verb_distinct.py +++ b/siuba/tests/test_verb_distinct.py @@ -42,10 +42,31 @@ def test_distinct_keep_all_not_impl(backend, df): distinct(df, _.y, _keep_all = True) >> collect() -@pytest.mark.xfail -def test_distinct_via_group_by(df): - # NotImplemented - assert False +def test_distinct_via_group_by_single_col(backend): + data = pd.DataFrame({"g": ["a", "a", "b", "b"], "x": [1, 1, 1, 2]}) + + src = backend.load_df(data) + query = group_by(_, _.g) >> distinct(_.x) + + assert_equal_query( + src, + query, + pd.DataFrame({"g": ["a", "b", "b"], "x": [1, 1, 2]}) + ) + + +def test_distinct_via_group_by_group_key_as_arg(backend): + data = pd.DataFrame({"g": ["a", "a", "b", "b"], "x": [1, 1, 1, 2]}) + + src = backend.load_df(data) + query = group_by(_, _.g) >> distinct(_.x, _.g) + + assert_equal_query( + src, + query, + pd.DataFrame({"x": [1, 1, 2], "g": ["a", "b", "b"]}) + ) + def test_distinct_after_summarize(df): From 2862f6d3dd5953db51ae725b36b30ea72acaaa50 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 6 Oct 2022 13:21:58 -0400 Subject: [PATCH 21/27] fix(sql): add_count more robust, can mutate group cols, supports across --- siuba/sql/verbs/count.py | 31 +++++++++++++++++++++++++++---- siuba/tests/test_verb_across.py | 2 +- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/siuba/sql/verbs/count.py b/siuba/sql/verbs/count.py index c68c9f2d..51d0367f 100644 --- a/siuba/sql/verbs/count.py +++ b/siuba/sql/verbs/count.py @@ -11,8 +11,9 @@ from siuba.dply.verbs import count, add_count, inner_join -from ..utils import _sql_select, lift_inner_cols +from ..utils import _sql_select, _sql_add_columns, lift_inner_cols from ..backend import LazyTbl, ordered_union +from ..translate import AggOver from .mutate import _mutate_cols @@ -73,7 +74,29 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): @add_count.register(LazyTbl) def _add_count(__data, *args, wt = None, sort = False, **kwargs): - counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_select.inner_columns)[:-1] - return inner_join(__data, counts, by = by) + res_name = "n" + + result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") + + # TODO: if clause copied from count + # remove unnecessary select, if we're operating on a table ---- + if set(lift_inner_cols(sel_inner)) == set(lift_inner_cols(__data.last_select)): + sel_inner = __data.last_select + + inner_cols = lift_inner_cols(sel_inner) + + + # TODO: this code to append groups to columns copied a lot inside verbs + # apply any group vars from a group_by verb call first + missing = [k for k in __data.group_by if k not in result_names] + + all_group_names = ordered_union(__data.group_by, result_names) + outer_group_cols = [inner_cols[k] for k in all_group_names] + + + count_col = AggOver(sql.functions.count(), partition_by=outer_group_cols) + + sel_appended = _sql_add_columns(sel_inner, [count_col.label(res_name)]) + + return __data.append_op(sel_appended) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index e4791dc6..2a51fb30 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -230,7 +230,7 @@ def test_across_in_filter_equiv_ungrouped(backend, df): @pytest.mark.parametrize("f", [ #(arrange), (verbs.count), - #(verbs.add_count), + (verbs.add_count), (verbs.distinct), (verbs.group_by), (verbs.transmute), From 73d1b6ecd3834c8e70d4a2c04342ec0e690dbb07 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Thu, 6 Oct 2022 16:21:37 -0400 Subject: [PATCH 22/27] fix: count, add_count proper name arg support --- siuba/dply/verbs.py | 21 +++------ siuba/sql/verbs/count.py | 36 ++++----------- siuba/tests/test_verb_count.py | 83 +++++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 43 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 78385103..32d44313 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -1154,16 +1154,8 @@ def _case_when(__data, cases): # Count ======================================================================= -def _count_group(data, *args): - crnt_cols = set(data.columns) - out_col = "n" - while out_col in crnt_cols: out_col = out_col + "n" - - return - - @singledispatch2((pd.DataFrame, DataFrameGroupBy)) -def count(__data, *args, wt = None, sort = False, **kwargs): +def count(__data, *args, wt = None, sort = False, name=None, **kwargs): """Summarize data with the number of rows for each grouping of data. Parameters @@ -1233,9 +1225,7 @@ def count(__data, *args, wt = None, sort = False, **kwargs): # count col named, n. If that col already exists, add more "n"s... - crnt_cols = set(counts.columns) - out_col = "n" - while out_col in crnt_cols: out_col = out_col + "n" + out_col = _check_name(name, set(counts.columns)) # rename the tally column to correct name counts.rename(columns = {counts.columns[-1]: out_col}, inplace = True) @@ -1252,9 +1242,10 @@ def _check_name(name, columns): while name in columns: name = name + "n" - if name != "n": - # TODO: warning - pass + elif name != "n" and name in columns: + raise ValueError( + f"Column name `{name}` specified for count name, but is already present in data." + ) elif not isinstance(name, str): raise TypeError("`name` must be a single string.") diff --git a/siuba/sql/verbs/count.py b/siuba/sql/verbs/count.py index 51d0367f..861cf4d0 100644 --- a/siuba/sql/verbs/count.py +++ b/siuba/sql/verbs/count.py @@ -9,7 +9,7 @@ from sqlalchemy import sql -from siuba.dply.verbs import count, add_count, inner_join +from siuba.dply.verbs import count, add_count, inner_join, _check_name from ..utils import _sql_select, _sql_add_columns, lift_inner_cols from ..backend import LazyTbl, ordered_union @@ -18,28 +18,10 @@ from .mutate import _mutate_cols - @count.register(LazyTbl) -def _count(__data, *args, sort = False, wt = None, **kwargs): - # TODO: if already col named n, use name nn, etc.. get logic from tidy.py +def _count(__data, *args, sort = False, wt = None, name=None, **kwargs): if wt is not None: - raise NotImplementedError("TODO") - - res_name = "n" - # similar to filter verb, we need two select statements, - # an inner one for derived cols, and outer to group by them - - # inner select ---- - # holds any mutation style columns - #arg_names = [] - #for arg in args: - # name = simple_varname(arg) - # if name is None: - # raise NotImplementedError( - # "Count positional arguments must be single column name. " - # "Use a named argument to count using complex expressions." - # ) - # arg_names.append(name) + raise NotImplementedError("wt argument is currently not implemented") result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") @@ -59,7 +41,8 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): outer_group_cols = [inner_cols[k] for k in all_group_names] # holds the actual count (e.g. n) - count_col = sql.functions.count().label(res_name) + label_n = _check_name(name, set(inner_cols.keys())) + count_col = sql.functions.count().label(label_n) sel_outer = _sql_select([*outer_group_cols, count_col]) \ .select_from(sel_inner_cte) \ @@ -73,9 +56,9 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): @add_count.register(LazyTbl) -def _add_count(__data, *args, wt = None, sort = False, **kwargs): - - res_name = "n" +def _add_count(__data, *args, wt = None, sort = False, name=None, **kwargs): + if wt is not None: + raise NotImplementedError("wt argument is currently not implemented") result_names, sel_inner = _mutate_cols(__data, args, kwargs, "Count") @@ -96,7 +79,8 @@ def _add_count(__data, *args, wt = None, sort = False, **kwargs): count_col = AggOver(sql.functions.count(), partition_by=outer_group_cols) + label_n = _check_name(name, set(inner_cols.keys())) - sel_appended = _sql_add_columns(sel_inner, [count_col.label(res_name)]) + sel_appended = _sql_add_columns(sel_inner, [count_col.label(label_n)]) return __data.append_op(sel_appended) diff --git a/siuba/tests/test_verb_count.py b/siuba/tests/test_verb_count.py index 6816d3b5..3117ae44 100644 --- a/siuba/tests/test_verb_count.py +++ b/siuba/tests/test_verb_count.py @@ -4,7 +4,7 @@ https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-mutate.R """ -from siuba import _, group_by, summarize, count +from siuba import _, group_by, summarize, count, add_count, collect import pandas as pd import pytest @@ -41,7 +41,6 @@ def test_count_with_expression(df): ) -@pytest.mark.skip("TODO: sql support kwargs in count (#68)") def test_count_with_kwarg_expression(df): assert_equal_query( df, @@ -49,6 +48,15 @@ def test_count_with_kwarg_expression(df): pd.DataFrame({"y": [0], "n": [4]}) ) + +def test_add_count_with_kwarg_expression(df): + assert_equal_query( + df, + add_count(y = _.x - _.x), + DATA.assign(y = 0, n = 4) + ) + + @backend_notimpl("sql") # see (#104) def test_count_wt(backend, df): assert_equal_query( @@ -57,6 +65,16 @@ def test_count_wt(backend, df): pd.DataFrame({'g': ['a', 'b'], 'n': [1 + 2, 3 + 4]}) ) + +@backend_notimpl("sql") # see (#104) +def test_add_count_wt(backend, df): + assert_equal_query( + df, + add_count(_.g, wt = _.x), + DATA.assign(n = [3, 3, 7, 7]) + ) + + def test_count_no_groups(df): # count w/ no groups returns ttl assert_equal_query( @@ -65,6 +83,14 @@ def test_count_no_groups(df): pd.DataFrame({'n': [4]}) ) + +def test_add_count_no_groups(df): + assert_equal_query( + df, + add_count(), + DATA.assign(n = 4), + ) + @backend_notimpl("sql") # see (#104) def test_count_no_groups_wt(backend, df): assert_equal_query( @@ -82,6 +108,14 @@ def test_count_on_grouped_df(df2): ) +def test_add_count_on_grouped_df(df2): + assert_equal_query( + df2, + group_by(_.g) >> add_count(_.h), + DATA2.assign(n = [2]*4) + ) + + def test_count_on_grouped_df_when_mutating_group_key(df): assert_equal_query( df, @@ -89,3 +123,48 @@ def test_count_on_grouped_df_when_mutating_group_key(df): pd.DataFrame({"g": ["az", "bz"], "n": [2, 2]}) ) + +def test_add_count_on_grouped_df_when_mutating_group_key(df): + assert_equal_query( + df, + group_by(_.g) >> add_count(g = _.g + "z"), + pd.DataFrame(DATA.assign(g = ["az", "az", "bz", "bz"], n = [2]*4)) + ) + + +def test_count_name_unique(backend): + df = data_frame(x = [1, 2], n = [3, 3]) + src = backend.load_df(df) + + res = data_frame(n = [3], nn = [2]) + + assert_equal_query( + df, + count(_, _.n), + res + ) + + +def test_add_count_name_unique(backend): + df = data_frame(x = [1, 2], n = [3, 3]) + src = backend.load_df(df) + + res = data_frame(x = [1, 2], n = [3, 3], nn = [2, 2]) + + assert_equal_query( + df, + add_count(_, _.n), + res + ) + + +def test_count_name_manual_conflict(backend): + df = data_frame(x = [1, 2], n = [3, 3]) + src = backend.load_df(df) + + res = data_frame(n = [3], nn = [2]) + + with pytest.raises(ValueError) as exc_info: + df >> count(_, _.x, name = "x") >> collect() + + assert "Column name `x` specified for count name, but" in exc_info.value.args[0] From 3e5502cd916581ae337f34099b0c1840ed76c607 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Fri, 7 Oct 2022 19:30:36 -0400 Subject: [PATCH 23/27] fix(sql): do not create subquery for custom sql_raw --- siuba/sql/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/siuba/sql/backend.py b/siuba/sql/backend.py index 02d6eed2..ea2b2c9d 100644 --- a/siuba/sql/backend.py +++ b/siuba/sql/backend.py @@ -80,10 +80,10 @@ def visit(self, el): elif el in self.src_columns: return self.dst_columns[el.name] - elif isinstance(el, ColumnClause) and not isinstance(el, Column): - # Raw SQL, which will need a subquery, but not substitution - if el.key != "*": - self.applied = True + #elif isinstance(el, ColumnClause) and not isinstance(el, Column): + # # Raw SQL, which will need a subquery, but not substitution + # if el.key != "*": + # self.applied = True return None From 95b6dbd1e8aed487556412deaabec5293660bcce Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Fri, 7 Oct 2022 19:32:09 -0400 Subject: [PATCH 24/27] refactor!: remove tests of using function to tidyselect columns --- siuba/tests/test_dply_verbs.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/siuba/tests/test_dply_verbs.py b/siuba/tests/test_dply_verbs.py index 0b70c35a..0a3ff940 100644 --- a/siuba/tests/test_dply_verbs.py +++ b/siuba/tests/test_dply_verbs.py @@ -107,20 +107,6 @@ def test_VarList_getitem(): assert res[1].name == "c" - -# Select ---------------------------------------------------------------------- - -from siuba.dply.verbs import select - -def test_varlist_multi_slice(df1): - out = select(df1, lambda _: _["repo", "owner"]) - assert out.columns.tolist() == ["repo", "owner"] - -def test_varlist_multi_slice_negate(df1): - out = select(df1, lambda _: -_["repo", "owner"]) - assert out.columns.tolist() == ["language", "stars", "x"] - - # Distinct -------------------------------------------------------------------- from siuba.dply.verbs import distinct From 570fd9f610d755ce6825c0f1958375cd4356ce63 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 10 Oct 2022 10:21:33 -0400 Subject: [PATCH 25/27] fix: clean up arrange, raise on across for now --- siuba/dply/verbs.py | 11 +++++++- siuba/sql/backend.py | 34 +++++++----------------- siuba/sql/verbs/arrange.py | 46 ++++++++++++++++++++++----------- siuba/tests/test_verb_across.py | 11 ++++++++ 4 files changed, 61 insertions(+), 41 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 32d44313..c7eba6fb 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -884,6 +884,7 @@ def arrange(__data, *args): ascending = [] for ii, arg in enumerate(args): f, asc = _call_strip_ascending(arg) + ascending.append(asc) col = simple_varname(f) @@ -894,7 +895,15 @@ def arrange(__data, *args): sort_cols.append(n_cols + ii) tmp_cols.append(n_cols + ii) - df[n_cols + ii] = f(df) + res = f(df) + + if isinstance(res, pd.DataFrame): + raise NotImplementedError( + f"`arrange()` expression {ii} of {len(args)} returned a " + "DataFrame, which is currently unsupported." + ) + + df[n_cols + ii] = res return df.sort_values(by = sort_cols, kind = "mergesort", ascending = ascending) \ diff --git a/siuba/sql/backend.py b/siuba/sql/backend.py index ea2b2c9d..d2ed3cc6 100644 --- a/siuba/sql/backend.py +++ b/siuba/sql/backend.py @@ -88,36 +88,15 @@ def visit(self, el): return None -# TODO: consolidate / pull expr handling funcs into own file? -def _create_order_by_clause(columns, *args): - from siuba.dply.verbs import _call_strip_ascending - - sort_cols = [] - for arg in args: - # simple named column - if isinstance(arg, str): - sort_cols.append(columns[arg]) - # an expression - elif callable(arg): - # handle special case where -_.colname -> colname DESC - f, asc = _call_strip_ascending(arg) - col_op = f(columns) if asc else f(columns).desc() - #col_op = arg(columns) - sort_cols.append(col_op) - else: - raise NotImplementedError("Must be string or callable") - - return sort_cols - def track_call_windows(call, columns, group_by, order_by, window_cte = None): col_expr = call(columns) crnt_group_by = sql.elements.ClauseList( *[columns[name] for name in group_by] ) - crnt_order_by = sql.elements.ClauseList( - *_create_order_by_clause(columns, *order_by) - ) + + crnt_order_by = sql.elements.ClauseList(*order_by) + return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte) @@ -262,8 +241,13 @@ def shape_call( def track_call_windows(self, call, columns = None, window_cte = None): """Returns tuple of (new column expression, list of window exprs)""" + from .verbs.arrange import _eval_arrange_args + columns = self.last_op.columns if columns is None else columns - return track_call_windows(call, columns, self.group_by, self.order_by, window_cte) + + order_by = _eval_arrange_args(self, self.order_by, columns) + + return track_call_windows(call, columns, self.group_by, order_by, window_cte) def get_ordered_col_names(self): """Return columns from current select, with grouping columns first.""" diff --git a/siuba/sql/verbs/arrange.py b/siuba/sql/verbs/arrange.py index a981c63d..51048dd4 100644 --- a/siuba/sql/verbs/arrange.py +++ b/siuba/sql/verbs/arrange.py @@ -1,6 +1,10 @@ -from siuba.dply.verbs import arrange +from sqlalchemy.sql.base import ImmutableColumnCollection + +from siuba.dply.verbs import arrange, _call_strip_ascending +from siuba.dply.across import _set_data_context + from ..utils import lift_inner_cols -from ..backend import LazyTbl, _create_order_by_clause +from ..backend import LazyTbl # Helpers --------------------------------------------------------------------- @@ -15,23 +19,35 @@ def _arrange(__data, *args): cols = lift_inner_cols(last_sel) # TODO: implement across in arrange - #exprs, _ = _mutate_cols(__data, args, kwargs, "Arrange", arrange_clause=True) + sort_cols = _eval_arrange_args(__data, args, cols) + + order_by = __data.order_by + tuple(args) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) - new_calls = [] + +def _eval_arrange_args(__data, args, cols): + sort_cols = [] for ii, expr in enumerate(args): - if callable(expr): + shaped = __data.shape_call( + expr, window = False, str_accessors = True, + verb_name = "Arrange", arg_name = ii, + ) + + new_call, ascending = _call_strip_ascending(shaped) - res = __data.shape_call( - expr, window = False, - verb_name = "Arrange", arg_name = ii - ) + with _set_data_context(__data, window=True): + res = new_call(cols) - else: - res = expr + if isinstance(res, ImmutableColumnCollection): + raise NotImplementedError( + f"`arrange()` expression {ii} of {len(args)} returned multiple columns, " + "which is currently unsupported." + ) - new_calls.append(res) + if not ascending: + res = res.desc() - sort_cols = _create_order_by_clause(cols, *new_calls) + sort_cols.append(res) + + return sort_cols - order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) diff --git a/siuba/tests/test_verb_across.py b/siuba/tests/test_verb_across.py index 2a51fb30..ddad1fd5 100644 --- a/siuba/tests/test_verb_across.py +++ b/siuba/tests/test_verb_across.py @@ -250,6 +250,17 @@ def test_across_in_verb(backend, df, f): assert_equal_query2(ungroup(res), ungroup(dst)) +def test_across_in_arrange_unsupported(backend, df): + src = backend.load_df(df) + expr_across = across(_, _[_.a_x, _.a_y], Fx % 2 > 0) + + with pytest.raises(NotImplementedError) as exc_info: + res = verbs.arrange(src, expr_across) + + assert "`arrange()` expression 0 of 1 returned" in exc_info.value.args[0] + + + def test_across_formula_and_underscore(df): res = across(df, _[_.a_x, _.a_y], f_round(Fx) / _.b_x) From 13ad7f64f1a6033d7065c9e2158ce0390edc5491 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Mon, 10 Oct 2022 10:32:41 -0400 Subject: [PATCH 26/27] fix(sql)!: arrange resets order_by vars, matches dbplyr --- siuba/sql/verbs/arrange.py | 4 ++-- siuba/tests/test_verb_arrange.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/siuba/sql/verbs/arrange.py b/siuba/sql/verbs/arrange.py index 51048dd4..2fd7bd98 100644 --- a/siuba/sql/verbs/arrange.py +++ b/siuba/sql/verbs/arrange.py @@ -21,8 +21,8 @@ def _arrange(__data, *args): # TODO: implement across in arrange sort_cols = _eval_arrange_args(__data, args, cols) - order_by = __data.order_by + tuple(args) - return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) + final_sel = last_sel.order_by(None).order_by(*sort_cols) + return __data.append_op(final_sel, order_by = tuple(args)) def _eval_arrange_args(__data, args, cols): diff --git a/siuba/tests/test_verb_arrange.py b/siuba/tests/test_verb_arrange.py index ca588fdc..49fa36f3 100644 --- a/siuba/tests/test_verb_arrange.py +++ b/siuba/tests/test_verb_arrange.py @@ -83,7 +83,7 @@ def test_arranges_back_to_back(backend): lazy_tbl = dfs >> arrange(_.x) >> arrange(_.g) order_by_vars = tuple(simple_varname(call) for call in lazy_tbl.order_by) - assert order_by_vars == ("x", "g") - assert [c.name for c in lazy_tbl.last_op._order_by_clause] == ["x", "g"] + assert order_by_vars == ("g",) + assert [c.name for c in lazy_tbl.last_op._order_by_clause] == ["g"] From 3d4a79a29e1d54362c055d3d1e3e0e8cd50ac14f Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 11 Oct 2022 19:38:55 -0400 Subject: [PATCH 27/27] tests: more tests of mutating after a summarize --- siuba/tests/test_verb_mutate.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py index decb2ddf..6a4359f6 100644 --- a/siuba/tests/test_verb_mutate.py +++ b/siuba/tests/test_verb_mutate.py @@ -119,5 +119,35 @@ def test_mutate_overwrites_prev(backend): ) +def test_mutate_after_summarize_on_non_derived_column(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(avg_g = _.g.mean()) + assert_equal_query( + dfs, + query, + data_frame(g = [1,2], avg = [1,2], avg_g = 1.5) + ) + + +def test_mutate_after_summarize_on_derived_column(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(avg_avg = _.avg.mean()) + assert_equal_query( + dfs, + query, + data_frame(g = [1,2], avg = [1,2], avg_avg = 1.5) + ) + + +def test_mutate_after_summarize_limits_column_access(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(x2 = _.x + 1) + + with pytest.raises(AttributeError) as exc_info: + query(dfs) + + + assert "x" in exc_info.value.args[0]