Skip to content

Commit

Permalink
Merge pull request #28 from machow/feat-sql-tests
Browse files Browse the repository at this point in the history
Feat sql tests
  • Loading branch information
machow authored May 11, 2019
2 parents a9d1d31 + 548d919 commit 4921fd8
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@ install:
- pip install -r requirements.txt
- pip install .
# command to run tests
services:
- postgresql
env:
global:
- PGPORT=5432
script:
- make test-travis
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ NOTEBOOK_TESTS=$(addprefix examples/, examples-dplyr-funcs.ipynb case-iris-selec

test:
py.test --nbval $(NOTEBOOK_TESTS)
py.test
pytest --dbs="sqlite,postgresql" siuba/tests

test-travis:
py.test --nbval $(filter-out %postgres.ipynb, $(NOTEBOOK_TESTS))
py.test
pytest --dbs="sqlite,postgresql" siuba/tests

examples/%.ipynb:
jupyter nbconvert --to notebook --inplace --execute $@
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ services:
image: postgres
restart: always
environment:
POSTGRES_PASSWORD: example
POSTGRES_PASSWORD: ""
ports:
- 5433:5432
17 changes: 6 additions & 11 deletions examples/examples-postgres.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/machow/.virtualenvs/siuba/lib/python3.6/site-packages/psycopg2/__init__.py:144: UserWarning: The psycopg2 wheel package will be renamed from release 2.8; in order to keep installing from binary please use \"pip install psycopg2-binary\" instead. For details see: <http://initd.org/psycopg/docs/install.html#binary-install-from-pypi>.\n",
" \"\"\")\n"
]
},
{
"data": {
"text/plain": [
"<sqlalchemy.engine.result.ResultProxy at 0x10aba65f8>"
"<sqlalchemy.engine.result.ResultProxy at 0x1080c3be0>"
]
},
"execution_count": 1,
Expand All @@ -38,7 +30,10 @@
"from sqlalchemy import sql\n",
"from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey\n",
"from sqlalchemy import create_engine\n",
"engine = create_engine('postgresql://postgres:example@localhost:5433/postgres', echo=False)\n",
"import os\n",
"\n",
"port = os.environ.get(\"PGPORT\", \"5433\")\n",
"engine = create_engine('postgresql://postgres:@localhost:%s/postgres'%port, echo=False)\n",
"\n",
"\n",
"metadata = MetaData()\n",
Expand Down Expand Up @@ -1009,7 +1004,7 @@
"data": {
"text/plain": [
"█─'__call__'\n",
"├─<function case_when at 0x11369ac80>\n",
"├─<function case_when at 0x11065c0d0>\n",
"├─_\n",
"└─█─'<lazy>'\n",
" └─█─'__call__'\n",
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pytz==2018.9
six==1.12.0
SQLAlchemy==1.2.17
nbval==0.9.1
# tests
psycopg2==2.8.2
# only used for iris dataset
scikit-learn==0.20.2
# used for docs
Expand Down
9 changes: 7 additions & 2 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def arrange(__data, *args):

@singledispatch2(DataFrame)
def distinct(__data, *args, _keep_all = False, **kwargs):
# using dict as ordered set
cols = {simple_varname(x): True for x in args}
if None in cols:
raise Exception("positional arguments must be simple column, "
Expand All @@ -629,10 +630,14 @@ def distinct(__data, *args, _keep_all = False, **kwargs):

# mutate kwargs
cols.update(kwargs)
tmp_data = mutate(__data, **kwargs).drop_duplicates(cols)

# special case: use all variables when none are specified
if not len(cols): cols = __data.columns

tmp_data = mutate(__data, **kwargs).drop_duplicates(list(cols)).reset_index(drop = True)

if not _keep_all:
return tmp_data[cols]
return tmp_data[list(cols)]

return tmp_data

Expand Down
9 changes: 4 additions & 5 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def lift_inner_cols(tbl):

return sql.base.ImmutableColumnCollection(data, cols)

def is_grouped_sel(select):
return False

def has_windows(clause):
windows = []
append_win = lambda col: windows.append(col)
Expand Down Expand Up @@ -618,8 +615,8 @@ def _rename(__data, **kwargs):

@distinct.register(LazyTbl)
def _distinct(__data, *args, _keep_all = False, **kwargs):
if _keep_all:
raise NotImplementedError("Distinct in sql requires _keep_all = True")
if (args or kwargs) and _keep_all:
raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False")

inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op

Expand All @@ -633,6 +630,8 @@ def _distinct(__data, *args, _keep_all = False, **kwargs):
"e.g. _.colname or _['colname']"
)

if not cols: cols = list(inner_sel.columns.keys())

sel_cols = lift_inner_cols(inner_sel)
distinct_cols = [sel_cols[k] for k in cols]

Expand Down
Empty file added siuba/tests/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions siuba/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest

def pytest_addoption(parser):
parser.addoption(
"--dbs", action="store", default="sqlite", help="databases tested against (comma separated)"
)
67 changes: 67 additions & 0 deletions siuba/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from sqlalchemy import create_engine, types
from siuba.sql import LazyTbl, collect
from pandas.testing import assert_frame_equal

class DbConRegistry:
table_name_indx = 0

def __init__(self):
self.connections = {}

def register(self, name, engine):
self.connections[name] = engine

def remove(self, name):
con = self.connections[name]
con.close()
del self.connections[name]

return con

@classmethod
def unique_table_name(cls):
cls.table_name_indx += 1
return "siuba_{0:03d}".format(cls.table_name_indx)

def load_df(self, df):
out = []
for k, engine in self.connections.items():
lazy_tbl = copy_to_sql(df, self.unique_table_name(), engine)
out.append(lazy_tbl)
return out

def assert_frame_sort_equal(a, b):
"""Tests that DataFrames are equal, even if rows are in different order"""
sorted_a = a.sort_values(by = a.columns.tolist()).reset_index(drop = True)
sorted_b = b.sort_values(by = b.columns.tolist()).reset_index(drop = True)

assert_frame_equal(sorted_a, sorted_b)

def assert_equal_query(tbls, lazy_query, target):
for tbl in tbls:
out = collect(lazy_query(tbl))
assert_frame_sort_equal(out, target)


PREFIX_TO_TYPE = {
# for datetime, need to convert to pandas datetime column
#"dt": types.DateTime,
"int": types.Integer,
"float": types.Float,
"str": types.String
}

def auto_types(df):
dtype = {}
for k in df.columns:
pref, *_ = k.split('_')
if pref in PREFIX_TO_TYPE:
dtype[k] = PREFIX_TO_TYPE[pref]
return dtype


def copy_to_sql(df, name, engine):
df.to_sql(name, engine, dtype = auto_types(df), index = False, if_exists = "replace")
return LazyTbl(engine, name)


9 changes: 9 additions & 0 deletions siuba/tests/test_dply_verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def test_varlist_multi_slice_negate(df1):
assert out.columns.tolist() == ["language", "stars", "x"]


# Distinct --------------------------------------------------------------------

from siuba.dply.verbs import distinct

def test_distinct_no_args():
df =pd.DataFrame({'x': [1,1,2], 'y': [1,1,2]})
assert_frame_equal(distinct(df), df.drop_duplicates().reset_index(drop = True))


# Nest ------------------------------------------------------------------------

from siuba.dply.verbs import nest, unnest
Expand Down
79 changes: 79 additions & 0 deletions siuba/tests/test_sql_verbs_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Note: this test file was heavily influenced by its dbplyr counterpart.
https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-distinct.R
"""

from siuba.sql import LazyTbl, collect
from siuba import _, distinct
import pandas as pd
import os

import pytest
from sqlalchemy import create_engine

from .helpers import assert_equal_query, DbConRegistry

DATA = pd.DataFrame({
"x": [1,1,1,1],
"y": [1,1,2,2],
"z": [1,2,1,2]
})

@pytest.fixture(scope = "module")
def dbs(request):
dialects = set(request.config.getoption("--dbs").split(","))
dbs = DbConRegistry()

if "sqlite" in dialects:
dbs.register("sqlite", create_engine("sqlite:///:memory:"))
if "postgresql" in dialects:
port = os.environ.get("PGPORT", "5433")
dbs.register("postgresql", create_engine('postgresql://postgres:@localhost:%s/postgres'%port))


yield dbs

# cleanup
for engine in dbs.connections.values():
engine.dispose()

@pytest.fixture(scope = "module")
def dfs(dbs):
yield dbs.load_df(DATA)

def test_distinct_no_args(dfs):
assert_equal_query(dfs, distinct(), DATA.drop_duplicates())
assert_equal_query(dfs, distinct(), distinct(DATA))

def test_distinct_one_arg(dfs):
assert_equal_query(
dfs,
distinct(_.y),
DATA.drop_duplicates(['y'])[['y']].reset_index(drop = True)
)

assert_equal_query(dfs, distinct(_.y), distinct(DATA, _.y))

def test_distinct_keep_all_not_impl(dfs):
# TODO: should just mock LazyTbl
for tbl in dfs:
with pytest.raises(NotImplementedError):
distinct(tbl, _.y, _keep_all = True) >> collect()


@pytest.mark.xfail
def test_distinct_via_group_by(dfs):
# NotImplemented
assert False

def test_distinct_kwargs(dfs):
dst = DATA.drop_duplicates(['y', 'x']) \
.rename(columns = {'x': 'a'}) \
.reset_index(drop = True)[['y', 'a']]

assert_equal_query(dfs, distinct(_.y, a = _.x), dst)




0 comments on commit 4921fd8

Please sign in to comment.