Skip to content

Commit

Permalink
Merge pull request #449 from machow/feat-sql-last-select
Browse files Browse the repository at this point in the history
feat(sql): add LazyTbl.last_select to simplify queries
  • Loading branch information
machow authored Sep 28, 2022
2 parents 1faa215 + 11c8327 commit ddc63af
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 53 deletions.
10 changes: 6 additions & 4 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2463,12 +2463,14 @@ def tbl(src, *args, **kwargs):
You can analyze a mock table
>>> from sqlalchemy import create_mock_engine
>>> from siuba import _
>>> mock_engine = create_mock_engine("postgresql:///", lambda *args, **kwargs: None)
>>> tbl_mock = tbl(mock_engine, "some_table", columns = ["a", "b", "c"])
>>> q = tbl_mock >> count() >> show_query() # doctest: +NORMALIZE_WHITESPACE
SELECT count(*) AS n
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
FROM some_table) AS anon_1 ORDER BY n DESC
>>> q = tbl_mock >> count(_.a) >> show_query() # doctest: +NORMALIZE_WHITESPACE
SELECT some_table_1.a, count(*) AS n
FROM some_table AS some_table_1 GROUP BY some_table_1.a ORDER BY n DESC
"""

return src
Expand Down
18 changes: 0 additions & 18 deletions siuba/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,3 @@ def simplify_sel(sel):

return clone_el


@contextmanager
def _use_simple_names():
from sqlalchemy import sql
from sqlalchemy.ext.compiler import compiles, deregister

get_col_name = lambda el, *args, **kwargs: str(el.element.name)
get_lab_name = lambda el, *args, **kwargs: str(el.element.name)
get_col_name = lambda el, *args, **kwargs: str(el.name)
compiles(sql.compiler._CompileLabel)(get_lab_name)
compiles(sql.elements.ColumnClause)(get_col_name)
compiles(sql.schema.Column)(get_col_name)
try:
yield 1
except:
pass
finally:
deregister(sql.compiler._CompileLabel)
69 changes: 41 additions & 28 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
_sql_add_columns,
_sql_with_only_columns,
_sql_simplify_select,
_use_simple_names,
MockConnection
)

Expand Down Expand Up @@ -297,7 +296,7 @@ def __init__(
self.tbl = self._create_table(tbl, columns, self.source)

# important states the query can be in (e.g. grouped)
self.ops = [self.tbl.select()] if ops is None else ops
self.ops = [self.tbl] if ops is None else ops

self.group_by = group_by
self.order_by = order_by
Expand Down Expand Up @@ -340,8 +339,21 @@ def get_ordered_col_names(self):


@property
def last_op(self):
return self.ops[-1] if len(self.ops) else None
def last_op(self) -> "sql.Table | sql.Select":
last_op = self.ops[-1]

if last_op is None:
raise TypeError()

return last_op

@property
def last_select(self):
last_op = self.last_op
if not isinstance(last_op, sql.selectable.SelectBase):
return last_op.select()

return last_op

@staticmethod
def _create_table(tbl, columns = None, source = None):
Expand Down Expand Up @@ -385,7 +397,7 @@ def _create_table(tbl, columns = None, source = None):

def _get_preview(self):
# need to make prev op a cte, so we don't override any previous limit
new_sel = self.last_op.alias().select().limit(5)
new_sel = self.last_select.limit(5)
tbl_small = self.append_op(new_sel)
return collect(tbl_small)

Expand Down Expand Up @@ -450,13 +462,12 @@ def _show_query(tbl, simplify = False, return_table = True):

if simplify:
# try to strip table names and labels where unnecessary
simple_sel = _sql_simplify_select(tbl.last_op)
simple_sel = _sql_simplify_select(tbl.last_select)

with _use_simple_names():
explained = compile_query(simple_sel)
explained = compile_query(simple_sel)
else:
# use a much more verbose query
explained = compile_query(tbl.last_op)
explained = compile_query(tbl.last_select)

if return_table:
print(str(explained))
Expand All @@ -483,13 +494,13 @@ def _collect(__data, as_df = True):
if _is_dialect_duckdb(__data.source):
# TODO: can be removed once next release of duckdb fixes:
# https://github.com/duckdb/duckdb/issues/2972
query = __data.last_op
query = __data.last_select
compiled = query.compile(
dialect = __data.source.dialect,
compile_kwargs = {"literal_binds": True}
)
else:
compiled = __data.last_op
compiled = __data.last_select

# execute query ----

Expand Down Expand Up @@ -519,8 +530,8 @@ def _select(__data, *args, **kwargs):
"Using kwargs in select not currently supported. "
"Use _.newname == _.oldname instead"
)
last_op = __data.last_op
columns = {c.key: c for c in last_op.inner_columns}
last_sel = __data.last_select
columns = {c.key: c for c in last_sel.inner_columns}

# same as for DataFrame
colnames = Series(list(columns))
Expand All @@ -541,7 +552,7 @@ def _select(__data, *args, **kwargs):
col_list.append(col if v is None else col.label(v))

return __data.append_op(
last_op.with_only_columns(col_list),
last_sel.with_only_columns(col_list),
group_by = group_keys
)

Expand Down Expand Up @@ -610,7 +621,10 @@ def _mutate(__data, **kwargs):
# TODO: verify it can follow a renaming select

# track labeled columns in set
sel = __data.last_op
if not len(kwargs):
return __data.append_op(__data.last_op)

sel = __data.last_select

# evaluate each call
for colname, func in kwargs.items():
Expand Down Expand Up @@ -664,7 +678,7 @@ def _transmute(__data, **kwargs):
# transmute keeps grouping cols, and any defined in kwargs
cols_to_keep = ordered_union(__data.group_by, kwargs)

sel = f_mutate(__data, **kwargs).last_op
sel = f_mutate(__data, **kwargs).last_select

columns = lift_inner_cols(sel)
sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep])
Expand All @@ -679,8 +693,8 @@ def _arrange(__data, *args):
# and handle when new columns are named the same as order by vars.
# see: https://dba.stackexchange.com/q/82930

last_op = __data.last_op
cols = lift_inner_cols(last_op)
last_sel = __data.last_select
cols = lift_inner_cols(last_sel)


new_calls = []
Expand All @@ -700,7 +714,7 @@ def _arrange(__data, *args):
sort_cols = _create_order_by_clause(cols, *new_calls)

order_by = __data.order_by + tuple(new_calls)
return __data.append_op(last_op.order_by(*sort_cols), order_by = order_by)
return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by)


# TODO: consolidate / pull expr handling funcs into own file?
Expand Down Expand Up @@ -746,8 +760,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
)
arg_names.append(name)

tbl_inner = mutate(__data, **kwargs)
sel_inner = tbl_inner.last_op
sel_inner = mutate(__data, **kwargs).last_op
group_cols = arg_names + list(kwargs)

# create outer select ----
Expand All @@ -756,7 +769,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
inner_cols = sel_inner_cte.columns

# apply any group vars from a group_by verb call first
tbl_group_cols = [inner_cols[k] for k in tbl_inner.group_by]
tbl_group_cols = [inner_cols[k] for k in __data.group_by]
count_group_cols = [inner_cols[k] for k in group_cols]

# combine with any defined in the count verb call
Expand All @@ -769,7 +782,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
.group_by(*outer_group_cols)

# count is like summarize, so removes order_by
return tbl_inner.append_op(
return __data.append_op(
sel_outer.order_by(count_col.desc()),
order_by = tuple()
)
Expand All @@ -778,7 +791,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
@add_count.register(LazyTbl)
def _add_count(__data, *args, wt = None, sort = False, **kwargs):
counts = count(__data, *args, wt = wt, sort = sort, **kwargs)
by = list(c.name for c in counts.last_op.inner_columns)[:-1]
by = list(c.name for c in counts.last_select.inner_columns)[:-1]

return inner_join(__data, counts, by = by)

Expand All @@ -789,7 +802,7 @@ def _summarize(__data, **kwargs):
# what if windowed mutate or filter has been done?
# - filter is fine, since it uses a CTE
# - need to detect any window functions...
old_sel = __data.last_op._clone()
old_sel = __data.last_select._clone()

new_calls = {}
for k, expr in kwargs.items():
Expand Down Expand Up @@ -1136,7 +1149,7 @@ def _create_join_conds(left_sel, right_sel, on):

@head.register(LazyTbl)
def _head(__data, n = 5):
sel = __data.last_op
sel = __data.last_select

return __data.append_op(sel.limit(n))

Expand All @@ -1145,7 +1158,7 @@ def _head(__data, n = 5):

@rename.register(LazyTbl)
def _rename(__data, **kwargs):
sel = __data.last_op
sel = __data.last_select
columns = lift_inner_cols(sel)

# old_keys uses dict as ordered set
Expand All @@ -1172,7 +1185,7 @@ def _distinct(__data, *args, _keep_all = False, **kwargs):
if (args or kwargs) and _keep_all:
raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False")

inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op
inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select

# TODO: this is copied from the df distinct version
# cols dict below is used as ordered set
Expand Down
6 changes: 3 additions & 3 deletions siuba/tests/test_verb_show_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def test_show_query_basic_simplify(df_tiny):
q = df_tiny >> mutate(a = _.x.mean()) >> show_query(return_table = False, simplify=True)

assert rename_source(str(q), df_tiny) == """\
SELECT *, avg(x) OVER () AS a
SELECT *, avg(SRC_TBL.x) OVER () AS a
FROM SRC_TBL"""

def test_show_query_complex_simplify(df_wide):
q = df_wide >> mutate(a = _.x.mean(), b = _.a.mean())
res = q >> show_query(return_table = False, simplify=True)

assert rename_source(str(res), df_wide) == """\
SELECT *, avg(a) OVER () AS b
FROM (SELECT *, avg(x) OVER () AS a
SELECT *, avg(anon_1.a) OVER () AS b
FROM (SELECT *, avg(SRC_TBL.x) OVER () AS a
FROM SRC_TBL) AS anon_1"""

0 comments on commit ddc63af

Please sign in to comment.