Skip to content

Commit

Permalink
Merge pull request #234 from machow/feat-sql-count-kwargs
Browse files Browse the repository at this point in the history
Feat sql count kwargs
  • Loading branch information
machow authored May 20, 2020
2 parents 00c910b + eb02025 commit a3ca711
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
11 changes: 11 additions & 0 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@
# * separate_rows
# * tally

def install_siu_methods(cls):
"""This function attaches siuba's table verbs on a class, to use as methods.
"""
func_dict = globals()
for func_name in DPLY_FUNCTIONS:
f = func_dict[func_name]

method_name = "siu_{}".format(func_name)
setattr(cls, method_name, f)

def install_pd_siu():
# https://github.com/coursera/pandas-ply/blob/master/pandas_ply/methods.py
func_dict = globals()
Expand Down
1 change: 1 addition & 0 deletions siuba/sql/dialects/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def sql_func_contains(col, pat, case = True, flags = 0, na = None, regex = True)
concat = sql.func.concat,
cat = sql.func.concat,
str_c = sql.func.concat,
__floordiv__ = lambda x, y: sql.cast(x / y, sa_types.Integer())
)

aggregate = SqlTranslator(
Expand Down
40 changes: 20 additions & 20 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,32 +589,28 @@ def _create_order_by_clause(columns, *args):
@count.register(LazyTbl)
def _count(__data, *args, sort = False, wt = None, **kwargs):
# TODO: if already col named n, use name nn, etc.. get logic from tidy.py
if kwargs:
raise NotImplementedError("TODO")

if wt is not None:
raise NotImplementedError("TODO")

res_name = "n"
# similar to filter verb, we need two select statements,
# an inner one for derived cols, and outer to group by them
sel = __data.last_op.alias()
sel_inner = sql.select([sel], from_obj = sel)

# inner select ----
# holds any mutation style columns
group_cols = []
arg_names = []
for arg in args:
col_name = simple_varname(arg)
if col_name is None:
# evaluate call
col_expr = arg(sel.columns) if callable(arg) else arg

# compile, so we can use the expr as its name (e.g. "id + 1")
col_name = str(compile_el(__data, col_expr))
label = col_expr.label(col_name)
sel_inner.append_column(label)
name = simple_varname(arg)
if name is None:
raise NotImplementedError(
"Count positional arguments must be single column name. "
"Use a named argument to count using complex expressions."
)
arg_names.append(name)

group_cols.append(col_name)
tbl_inner = mutate(__data, **kwargs)
sel_inner = tbl_inner.last_op
group_cols = arg_names + list(kwargs)

# outer select ----
# holds selected columns and tally (n)
Expand All @@ -623,7 +619,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
sel_outer = sql.select(from_obj = sel_inner_cte)

# apply any group vars from a group_by verb call first
prev_group_cols = [inner_cols[k] for k in __data.group_by]
prev_group_cols = [inner_cols[k] for k in tbl_inner.group_by]
if prev_group_cols:
sel_outer.append_group_by(*prev_group_cols)
sel_outer.append_column(*prev_group_cols)
Expand All @@ -633,10 +629,14 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
sel_outer.append_group_by(inner_cols[k])
sel_outer.append_column(inner_cols[k])

sel_outer.append_column(sql.functions.count().label("n"))
count_col = sql.functions.count().label(res_name)
sel_outer.append_column(count_col)

return __data.append_op(sel_outer)

# count is like summarize, so removes order_by
return tbl_inner.append_op(
sel_outer.order_by(count_col.desc()),
order_by = tuple()
)


@summarize.register(LazyTbl)
Expand Down

0 comments on commit a3ca711

Please sign in to comment.