Skip to content

Commit

Permalink
Merge pull request #105 from machow/feat-count-total
Browse files Browse the repository at this point in the history
feat: count with no groups is like one big group
  • Loading branch information
machow authored Aug 8, 2019
2 parents 17e078b + 5f1b001 commit f5b5eb4
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 8 deletions.
46 changes: 39 additions & 7 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def simple_varname(call):
return None


def ordered_union(x, y):
# TODO: duplicated in sql file
dx = {el: True for el in x}
dy = {el: True for el in y}

return tuple({**dx, **dy})

# Symbolic Wrapper ============================================================

from functools import wraps
Expand Down Expand Up @@ -257,7 +264,7 @@ def _mutate(__data, **kwargs):
# Group By ====================================================================

@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def group_by(__data, *args, **kwargs):
def group_by(__data, *args, add = False, **kwargs):
tmp_df = mutate(__data, **kwargs) if kwargs else __data

by_vars = list(map(simple_varname, args))
Expand All @@ -266,6 +273,11 @@ def group_by(__data, *args, **kwargs):

by_vars.extend(kwargs.keys())

if isinstance(tmp_df, DataFrameGroupBy) and add:
prior_groups = [el.name for el in __data.grouper.groupings]
all_groups = ordered_union(prior_groups, by_vars)
return tmp_df.obj.groupby(list(all_groups))

return tmp_df.groupby(by = by_vars)


Expand Down Expand Up @@ -767,18 +779,38 @@ def _count_group(data, *args):
return


@singledispatch2(pd.DataFrame)
@singledispatch2((pd.DataFrame, DataFrameGroupBy))
def count(__data, *args, wt = None, sort = False, **kwargs):
# TODO: if expr, works like mutate
"""Return the number of rows for each grouping of data.
Args:
__data: a DataFrame
*args: the names of columns to be used for grouping. Passed to group_by.
wt: the name of a column to use as a weighted for each row.
sort: whether to sort the results in descending order.
**kwargs: creates a new named column, and uses for grouping. Passed to group_by.
"""
no_grouping_vars = not args and not kwargs and isinstance(__data, pd.DataFrame)

#group by args
if wt is None:
counts = group_by(__data, *args, **kwargs).size().reset_index()
if no_grouping_vars:
# no groups, just use number of rows
counts = pd.DataFrame({'tmp': [__data.shape[0]]})
else:
# tally rows for each group
counts = group_by(__data, *args, add = True, **kwargs).size().reset_index()
else:
wt_col = simple_varname(wt)
if wt_col is None:
raise Exception("wt argument has to be simple column name")
counts = group_by(__data, *args, **kwargs)[wt_col].sum().reset_index()

if no_grouping_vars:
# no groups, sum weights
counts = pd.DataFrame({'tmp': [__data[wt_col].sum()]})
else:
# do weighted tally
counts = group_by(__data, *args, add = True, **kwargs)[wt_col].sum().reset_index()


# count col named, n. If that col already exists, add more "n"s...
Expand All @@ -790,7 +822,7 @@ def count(__data, *args, wt = None, sort = False, **kwargs):
counts.rename(columns = {counts.columns[-1]: out_col}, inplace = True)

if sort:
return counts.sort_values(out_col, ascending = False)
return counts.sort_values(out_col, ascending = False).reset_index(drop = True)

return counts

Expand Down
5 changes: 4 additions & 1 deletion siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,14 @@ def _create_order_by_clause(columns, *args):


@count.register(LazyTbl)
def _count(__data, *args, sort = False, **kwargs):
def _count(__data, *args, sort = False, wt = None, **kwargs):
# TODO: if already col named n, use name nn, etc.. get logic from tidy.py
if kwargs:
raise NotImplementedError("TODO")

if wt is not None:
raise NotImplementedError("TODO")

# similar to filter verb, we need two select statements,
# an inner one for derived cols, and outer to group by them
sel = __data.last_op.alias()
Expand Down
40 changes: 40 additions & 0 deletions siuba/tests/test_verb_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@
"""

from siuba import _, group_by, summarize, count
import pandas as pd

import pytest
from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql

DATA = data_frame(x = [1,2,3,4], g = ['a', 'a', 'b', 'b'])
DATA2 = data_frame(x = [1,2,3,4], g = ['a', 'a', 'b', 'b'], h = ['c', 'c', 'd', 'd'])

@pytest.fixture(scope = "module")
def df(backend):
return backend.load_df(DATA)

@pytest.fixture(scope = "module")
def df2(backend):
return backend.load_df(DATA2)


@pytest.mark.parametrize("query, output", [
(count(_.g), data_frame(g = ['a', 'b'], n = [2, 2])),
(count("g"), data_frame(g = ['a', 'b'], n = [2, 2])),
Expand Down Expand Up @@ -42,3 +49,36 @@ def test_count_with_kwarg_expression(df):
pd.DataFrame({"y": [0], "n": [4]})
)

@backend_notimpl("sqlite", "postgresql") # see (#104)
def test_count_wt(backend, df):
assert_equal_query(
df,
count(_.g, wt = _.x),
pd.DataFrame({'g': ['a', 'b'], 'n': [1 + 2, 3 + 4]})
)

def test_count_no_groups(df):
# count w/ no groups returns ttl
assert_equal_query(
df,
count(),
pd.DataFrame({'n': [4]})
)

@backend_notimpl("sqlite", "postgresql") # see (#104)
def test_count_no_groups_wt(backend, df):
assert_equal_query(
df,
count(wt = _.x),
pd.DataFrame({'n': [sum([1,2,3,4])]})
)


def test_count_on_grouped_df(df2):
assert_equal_query(
df2,
group_by(_.g) >> count(_.h),
pd.DataFrame({'g': ['a', 'b'], 'h': ['c', 'd'], 'n': [2,2]})
)


0 comments on commit f5b5eb4

Please sign in to comment.