-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from machow/feat-sql-tests
Feat sql tests
- Loading branch information
Showing
12 changed files
with
188 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import pytest | ||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--dbs", action="store", default="sqlite", help="databases tested against (comma separated)" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from sqlalchemy import create_engine, types | ||
from siuba.sql import LazyTbl, collect | ||
from pandas.testing import assert_frame_equal | ||
|
||
class DbConRegistry: | ||
table_name_indx = 0 | ||
|
||
def __init__(self): | ||
self.connections = {} | ||
|
||
def register(self, name, engine): | ||
self.connections[name] = engine | ||
|
||
def remove(self, name): | ||
con = self.connections[name] | ||
con.close() | ||
del self.connections[name] | ||
|
||
return con | ||
|
||
@classmethod | ||
def unique_table_name(cls): | ||
cls.table_name_indx += 1 | ||
return "siuba_{0:03d}".format(cls.table_name_indx) | ||
|
||
def load_df(self, df): | ||
out = [] | ||
for k, engine in self.connections.items(): | ||
lazy_tbl = copy_to_sql(df, self.unique_table_name(), engine) | ||
out.append(lazy_tbl) | ||
return out | ||
|
||
def assert_frame_sort_equal(a, b): | ||
"""Tests that DataFrames are equal, even if rows are in different order""" | ||
sorted_a = a.sort_values(by = a.columns.tolist()).reset_index(drop = True) | ||
sorted_b = b.sort_values(by = b.columns.tolist()).reset_index(drop = True) | ||
|
||
assert_frame_equal(sorted_a, sorted_b) | ||
|
||
def assert_equal_query(tbls, lazy_query, target): | ||
for tbl in tbls: | ||
out = collect(lazy_query(tbl)) | ||
assert_frame_sort_equal(out, target) | ||
|
||
|
||
PREFIX_TO_TYPE = { | ||
# for datetime, need to convert to pandas datetime column | ||
#"dt": types.DateTime, | ||
"int": types.Integer, | ||
"float": types.Float, | ||
"str": types.String | ||
} | ||
|
||
def auto_types(df): | ||
dtype = {} | ||
for k in df.columns: | ||
pref, *_ = k.split('_') | ||
if pref in PREFIX_TO_TYPE: | ||
dtype[k] = PREFIX_TO_TYPE[pref] | ||
return dtype | ||
|
||
|
||
def copy_to_sql(df, name, engine): | ||
df.to_sql(name, engine, dtype = auto_types(df), index = False, if_exists = "replace") | ||
return LazyTbl(engine, name) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Note: this test file was heavily influenced by its dbplyr counterpart. | ||
https://github.com/tidyverse/dbplyr/blob/master/tests/testthat/test-verb-distinct.R | ||
""" | ||
|
||
from siuba.sql import LazyTbl, collect | ||
from siuba import _, distinct | ||
import pandas as pd | ||
import os | ||
|
||
import pytest | ||
from sqlalchemy import create_engine | ||
|
||
from .helpers import assert_equal_query, DbConRegistry | ||
|
||
DATA = pd.DataFrame({ | ||
"x": [1,1,1,1], | ||
"y": [1,1,2,2], | ||
"z": [1,2,1,2] | ||
}) | ||
|
||
@pytest.fixture(scope = "module") | ||
def dbs(request): | ||
dialects = set(request.config.getoption("--dbs").split(",")) | ||
dbs = DbConRegistry() | ||
|
||
if "sqlite" in dialects: | ||
dbs.register("sqlite", create_engine("sqlite:///:memory:")) | ||
if "postgresql" in dialects: | ||
port = os.environ.get("PGPORT", "5433") | ||
dbs.register("postgresql", create_engine('postgresql://postgres:@localhost:%s/postgres'%port)) | ||
|
||
|
||
yield dbs | ||
|
||
# cleanup | ||
for engine in dbs.connections.values(): | ||
engine.dispose() | ||
|
||
@pytest.fixture(scope = "module") | ||
def dfs(dbs): | ||
yield dbs.load_df(DATA) | ||
|
||
def test_distinct_no_args(dfs): | ||
assert_equal_query(dfs, distinct(), DATA.drop_duplicates()) | ||
assert_equal_query(dfs, distinct(), distinct(DATA)) | ||
|
||
def test_distinct_one_arg(dfs): | ||
assert_equal_query( | ||
dfs, | ||
distinct(_.y), | ||
DATA.drop_duplicates(['y'])[['y']].reset_index(drop = True) | ||
) | ||
|
||
assert_equal_query(dfs, distinct(_.y), distinct(DATA, _.y)) | ||
|
||
def test_distinct_keep_all_not_impl(dfs): | ||
# TODO: should just mock LazyTbl | ||
for tbl in dfs: | ||
with pytest.raises(NotImplementedError): | ||
distinct(tbl, _.y, _keep_all = True) >> collect() | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_distinct_via_group_by(dfs): | ||
# NotImplemented | ||
assert False | ||
|
||
def test_distinct_kwargs(dfs): | ||
dst = DATA.drop_duplicates(['y', 'x']) \ | ||
.rename(columns = {'x': 'a'}) \ | ||
.reset_index(drop = True)[['y', 'a']] | ||
|
||
assert_equal_query(dfs, distinct(_.y, a = _.x), dst) | ||
|
||
|
||
|
||
|