From 1f8b251b6eddc8c2b75fb2e46b2a5efaf5df6a81 Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 27 Sep 2022 14:21:44 -0400 Subject: [PATCH 1/2] feat(sql): add LazyTbl.last_select to simplify queries --- siuba/dply/verbs.py | 7 +++-- siuba/sql/verbs.py | 65 ++++++++++++++++++++++++++++----------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 401698b0..024c9133 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -2449,10 +2449,9 @@ def tbl(src, *args, **kwargs): >>> from sqlalchemy import create_mock_engine >>> mock_engine = create_mock_engine("postgresql:///", lambda *args, **kwargs: None) >>> tbl_mock = tbl(mock_engine, "some_table", columns = ["a", "b", "c"]) - >>> q = tbl_mock >> count() >> show_query() # doctest: +NORMALIZE_WHITESPACE - SELECT count(*) AS n - FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c - FROM some_table) AS anon_1 ORDER BY n DESC + >>> q = tbl_mock >> count(_.a) >> show_query() # doctest: +NORMALIZE_WHITESPACE + SELECT a, count(*) AS n + FROM some_table AS some_table_1 GROUP BY a ORDER BY n DESC """ return src diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index a1c82b03..1fc8dc88 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -297,7 +297,7 @@ def __init__( self.tbl = self._create_table(tbl, columns, self.source) # important states the query can be in (e.g. grouped) - self.ops = [self.tbl.select()] if ops is None else ops + self.ops = [self.tbl] if ops is None else ops self.group_by = group_by self.order_by = order_by @@ -340,8 +340,21 @@ def get_ordered_col_names(self): @property - def last_op(self): - return self.ops[-1] if len(self.ops) else None + def last_op(self) -> "sql.Table | sql.Select": + last_op = self.ops[-1] + + if last_op is None: + raise TypeError() + + return last_op + + @property + def last_select(self): + last_op = self.last_op + if not isinstance(last_op, sql.selectable.SelectBase): + return last_op.select() + + return last_op @staticmethod def _create_table(tbl, columns = None, source = None): @@ -385,7 +398,7 @@ def _create_table(tbl, columns = None, source = None): def _get_preview(self): # need to make prev op a cte, so we don't override any previous limit - new_sel = self.last_op.alias().select().limit(5) + new_sel = self.last_select.limit(5) tbl_small = self.append_op(new_sel) return collect(tbl_small) @@ -450,13 +463,13 @@ def _show_query(tbl, simplify = False, return_table = True): if simplify: # try to strip table names and labels where unnecessary - simple_sel = _sql_simplify_select(tbl.last_op) + simple_sel = _sql_simplify_select(tbl.last_select) with _use_simple_names(): explained = compile_query(simple_sel) else: # use a much more verbose query - explained = compile_query(tbl.last_op) + explained = compile_query(tbl.last_select) if return_table: print(str(explained)) @@ -483,13 +496,13 @@ def _collect(__data, as_df = True): if _is_dialect_duckdb(__data.source): # TODO: can be removed once next release of duckdb fixes: # https://github.com/duckdb/duckdb/issues/2972 - query = __data.last_op + query = __data.last_select compiled = query.compile( dialect = __data.source.dialect, compile_kwargs = {"literal_binds": True} ) else: - compiled = __data.last_op + compiled = __data.last_select # execute query ---- @@ -519,8 +532,8 @@ def _select(__data, *args, **kwargs): "Using kwargs in select not currently supported. " "Use _.newname == _.oldname instead" ) - last_op = __data.last_op - columns = {c.key: c for c in last_op.inner_columns} + last_sel = __data.last_select + columns = {c.key: c for c in last_sel.inner_columns} # same as for DataFrame colnames = Series(list(columns)) @@ -541,7 +554,7 @@ def _select(__data, *args, **kwargs): col_list.append(col if v is None else col.label(v)) return __data.append_op( - last_op.with_only_columns(col_list), + last_sel.with_only_columns(col_list), group_by = group_keys ) @@ -610,7 +623,10 @@ def _mutate(__data, **kwargs): # TODO: verify it can follow a renaming select # track labeled columns in set - sel = __data.last_op + if not len(kwargs): + return __data.append_op(__data.last_op) + + sel = __data.last_select # evaluate each call for colname, func in kwargs.items(): @@ -664,7 +680,7 @@ def _transmute(__data, **kwargs): # transmute keeps grouping cols, and any defined in kwargs cols_to_keep = ordered_union(__data.group_by, kwargs) - sel = f_mutate(__data, **kwargs).last_op + sel = f_mutate(__data, **kwargs).last_select columns = lift_inner_cols(sel) sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep]) @@ -679,8 +695,8 @@ def _arrange(__data, *args): # and handle when new columns are named the same as order by vars. # see: https://dba.stackexchange.com/q/82930 - last_op = __data.last_op - cols = lift_inner_cols(last_op) + last_sel = __data.last_select + cols = lift_inner_cols(last_sel) new_calls = [] @@ -700,7 +716,7 @@ def _arrange(__data, *args): sort_cols = _create_order_by_clause(cols, *new_calls) order_by = __data.order_by + tuple(new_calls) - return __data.append_op(last_op.order_by(*sort_cols), order_by = order_by) + return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by) # TODO: consolidate / pull expr handling funcs into own file? @@ -746,8 +762,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): ) arg_names.append(name) - tbl_inner = mutate(__data, **kwargs) - sel_inner = tbl_inner.last_op + sel_inner = mutate(__data, **kwargs).last_op group_cols = arg_names + list(kwargs) # create outer select ---- @@ -756,7 +771,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): inner_cols = sel_inner_cte.columns # apply any group vars from a group_by verb call first - tbl_group_cols = [inner_cols[k] for k in tbl_inner.group_by] + tbl_group_cols = [inner_cols[k] for k in __data.group_by] count_group_cols = [inner_cols[k] for k in group_cols] # combine with any defined in the count verb call @@ -769,7 +784,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): .group_by(*outer_group_cols) # count is like summarize, so removes order_by - return tbl_inner.append_op( + return __data.append_op( sel_outer.order_by(count_col.desc()), order_by = tuple() ) @@ -778,7 +793,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs): @add_count.register(LazyTbl) def _add_count(__data, *args, wt = None, sort = False, **kwargs): counts = count(__data, *args, wt = wt, sort = sort, **kwargs) - by = list(c.name for c in counts.last_op.inner_columns)[:-1] + by = list(c.name for c in counts.last_select.inner_columns)[:-1] return inner_join(__data, counts, by = by) @@ -789,7 +804,7 @@ def _summarize(__data, **kwargs): # what if windowed mutate or filter has been done? # - filter is fine, since it uses a CTE # - need to detect any window functions... - old_sel = __data.last_op._clone() + old_sel = __data.last_select._clone() new_calls = {} for k, expr in kwargs.items(): @@ -1136,7 +1151,7 @@ def _create_join_conds(left_sel, right_sel, on): @head.register(LazyTbl) def _head(__data, n = 5): - sel = __data.last_op + sel = __data.last_select return __data.append_op(sel.limit(n)) @@ -1145,7 +1160,7 @@ def _head(__data, n = 5): @rename.register(LazyTbl) def _rename(__data, **kwargs): - sel = __data.last_op + sel = __data.last_select columns = lift_inner_cols(sel) # old_keys uses dict as ordered set @@ -1172,7 +1187,7 @@ def _distinct(__data, *args, _keep_all = False, **kwargs): if (args or kwargs) and _keep_all: raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False") - inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op + inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select # TODO: this is copied from the df distinct version # cols dict below is used as ordered set From 11c8327072ed916a42056c10dcd324a354ab2fbe Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Tue, 27 Sep 2022 17:31:25 -0400 Subject: [PATCH 2/2] fix(sql)!: simplify_query no longer strips table names This functionality both produced queries that were invalid, and involved futzing with the sqlalchemy compiler, which appeared to break for certain sqlalchemy versions. --- siuba/dply/verbs.py | 7 +++++-- siuba/sql/utils.py | 18 ------------------ siuba/sql/verbs.py | 4 +--- siuba/tests/test_verb_show_query.py | 6 +++--- 4 files changed, 9 insertions(+), 26 deletions(-) diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 024c9133..ed2835a4 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -2447,11 +2447,14 @@ def tbl(src, *args, **kwargs): You can analyze a mock table >>> from sqlalchemy import create_mock_engine + >>> from siuba import _ + >>> mock_engine = create_mock_engine("postgresql:///", lambda *args, **kwargs: None) >>> tbl_mock = tbl(mock_engine, "some_table", columns = ["a", "b", "c"]) + >>> q = tbl_mock >> count(_.a) >> show_query() # doctest: +NORMALIZE_WHITESPACE - SELECT a, count(*) AS n - FROM some_table AS some_table_1 GROUP BY a ORDER BY n DESC + SELECT some_table_1.a, count(*) AS n + FROM some_table AS some_table_1 GROUP BY some_table_1.a ORDER BY n DESC """ return src diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index ed5b6a1a..7310b328 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -190,21 +190,3 @@ def simplify_sel(sel): return clone_el - -@contextmanager -def _use_simple_names(): - from sqlalchemy import sql - from sqlalchemy.ext.compiler import compiles, deregister - - get_col_name = lambda el, *args, **kwargs: str(el.element.name) - get_lab_name = lambda el, *args, **kwargs: str(el.element.name) - get_col_name = lambda el, *args, **kwargs: str(el.name) - compiles(sql.compiler._CompileLabel)(get_lab_name) - compiles(sql.elements.ColumnClause)(get_col_name) - compiles(sql.schema.Column)(get_col_name) - try: - yield 1 - except: - pass - finally: - deregister(sql.compiler._CompileLabel) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index 1fc8dc88..c03580a7 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -42,7 +42,6 @@ _sql_add_columns, _sql_with_only_columns, _sql_simplify_select, - _use_simple_names, MockConnection ) @@ -465,8 +464,7 @@ def _show_query(tbl, simplify = False, return_table = True): # try to strip table names and labels where unnecessary simple_sel = _sql_simplify_select(tbl.last_select) - with _use_simple_names(): - explained = compile_query(simple_sel) + explained = compile_query(simple_sel) else: # use a much more verbose query explained = compile_query(tbl.last_select) diff --git a/siuba/tests/test_verb_show_query.py b/siuba/tests/test_verb_show_query.py index 4065bf15..8504cc6a 100644 --- a/siuba/tests/test_verb_show_query.py +++ b/siuba/tests/test_verb_show_query.py @@ -32,7 +32,7 @@ def test_show_query_basic_simplify(df_tiny): q = df_tiny >> mutate(a = _.x.mean()) >> show_query(return_table = False, simplify=True) assert rename_source(str(q), df_tiny) == """\ -SELECT *, avg(x) OVER () AS a +SELECT *, avg(SRC_TBL.x) OVER () AS a FROM SRC_TBL""" def test_show_query_complex_simplify(df_wide): @@ -40,7 +40,7 @@ def test_show_query_complex_simplify(df_wide): res = q >> show_query(return_table = False, simplify=True) assert rename_source(str(res), df_wide) == """\ -SELECT *, avg(a) OVER () AS b -FROM (SELECT *, avg(x) OVER () AS a +SELECT *, avg(anon_1.a) OVER () AS b +FROM (SELECT *, avg(SRC_TBL.x) OVER () AS a FROM SRC_TBL) AS anon_1"""