diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb256df6..148dc82a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: run: | make test-travis env: - SB_TEST_PGPORT: 5432 + SB_TEST_PGPORT: 5433 PYTEST_FLAGS: ${{ matrix.pytest_flags }} # optional step for running bigquery tests ---- diff --git a/docker-compose.yml b/docker-compose.yml index f50013b9..eb6d2f7e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,9 +3,10 @@ version: '3.1' services: db_mysql: - image: mysql + image: mysql/mysql-server restart: always environment: + MYSQL_ROOT_HOST: "%" MYSQL_ROOT_PASSWORD: "" MYSQL_ALLOW_EMPTY_PASSWORD: 1 MYSQL_DATABASE: "public" @@ -21,4 +22,4 @@ services: POSTGRES_PASSWORD: "" POSTGRES_HOST_AUTH_METHOD: "trust" ports: - - 5432:5432 + - 5433:5432 diff --git a/siuba/ops/utils.py b/siuba/ops/utils.py index 75132720..d113c272 100644 --- a/siuba/ops/utils.py +++ b/siuba/ops/utils.py @@ -69,7 +69,7 @@ def _register_series_default(generic): generic.register(pd.Series, partial(_default_pd_series, generic.operation)) -def _default_pd_series(__op, self, args = tuple(), kwargs = {}): +def _default_pd_series(__op, self, *args, **kwargs): # Once we drop python 3.7 dependency, could make __op position only if __op.accessor is not None: method = getattr(getattr(self, __op.accessor), __op.name) diff --git a/siuba/siu/symbolic.py b/siuba/siu/symbolic.py index 032c7ff7..d518de40 100644 --- a/siuba/siu/symbolic.py +++ b/siuba/siu/symbolic.py @@ -1,3 +1,5 @@ +from functools import singledispatch + from .calls import BINARY_OPS, UNARY_OPS, Call, BinaryOp, BinaryRightOp, MetaArg, UnaryOp, SliceOp, FuncArg from .format import Formatter @@ -9,6 +11,12 @@ def __init__(self, source = None, ready_to_call = False): self.__source = MetaArg("_") if source is None else source self.__ready_to_call = ready_to_call + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + """Handle numpy universal functions. E.g. np.sqrt(_).""" + return array_ufunc(self, ufunc, method, *inputs, **kwargs) + + def __array_function__(self, func, types, *args, **kwargs): + return array_function(self, func, types, *args, **kwargs) # allowed methods ---- @@ -108,7 +116,63 @@ def explain(symbol): return str(symbol) -# Do some gnarly method setting ----------------------------------------------- +# Special numpy ufunc dispatcher +# ============================================================================= +# note that this is essentially what dispatchers.symbolic_dispatch does... +# details on numpy array dispatch: https://github.com/numpy/numpy/issues/21387 + +@singledispatch +def array_function(self, func, types, *args, **kwargs): + return func(*args, **kwargs) + + +@array_function.register(Call) +def _array_function_call(self, func, types, *args, **kwargs): + return Call("__call__", FuncArg(array_function), self, func, *args, **kwargs) + + +@array_function.register(Symbolic) +def _array_function_sym(self, func, types, *args, **kwargs): + f_concrete = array_function.dispatch(Call) + + call = f_concrete( + strip_symbolic(self), + func, + types, + *map(strip_symbolic, args), + **{k: strip_symbolic(v) for k, v in kwargs.items()} + ) + + return Symbolic(call) + + +@singledispatch +def array_ufunc(self, ufunc, method, *inputs, **kwargs): + return getattr(ufunc, method)(*inputs, **kwargs) + +@array_ufunc.register(Call) +def _array_ufunc_call(self, ufunc, method, *inputs, **kwargs): + + return Call("__call__", FuncArg(array_ufunc), self, ufunc, method, *inputs, **kwargs) + + +@array_ufunc.register(Symbolic) +def _array_ufunc_sym(self, ufunc, method, *inputs, **kwargs): + f_concrete = array_ufunc.dispatch(Call) + + call = f_concrete( + strip_symbolic(self), + ufunc, + method, + *map(strip_symbolic, inputs), + **{k: strip_symbolic(v) for k, v in kwargs.items()} + ) + + return Symbolic(call) + + +# Do some gnarly method setting on Symbolic ----------------------------------- +# ============================================================================= def create_binary_op(op_name, left_op = True): def _binary_op(self, x): diff --git a/siuba/siu/visitors.py b/siuba/siu/visitors.py index 215d6b3b..cfa28c8a 100644 --- a/siuba/siu/visitors.py +++ b/siuba/siu/visitors.py @@ -9,7 +9,7 @@ class FunctionLookupBound: def __init__(self, msg): self.msg = msg - def __call__(self): + def __call__(self, *args, **kwargs): raise NotImplementedError(self.msg) diff --git a/siuba/sql/dialects/base.py b/siuba/sql/dialects/base.py index 4650b86f..babf6d75 100644 --- a/siuba/sql/dialects/base.py +++ b/siuba/sql/dialects/base.py @@ -39,7 +39,8 @@ annotate, RankOver, CumlOver, - SqlTranslator + SqlTranslator, + FunctionLookupBound ) @@ -122,6 +123,19 @@ def sql_func_capitalize(_, col): return sql.functions.concat(first_char, rest) +# Numpy ufuncs ---------------------------------------------------------------- +# symbolic objects have a generic dispatch for when _.__array_ufunc__ is called, +# in order to support things like np.sqrt(_.x). In theory this wouldn't be crazy +# to support, but most ufuncs have existing pandas methods already. + +from siuba.siu.symbolic import array_ufunc, array_function + +_f_err = FunctionLookupBound("Numpy function sql translation (e.g. np.sqrt) not supported.") + +array_ufunc.register(SqlColumn, _f_err) +array_function.register(SqlColumn, _f_err) + + # Misc implementations -------------------------------------------------------- def sql_func_astype(_, col, _type): diff --git a/siuba/tests/helpers.py b/siuba/tests/helpers.py index 75a77e5f..09c90b0e 100644 --- a/siuba/tests/helpers.py +++ b/siuba/tests/helpers.py @@ -26,7 +26,7 @@ def data_frame(*args, _index = None, **kwargs): "dialect": "postgresql", "driver": "", "dbname": ["SB_TEST_PGDATABASE", "postgres"], - "port": ["SB_TEST_PGPORT", "5432"], + "port": ["SB_TEST_PGPORT", "5433"], "user": ["SB_TEST_PGUSER", "postgres"], "password": ["SB_TEST_PGPASSWORD", ""], "host": ["SB_TEST_PGHOST", "localhost"], diff --git a/siuba/tests/test_dply_series_methods.py b/siuba/tests/test_dply_series_methods.py index c3394a5f..a9127020 100644 --- a/siuba/tests/test_dply_series_methods.py +++ b/siuba/tests/test_dply_series_methods.py @@ -362,6 +362,7 @@ def test_pandas_grouped_frame_fast_summarize(agg_entry): # Edge Cases ================================================================== +@pytest.mark.postgresql def test_frame_set_aggregates_postgresql(): # TODO: probably shouldn't be creating backend here backend = SqlBackend("postgresql") diff --git a/siuba/tests/test_siu.py b/siuba/tests/test_siu.py index aef59b0e..01a1408f 100644 --- a/siuba/tests/test_siu.py +++ b/siuba/tests/test_siu.py @@ -23,10 +23,40 @@ def _(): return Symbolic() -def test_source_attr(_): - sym = _.source + +# Symbolic class ============================================================== + +def test_symbolic_source_attr(_): + sym = _.__source + assert isinstance(sym, Symbolic) + assert explain(sym) == "_.__source" + + +def test_symbolic_numpy_ufunc(_): + from siuba.siu.symbolic import array_ufunc + import numpy as np + + # should have form... + # █─'__call__' + # ├─█─'__custom_func__' + # │ └─ + # ├─_ + # ├─ + # ├─'__call__' + # └─_ + + sym = np.sqrt(_) + expr = strip_symbolic(sym) + assert isinstance(sym, Symbolic) - assert explain(sym) == "_.source" + + # check we are doing a call over a custom dispatch function ---- + assert expr.func == "__call__" + + dispatcher = expr.args[0] + assert isinstance(dispatcher, FuncArg) + assert dispatcher.args[0] is array_ufunc # could check .dispatch() method + def test_op_vars_slice(_): assert strip_symbolic(_.a[_.b:_.c]).op_vars() == {'a', 'b', 'c'} diff --git a/siuba/tests/test_siu_symbolic.py b/siuba/tests/test_siu_symbolic.py new file mode 100644 index 00000000..7c5fb69d --- /dev/null +++ b/siuba/tests/test_siu_symbolic.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest + +from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call + + +# Note that currently tests are split across the test_siu.py, and this module. + +@pytest.fixture +def _(): + return Symbolic() + +def test_siu_symbolic_np_array_ufunc_call(_): + sym = np.add(_, 1) + expr = strip_symbolic(sym) + + # structure: + # █─'__call__' + # ├─█─'__custom_func__' + # │ └─ + # ├─_ + # ├─ + # ├─'__call__' + # ├─_ + # └─1 + + assert len(expr.args) == 6 + assert expr.args[1] is strip_symbolic(_) # original dispatch obj + assert expr.args[2] is np.add # ufunc object + assert expr.args[3] == "__call__" # its method to use + assert expr.args[4] is strip_symbolic(_) # lhs input + assert expr.args[5] == 1 # rhs input + + +def test_siu_symbolic_np_array_ufunc_inputs_lhs(_): + lhs = np.array([1,2]) + rhs = np.array([3,4]) + res = lhs + rhs + + # symbol on lhs ---- + + sym = np.add(_, rhs) + expr = strip_symbolic(sym) + + assert np.array_equal(expr(lhs), res) + + +def test_siu_symbolic_np_array_ufunc_inputs_rhs(_): + lhs = np.array([1,2]) + rhs = np.array([3,4]) + res = lhs + rhs + + # symbol on rhs ---- + + sym2 = np.add(lhs, _) + expr2 = strip_symbolic(sym2) + + assert np.array_equal(expr2(rhs), res) + + +@pytest.mark.xfail +def test_siu_symbolic_np_array_function(_): + # Note that np.sum is not a ufunc, but sort of reduces on a ufunc under the + # hood, so fails when called on a symbol + sym = np.sum(_) + expr = strip_symbolic(sym) + + assert expr(np.array([1,2])) == 3 + + +@pytest.mark.parametrize("func", [ + np.absolute, # a ufunc + np.sum # dispatched by __array_function__ + ]) +def test_siu_symbolic_array_ufunc_sql_raises(_, func): + from siuba.sql.utils import mock_sqlalchemy_engine + from siuba.sql import LazyTbl + from siuba.sql import SqlFunctionLookupError + + lazy_tbl = LazyTbl(mock_sqlalchemy_engine("postgresql"), "somedata", ["x", "y"]) + with pytest.raises(SqlFunctionLookupError) as exc_info: + lazy_tbl.shape_call(strip_symbolic(func(_.x))) + + assert "Numpy function sql translation" in exc_info.value.args[0] + assert "not supported" in exc_info.value.args[0] + +def test_siu_symbolic_array_ufunc_pandas(_): + import pandas as pd + lhs = pd.Series([1,2]) + + sym = np.add(_, 1) + expr = strip_symbolic(sym) + + src = expr(lhs) + assert isinstance(src, pd.Series) + assert src.equals(lhs + 1) +