Skip to content

Commit 8896eec

Browse files
authored
Merge pull request #398 from machow/feat-siu-dispatch
Feat siu dispatch
2 parents 76d6abc + a7f82a3 commit 8896eec

30 files changed

+806
-399
lines changed

.github/workflows/ci.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
name: CI
22

33
on:
4+
workflow_dispatch:
45
push:
6+
branches: ['main', 'dev-*']
57
pull_request:
68
release:
79
types: [published]
@@ -13,15 +15,10 @@ jobs:
1315
strategy:
1416
fail-fast: false
1517
matrix:
16-
python-version: [3.6, 3.7, 3.8]
18+
python-version: [3.7, 3.8]
1719
requirements: ['-r requirements.txt']
1820
include:
1921
# historical requirements
20-
- name: "2020-early dependencies"
21-
requirements: numpy==1.17.4 pandas~=0.25.3 SQLAlchemy~=1.3.11 psycopg2~=2.8.4 PyMySQL==1.0.2
22-
pytest_flags: --ignore=siuba/dply/forcats.py siuba
23-
python-version: 3.6
24-
# current
2522
- name: "2020-mid dependencies"
2623
python-version: 3.8
2724
requirements: numpy~=1.19.1 pandas~=1.1.0 SQLAlchemy~=1.3.18 psycopg2~=2.8.5 PyMySQL==1.0.2
@@ -52,6 +49,7 @@ jobs:
5249
python -m pip install --upgrade pip
5350
python -m pip install $REQUIREMENTS
5451
python -m pip install -r requirements-test.txt
52+
python -m pip install snowflake-sqlalchemy==1.3.3
5553
python -m pip install .
5654
env:
5755
REQUIREMENTS: ${{ matrix.requirements }}
@@ -61,6 +59,8 @@ jobs:
6159
env:
6260
SB_TEST_PGPORT: 5432
6361
PYTEST_FLAGS: ${{ matrix.pytest_flags }}
62+
SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }}
63+
SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}
6464

6565
# optional step for running bigquery tests ----
6666
- name: Set up Cloud SDK

docs/developer/backend_sql.Rmd

+33-11
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,26 @@ jupyter:
55
extension: .Rmd
66
format_name: rmarkdown
77
format_version: '1.2'
8-
jupytext_version: 1.4.2
8+
jupytext_version: 1.13.7
99
kernelspec:
10-
display_name: Python 3
10+
display_name: Python 3 (ipykernel)
1111
language: python
1212
name: python3
1313
---
1414

1515
```{python nbsphinx=hidden}
1616
import pandas as pd
1717
pd.set_option("display.max_rows", 5)
18+
19+
from siuba.siu.format import Formatter
20+
21+
show_tree = lambda x: print(Formatter().format(x))
1822
```
1923

2024
# SQL backend
2125

26+
> ⚠️: This document is being revised (though the code runs correctly!).
27+
2228

2329
## Step 1: Column Translation
2430

@@ -28,7 +34,8 @@ Column translation requires three pieces:
2834
1. **Locals:** Functions for creating the sqlalchemy clause corresponding to an
2935
operation.
3036
2. **Column Data:** Classes representing columns under normal and aggregate settings.
31-
3. **Translator:** A class that can take a symbolic expression (e.g. `_.x.mean()`) and return the correct sqlachemy clause.
37+
3. **Translator:** A class that can take a symbolic expression (e.g. `_.x.mean()`) and return it in call form: `mean(_.x)`.
38+
4. **Codata visitor:** A class that takes the above call, and swaps in the sql dialect version of each call.
3239

3340

3441
```{python}
@@ -77,9 +84,10 @@ aggregation = {
7784
from siuba.sql.translate import SqlTranslator
7885
7986
translator = SqlTranslator.from_mappings(
80-
scalar, window, aggregation,
8187
WowSqlColumn, WowSqlColumnAgg
8288
)
89+
90+
# TODO: how to work in codata visitor?
8391
```
8492

8593
## Column Data
@@ -96,7 +104,7 @@ The entries of each local dictionary are functions that take a sqlalchemy.sql.Cl
96104
```{python}
97105
from sqlalchemy import sql
98106
99-
expr_rank = window["rank"](sql.column("a_col"))
107+
expr_rank = window["rank"](WowSqlColumn(), sql.column("a_col"))
100108
expr_rank
101109
```
102110

@@ -111,6 +119,8 @@ Below, we set up a sqlalchemy select statement in order to demonstrate the trans
111119

112120
```{python}
113121
from siuba import _
122+
123+
114124
from sqlalchemy.sql import column, select
115125
116126
sel = select([column('x'), column('y')])
@@ -120,26 +130,38 @@ Then we feed the columns to the translated call.
120130

121131
```{python}
122132
call_add = translator.translate(_.x + _.y)
123-
call_add(sel.columns)
133+
134+
show_tree(call_add)
124135
```
125136

126137
Note that behind the scenes, the translator goes down the call tree and swaps functions like `"__add__"` with the local translations.
127138

139+
```{python}
140+
from siuba.siu.visitors import CodataVisitor
141+
codata = CodataVisitor(WowSqlColumn, object)
142+
143+
call_add_final = codata.visit(call_add)
144+
145+
show_tree(call_add_final)
146+
```
147+
128148
```{python}
129149
# the root node is __add__. shown as +.
130150
_.x + _.y
131151
```
132152

133153
```{python}
134154
# We can see this in action by calling the translation directly.
135-
scalar["__add__"](sel.columns.x, sel.columns.y)
155+
scalar["__add__"](WowSqlColumn(), sel.columns.x, sel.columns.y)
136156
```
137157

138158
By default the translate method assumes the expression is using window functions, so operations like `.mean()` return SqlAlchemy Over clauses.
139159

140160
```{python}
141161
f_translate = translator.translate(_.x.mean())
142-
expr = f_translate(sel.columns)
162+
163+
f_translate_co = codata.visit(f_translate)
164+
expr = f_translate_co(sel.columns)
143165
144166
expr
145167
```
@@ -166,7 +188,7 @@ from siuba.siu import _, symbolic_dispatch
166188
from sqlalchemy import sql
167189
168190
@symbolic_dispatch(cls = WowSqlColumn)
169-
def round(col):
191+
def round(self, col):
170192
print("running round function")
171193
172194
return sql.function.round(col)
@@ -224,8 +246,8 @@ tbl_cars
224246
Note that you can access a number of useful attributes.
225247

226248
```{python}
227-
# the underlying translator
228-
f_add = tbl_cars.translator.translate(_.mpg + _.hp)
249+
# calls the underlying translator and codata
250+
f_add = tbl_cars.shape_call(_.mpg + _.hp)
229251
f_add(tbl_cars.last_op.columns)
230252
```
231253

setup.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@
4646
"gapminder==0.1",
4747
],
4848
},
49-
python_requires=">=3.6",
49+
python_requires=">=3.7",
5050
include_package_data=True,
5151
long_description=README,
5252
long_description_content_type="text/markdown",
5353
classifiers=[
5454
'Programming Language :: Python :: 3',
55-
'Programming Language :: Python :: 3.6',
5655
'Programming Language :: Python :: 3.7',
56+
'Programming Language :: Python :: 3.8',
57+
'Programming Language :: Python :: 3.9',
58+
'Programming Language :: Python :: 3.10',
5759
],
5860
)
5961

siuba/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# version ---------------------------------------------------------------------
2-
__version__ = "0.1.2"
2+
__version__ = "0.2.0.dev3"
33

44
# default imports--------------------------------------------------------------
55
from .siu import _, Lam

siuba/dply/verbs.py

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def install_pd_siu():
6666
DataFrameGroupBy.__repr__ = _repr_grouped_df_console_
6767

6868
def _repr_grouped_df_html_(self):
69+
obj_repr = self.obj._repr_html_()
70+
71+
# user can config pandas not to return html representation, in which case
72+
# the ipython behavior should fall back to repr
73+
if obj_repr is None:
74+
return None
75+
6976
return "<div><p>(grouped data frame)</p>" + self.obj._repr_html_() + "</div>"
7077

7178
def _repr_grouped_df_console_(self):

siuba/experimental/pd_groups/dialect.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from siuba.siu import CallTreeLocal, FunctionLookupError
1+
from siuba.siu import CallTreeLocal, FunctionLookupError, ExecutionValidatorVisitor
22
from .groupby import SeriesGroupBy
33

44
from .translate import (
@@ -99,6 +99,8 @@ def register_method(ns, op_name, f, is_property = False, accessor = None):
9999
call_props = ALL_PROPERTIES
100100
)
101101

102+
call_validator = ExecutionValidatorVisitor(GroupByAgg, SeriesGroupBy)
103+
102104

103105
# Fast group by verbs =========================================================
104106

@@ -123,6 +125,8 @@ def grouped_eval(__data, expr, require_agg = False):
123125
if isinstance(expr, Call):
124126
try:
125127
call = call_listener.enter(expr)
128+
call_validator.visit(call)
129+
126130
except FunctionLookupError as e:
127131
fallback_warning(expr, str(e))
128132
call = expr
@@ -162,6 +166,7 @@ def _transform_args(args):
162166
elif isinstance(expr, Call):
163167
try:
164168
call = call_listener.enter(expr)
169+
call_validator.visit(call)
165170
out.append(call)
166171
except FunctionLookupError as e:
167172
fallback_warning(expr, str(e))

siuba/experimental/pd_groups/test_pd_groups.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_transform_args():
122122
def test_fast_grouped_custom_user_funcs():
123123
@symbolic_dispatch
124124
def f(x):
125-
return x.mean()
125+
raise NotImplementedError()
126126

127127
@f.register(SeriesGroupBy)
128128
def _f_grouped(x) -> GroupByAgg:
@@ -149,7 +149,7 @@ def test_fast_grouped_custom_user_func_fail():
149149
def f(x):
150150
return x.mean()
151151

152-
@f.register(GroupByAgg)
152+
@f.register(SeriesGroupBy)
153153
def _f_gser(x):
154154
# note, no return annotation, so translator will raise an error
155155
return GroupByAgg.from_result(x.mean(), x)

siuba/ops/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
from .generics import ALL_OPS, PLAIN_OPS
2+
from .utils import _register_series_default
3+
4+
# register default series methods on all operations
5+
for _generic in ALL_OPS.values():
6+
_register_series_default(_generic)
7+
8+
del _generic
9+
del _register_series_default
10+
211

312
# import accessor generics. These are included in ALL_OPS, but since we want
413
# users to be able to import from them, also need to be modules. Start their

siuba/ops/support/base.py

+43-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from siuba.siu import FunctionLookupBound
99
from siuba.sql.utils import get_dialect_translator
1010

11-
SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery"]
11+
SQL_BACKENDS = ["postgresql", "redshift", "sqlite", "mysql", "bigquery", "snowflake"]
1212
ALL_BACKENDS = SQL_BACKENDS + ["pandas"]
1313

1414
methods = pd.DataFrame(
@@ -32,24 +32,57 @@ def read_dialect(name):
3232

3333

3434
def read_sql_op(name, backend, translator):
35-
f_win = translator.window.local.get(name)
36-
f_agg = translator.aggregate.local.get(name)
37-
38-
# check FunctionLookupBound, a sentinal class for not implemented funcs
39-
support = not (f_win is None or isinstance(f_win, FunctionLookupBound))
40-
metadata = getattr(f_win, "operation", {})
35+
# TODO: MC-NOTE - cleanup this code
36+
from siuba.siu.visitors import CodataVisitor, FunctionLookupError
37+
from siuba.ops.utils import Operation
38+
co_win = CodataVisitor(translator.window.dispatch_cls)
39+
co_agg = CodataVisitor(translator.aggregate.dispatch_cls)
40+
41+
disp_win = translator.window.local[name]
42+
disp_agg = translator.aggregate.local[name]
43+
44+
try:
45+
f_win = co_win.validate_dispatcher(disp_win, strict=False)
46+
if isinstance(f_win, FunctionLookupBound):
47+
win_supported = False
48+
elif disp_win.dispatch(object) is f_win:
49+
win_supported = False
50+
else:
51+
win_supported = True
52+
except FunctionLookupError:
53+
f_win = None
54+
win_supported = False
55+
56+
57+
try:
58+
f_agg = co_agg.validate_dispatcher(disp_agg)
59+
if isinstance(f_agg, FunctionLookupBound):
60+
agg_supported = False
61+
else:
62+
agg_supported = True
63+
except FunctionLookupError:
64+
agg_supported = False
4165

4266
# window functions should be a superset of agg functions
43-
if f_win is None and f_agg is not None:
67+
if f_win is None and agg_supported:
4468
raise Exception("agg functions in %s without window funcs: %s" %(backend, name))
4569

46-
if support and isinstance(f_agg, FunctionLookupBound):
70+
if win_supported and not agg_supported:
71+
flags = "no_aggregate"
72+
elif agg_supported and not win_supported:
4773
flags = "no_mutate"
4874
else:
4975
flags = ""
5076

51-
meta = {"is_supported": support, "flags": flags, **metadata}
77+
if win_supported or agg_supported:
78+
metadata = getattr(f_win, "operation", {})
79+
if isinstance(metadata, Operation):
80+
metadata = {**vars(metadata)}
81+
meta = {"is_supported": True, "flags": flags, **metadata}
5282

83+
else:
84+
meta = {"is_supported": False, "flags": flags}
85+
5386
return {"full_name": name, "backend": backend, "metadata": meta}
5487

5588

0 commit comments

Comments
 (0)