diff --git a/siuba/sql/utils.py b/siuba/sql/utils.py index 653db5df..c6391690 100644 --- a/siuba/sql/utils.py +++ b/siuba/sql/utils.py @@ -1,5 +1,14 @@ import importlib +try: + # once we drop sqlalchemy 1.2, can use create_mock_engine function + from sqlalchemy.engine.mock import MockConnection +except ImportError: + # monkey patch old sqlalchemy mock, so it can be a context handler + from sqlalchemy.engine.strategies import MockEngineStrategy + MockConnection = MockEngineStrategy.MockConnection + + def get_dialect_translator(name): mod = importlib.import_module('siuba.sql.dialects.{}'.format(name)) return mod.translator @@ -37,11 +46,13 @@ def mock_sqlalchemy_engine(dialect): show_query(query) """ + from sqlalchemy.engine import Engine from sqlalchemy.dialects import registry - dialect_cls = registry.load('postgresql') - return Engine(None, dialect_cls(), '') + dialect_cls = registry.load(dialect) + + return MockConnection(dialect_cls(), lambda *args, **kwargs: None) # Temporary fix for pandas bug (https://github.com/pandas-dev/pandas/issues/35484) diff --git a/siuba/sql/verbs.py b/siuba/sql/verbs.py index ba9e42ad..300150df 100644 --- a/siuba/sql/verbs.py +++ b/siuba/sql/verbs.py @@ -29,7 +29,15 @@ ) from .translate import CustomOverClause, SqlColumn, SqlColumnAgg -from .utils import get_dialect_translator, _FixedSqlDatabase, _sql_select, _sql_column_collection, _sql_add_columns, _sql_with_only_columns +from .utils import ( + get_dialect_translator, + _FixedSqlDatabase, + _sql_select, + _sql_column_collection, + _sql_add_columns, + _sql_with_only_columns, + MockConnection +) from sqlalchemy import sql import sqlalchemy @@ -467,6 +475,12 @@ def _collect(__data, as_df = True): # compile_kwargs = {"literal_binds": True} #) + if isinstance(__data.source, MockConnection): + # a mock sqlalchemy is being used to show_query, and echo queries. + # it doesn't return a result object or have a context handler, so + # we need to bail out early + return + with __data.source.connect() as conn: if as_df: sql_db = _FixedSqlDatabase(conn) diff --git a/siuba/tests/test_sql_utils.py b/siuba/tests/test_sql_utils.py index 0253566c..232e0acf 100644 --- a/siuba/tests/test_sql_utils.py +++ b/siuba/tests/test_sql_utils.py @@ -1,4 +1,6 @@ -from siuba.sql.utils import get_dialect_translator +from siuba.sql.utils import get_dialect_translator, mock_sqlalchemy_engine +from siuba.sql.verbs import collect +from siuba.sql import LazyTbl import pytest @pytest.mark.parametrize('name', [ @@ -8,3 +10,15 @@ ]) def test_get_dialect_translator(name): get_dialect_translator(name) + +def test_mock_sqlalchemy_engine_dialect(): + engine = mock_sqlalchemy_engine("postgresql") + assert engine.dialect.name == "postgresql" + + engine = mock_sqlalchemy_engine("sqlite") + assert engine.dialect.name == "sqlite" + +def test_mock_sqlalchemy_engine_no_collect(): + engine = mock_sqlalchemy_engine("sqlite") + tbl = LazyTbl(engine, "some_table", ["x"]) + assert collect(tbl) is None