Skip to content

Commit

Permalink
Merge pull request #406 from machow/feat-sqlite-dialect
Browse files Browse the repository at this point in the history
feat: support sqlite window functions
  • Loading branch information
machow authored Mar 29, 2022
2 parents 8896eec + e8bdd46 commit b284d75
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 80 deletions.
43 changes: 34 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ jobs:
python -m pip install --upgrade pip
python -m pip install $REQUIREMENTS
python -m pip install -r requirements-test.txt
python -m pip install snowflake-sqlalchemy==1.3.3
python -m pip install .
env:
REQUIREMENTS: ${{ matrix.requirements }}
Expand All @@ -59,8 +58,6 @@ jobs:
env:
SB_TEST_PGPORT: 5432
PYTEST_FLAGS: ${{ matrix.pytest_flags }}
SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }}
SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}

# optional step for running bigquery tests ----
- name: Set up Cloud SDK
Expand All @@ -78,19 +75,20 @@ jobs:
test-bigquery:
name: "Test BigQuery"
runs-on: ubuntu-latest
if: contains(github.ref, 'bigquery') || contains(github.ref, 'refs/tags')
if: ${{ contains('bigquery', github.ref) || !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r requirements-test.txt
python -m pip install git+https://github.com/machow/pybigquery.git pandas-gbq==0.15.0
python -m pip install pytest-parallel
python -m pip install sqlalchemy-bigquery==1.3.0 pandas-gbq==0.15.0
python -m pip install .
- name: Set up Cloud SDK
uses: google-github-actions/setup-gcloud@v0
Expand All @@ -100,10 +98,37 @@ jobs:
export_default_credentials: true
- name: Test with pytest
run: |
pytest siuba -m bigquery
# tests are mostly waiting on http requests to bigquery api
# note that test backends can cache data, so more processes
# is not always faster
pytest siuba -m bigquery --workers 2 --tests-per-worker 20
env:
SB_TEST_BQDATABASE: "ci_github"

test-snowflake:
name: "Test snowflake"
runs-on: ubuntu-latest
if: ${{ contains('snowflake', github.ref) || !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r requirements-test.txt
python -m pip install pytest-parallel
python -m pip install snowflake-sqlalchemy==1.3.3
python -m pip install .
- name: Test with pytest
run: |
pytest siuba -m snowflake --workers 2 --tests-per-worker 20
env:
SB_TEST_SNOWFLAKEPASSWORD: ${{ secrets.SB_TEST_SNOWFLAKEPASSWORD }}
SB_TEST_SNOWFLAKEHOST: ${{ secrets.SB_TEST_SNOWFLAKEHOST }}

build-docs:
name: "Build Documentation"
Expand Down Expand Up @@ -185,7 +210,7 @@ jobs:
name: "Deploy to PyPI"
runs-on: ubuntu-latest
if: github.event_name == 'release'
needs: [checks, test-bigquery]
needs: [checks, test-bigquery, test-snowflake]
steps:
- uses: actions/checkout@v2
- name: "Set up Python 3.8"
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[pytest]
# since bigquery takes a long time to execute,
# the tests are disabled by default.
addopts = --doctest-modules -m 'not bigquery'
addopts = --doctest-modules -m 'not bigquery and not snowflake'
markers =
skip_backend
38 changes: 13 additions & 25 deletions siuba/sql/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
from sqlalchemy.sql import func as fn
from . import _dt_generics as _dt

# Custom dispatching in call trees ============================================

Expand All @@ -23,30 +24,24 @@ def sql_floordiv(_, x, y):

# datetime ----

def _date_trunc(_, col, name):
@_dt.date_trunc.register
def _date_trunc(_: BigqueryColumn, col, name):
return fn.datetime_trunc(col, sql.text(name))

def sql_extract(field):
return lambda _, col: fn.extract(field, col)

def sql_is_first_of(name, reference):
def f(codata, col):
return _date_trunc(codata, col, name) == _date_trunc(codata, col, reference)

return f

def sql_func_last_day_in_period(_, col, period):
@_dt.sql_func_last_day_in_period.register
def sql_func_last_day_in_period(_: BigqueryColumn, col, period):
return fn.last_day(col, sql.text(period))

def sql_func_days_in_month(_, col):
return fn.extract('DAY', sql_func_last_day_in_period(col, 'MONTH'))

def sql_is_last_day_of(period):
def f(codata, col):
last_day = sql_func_last_day_in_period(codata, col, period)
return _date_trunc(codata, col, "DAY") == last_day
@_dt.sql_is_last_day_of.register
def sql_is_last_day_of(codata: BigqueryColumn, col, period):
last_day = sql_func_last_day_in_period(codata, col, period)
return _date_trunc(codata, col, "DAY") == last_day

return f

def sql_extract(field):
return lambda _, col: fn.extract(field, col)


# string ----
Expand Down Expand Up @@ -115,13 +110,6 @@ def f(_, col):
# bigquery has Sunday as 1, pandas wants Monday as 0
"dt.dayofweek": lambda _, col: fn.extract("DAYOFWEEK", col) - 2,
"dt.dayofyear": sql_extract("DAYOFYEAR"),
"dt.daysinmonth": sql_func_days_in_month,
"dt.days_in_month": sql_func_days_in_month,
"dt.is_month_end": sql_is_last_day_of("MONTH"),
"dt.is_month_start": sql_is_first_of("DAY", "MONTH"),
"dt.is_quarter_start": sql_is_first_of("DAY", "QUARTER"),
"dt.is_year_end": sql_is_last_day_of("YEAR"),
"dt.is_year_start": sql_is_first_of("DAY", "YEAR"),
"dt.month_name": lambda _, col: fn.format_date("%B", col),
"dt.week": sql_extract("ISOWEEK"),
"dt.weekday": lambda _, col: fn.extract("DAYOFWEEK", col) - 2,
Expand Down Expand Up @@ -154,7 +142,7 @@ def f(_, col):
all = sql_all(window = True),
count = lambda _, col: AggOver(fn.count(col)),
cumsum = win_cumul("sum"),
median = lambda _, col: RankOver(sql_median(col)),
median = lambda _, col: RankOver(fn.percentile_cont(col, .5)),
nunique = lambda _, col: AggOver(fn.count(fn.distinct(col))),
quantile = lambda _, col, q: RankOver(fn.percentile_cont(col, q)),
std = win_agg("stddev"),
Expand Down
159 changes: 151 additions & 8 deletions siuba/sql/dialects/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# sqlvariant, allow defining 3 namespaces to override defaults
from ..translate import (
SqlColumn, SqlColumnAgg, extend_base,
SqlTranslator
SqlTranslator,
sql_not_impl,
win_cumul,
win_agg,
annotate,
wrap_annotate
)

from .base import base_nowin
#from .postgresql import PostgresqlColumn as SqlColumn, PostgresqlColumnAgg as SqlColumnAgg
from . import _dt_generics as _dt

import sqlalchemy.sql.sqltypes as sa_types
from sqlalchemy import sql
from sqlalchemy.sql import func as fn

# Custom dispatching in call trees ============================================

Expand All @@ -16,18 +24,153 @@
class SqliteColumn(SqlColumn): pass
class SqliteColumnAgg(SqlColumnAgg, SqliteColumn): pass

scalar = extend_base(

# Translations ================================================================

# fix some annotations --------------------------------------------------------

# Note this is taken from the postgres dialect, but it seems that there are 2 key points
# compared to postgresql, which always returns a float
# * sqlite date parts are returned as floats
# * sqlite time parts are returned as integers
def returns_float(func_names):
# TODO: MC-NOTE - shift all translations to directly register
# TODO: MC-NOTE - make an AliasAnnotated class or something, that signals
# it is using another method, but w/ an updated annotation.
from siuba.ops import ALL_OPS

for name in func_names:
generic = ALL_OPS[name]
f_concrete = generic.dispatch(SqlColumn)
f_annotated = wrap_annotate(f_concrete, result_type="float")
generic.register(SqliteColumn, f_annotated)

# detect first and last date (similar to the mysql dialect) -------------------

@annotate(return_type="float")
def sql_extract(name):
if name == "quarter":
# division in sqlite automatically rounds down
# so for jan, 1 + 2 = 3, and 3 / 1 is Q1
return lambda _, col: (fn.strftime("%m", col) + 2) / 3
return lambda _, col: fn.extract(name, col)


@_dt.sql_is_last_day_of.register
def _sql_is_last_day_of(codata: SqliteColumn, col, period):
valid_periods = {"month", "year"}
if period not in valid_periods:
raise ValueError(f"Period must be one of {valid_periods}")

incr = f"+1 {period}"

target_date = fn.date(col, f'start of {period}', incr, "-1 day")
return col == target_date


@_dt.sql_is_first_day_of.register
def _sql_is_first_day_of(codata: SqliteColumn, col, period):
valid_periods = {"month", "year"}
if period not in valid_periods:
raise ValueError(f"Period must be one of {valid_periods}")

target_date = fn.date(col, f'start of {period}')
return fn.date(col) == target_date


# date part of period calculations --------------------------------------------

def sql_days_in_month(_, col):
date_last_day = fn.date(col, 'start of month', '+1 month', '-1 day')
return fn.strftime("%d", date_last_day).cast(sa_types.Integer())


def sql_week_of_year(_, col):
# convert sqlite week to ISO week
# adapted from: https://stackoverflow.com/a/15511864
iso_dow = (fn.strftime("%j", fn.date(col, "-3 days", "weekday 4")) - 1)

return (iso_dow / 7) + 1


# misc ------------------------------------------------------------------------

@annotate(result_type = "float")
def sql_round(_, col, n):
return sql.func.round(col, n)


def sql_func_truediv(_, x, y):
return sql.cast(x, sa_types.Float()) / y


def between(_, col, x, y):
res = col.between(x, y)

# tell sqlalchemy the result is a boolean. this causes it to be correctly
# converted from an integer to bool when the results are collected.
# note that this is consistent with what col == col returns
res.type = sa_types.Boolean()
return res

def sql_str_capitalize(_, col):
# capitalize first letter, then concatenate with lowercased rest
first_upper = fn.upper(fn.substr(col, 1, 1))
rest_lower = fn.lower(fn.substr(col, 2))
return first_upper.concat(rest_lower)

extend_base(
SqliteColumn,
)

aggregate = extend_base(
SqliteColumnAgg,
)
between = between,
clip = sql_not_impl("sqlite does not have a least or greatest function."),

div = sql_func_truediv,
divide = sql_func_truediv,
rdiv = lambda _, x,y: sql_func_truediv(_, y, x),

window = extend_base(
__truediv__ = sql_func_truediv,
truediv = sql_func_truediv,
__rtruediv__ = lambda _, x, y: sql_func_truediv(_, y, x),

round = sql_round,
__round__ = sql_round,

**{
"str.title": sql_not_impl("TODO"),
"str.capitalize": sql_str_capitalize,
},

**{
"dt.quarter": sql_extract("quarter"),
"dt.is_quarter_start": sql_not_impl("TODO"),
"dt.is_quarter_end": sql_not_impl("TODO"),
"dt.days_in_month": sql_days_in_month,
"dt.daysinmonth": sql_days_in_month,
"dt.week": sql_week_of_year,
"dt.weekofyear": sql_week_of_year,

}
)

returns_float([
"dt.dayofweek",
"dt.weekday",
])


extend_base(
SqliteColumn,
**base_nowin
# TODO: should check sqlite version, since < 3.25 can't use windows
cumsum = win_cumul("sum"),

quantile = sql_not_impl("sqlite does not support ordered set aggregates"),
sum = win_agg("sum"),
)

extend_base(
SqliteColumnAgg,
quantile = sql_not_impl("sqlite does not support ordered set aggregates"),
)


Expand Down
Loading

0 comments on commit b284d75

Please sign in to comment.