diff --git a/siuba/dply/verbs.py b/siuba/dply/verbs.py index 1a1f6a1c..7d573f78 100644 --- a/siuba/dply/verbs.py +++ b/siuba/dply/verbs.py @@ -108,6 +108,13 @@ def simple_varname(call): return None +def ordered_union(x, y): + # TODO: duplicated in sql file + dx = {el: True for el in x} + dy = {el: True for el in y} + + return tuple({**dx, **dy}) + # Symbolic Wrapper ============================================================ from functools import wraps @@ -257,7 +264,7 @@ def _mutate(__data, **kwargs): # Group By ==================================================================== @singledispatch2((pd.DataFrame, DataFrameGroupBy)) -def group_by(__data, *args, **kwargs): +def group_by(__data, *args, add = False, **kwargs): tmp_df = mutate(__data, **kwargs) if kwargs else __data by_vars = list(map(simple_varname, args)) @@ -266,6 +273,11 @@ def group_by(__data, *args, **kwargs): by_vars.extend(kwargs.keys()) + if isinstance(tmp_df, DataFrameGroupBy) and add: + prior_groups = [el.name for el in __data.grouper.groupings] + all_groups = ordered_union(prior_groups, by_vars) + return tmp_df.obj.groupby(list(all_groups)) + return tmp_df.groupby(by = by_vars) @@ -767,18 +779,38 @@ def _count_group(data, *args): return -@singledispatch2(pd.DataFrame) +@singledispatch2((pd.DataFrame, DataFrameGroupBy)) def count(__data, *args, wt = None, sort = False, **kwargs): - # TODO: if expr, works like mutate + """Return the number of rows for each grouping of data. + + Args: + __data: a DataFrame + *args: the names of columns to be used for grouping. Passed to group_by. + wt: the name of a column to use as a weighted for each row. + sort: whether to sort the results in descending order. + **kwargs: creates a new named column, and uses for grouping. Passed to group_by. + + """ + no_grouping_vars = not args and not kwargs and isinstance(__data, pd.DataFrame) - #group by args if wt is None: - counts = group_by(__data, *args, **kwargs).size().reset_index() + if no_grouping_vars: + # no groups, just use number of rows + counts = pd.DataFrame({'tmp': [__data.shape[0]]}) + else: + # tally rows for each group + counts = group_by(__data, *args, add = True, **kwargs).size().reset_index() else: wt_col = simple_varname(wt) if wt_col is None: raise Exception("wt argument has to be simple column name") - counts = group_by(__data, *args, **kwargs)[wt_col].sum().reset_index() + + if no_grouping_vars: + # no groups, sum weights + counts = pd.DataFrame({'tmp': [__data[wt_col].sum()]}) + else: + # do weighted tally + counts = group_by(__data, *args, add = True, **kwargs)[wt_col].sum().reset_index() # count col named, n. If that col already exists, add more "n"s... @@ -790,7 +822,7 @@ def count(__data, *args, wt = None, sort = False, **kwargs): counts.rename(columns = {counts.columns[-1]: out_col}, inplace = True) if sort: - return counts.sort_values(out_col, ascending = False) + return counts.sort_values(out_col, ascending = False).reset_index(drop = True) return counts diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index a60bb898..178b7672 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -510,11 +510,14 @@ def _create_order_by_clause(columns, *args): @count.register(LazyTbl) -def _count(__data, *args, sort = False, **kwargs): +def _count(__data, *args, sort = False, wt = None, **kwargs): # TODO: if already col named n, use name nn, etc.. get logic from tidy.py if kwargs: raise NotImplementedError("TODO") + if wt is not None: + raise NotImplementedError("TODO") + # similar to filter verb, we need two select statements, # an inner one for derived cols, and outer to group by them sel = __data.last_op.alias() diff --git a/siuba/tests/test_verb_count.py b/siuba/tests/test_verb_count.py index a63b8500..1d8a37f9 100644 --- a/siuba/tests/test_verb_count.py +++ b/siuba/tests/test_verb_count.py @@ -5,16 +5,23 @@ """ from siuba import _, group_by, summarize, count +import pandas as pd import pytest from .helpers import assert_equal_query, data_frame, backend_notimpl, backend_sql DATA = data_frame(x = [1,2,3,4], g = ['a', 'a', 'b', 'b']) +DATA2 = data_frame(x = [1,2,3,4], g = ['a', 'a', 'b', 'b'], h = ['c', 'c', 'd', 'd']) @pytest.fixture(scope = "module") def df(backend): return backend.load_df(DATA) +@pytest.fixture(scope = "module") +def df2(backend): + return backend.load_df(DATA2) + + @pytest.mark.parametrize("query, output", [ (count(_.g), data_frame(g = ['a', 'b'], n = [2, 2])), (count("g"), data_frame(g = ['a', 'b'], n = [2, 2])), @@ -42,3 +49,36 @@ def test_count_with_kwarg_expression(df): pd.DataFrame({"y": [0], "n": [4]}) ) +@backend_notimpl("sqlite", "postgresql") # see (#104) +def test_count_wt(backend, df): + assert_equal_query( + df, + count(_.g, wt = _.x), + pd.DataFrame({'g': ['a', 'b'], 'n': [1 + 2, 3 + 4]}) + ) + +def test_count_no_groups(df): + # count w/ no groups returns ttl + assert_equal_query( + df, + count(), + pd.DataFrame({'n': [4]}) + ) + +@backend_notimpl("sqlite", "postgresql") # see (#104) +def test_count_no_groups_wt(backend, df): + assert_equal_query( + df, + count(wt = _.x), + pd.DataFrame({'n': [sum([1,2,3,4])]}) + ) + + +def test_count_on_grouped_df(df2): + assert_equal_query( + df2, + group_by(_.g) >> count(_.h), + pd.DataFrame({'g': ['a', 'b'], 'h': ['c', 'd'], 'n': [2,2]}) + ) + +