From f3d39cbdf2c0a4025911e58804e6b1742657971b Mon Sep 17 00:00:00 2001 From: Michael Chow Date: Sat, 30 Apr 2022 17:15:11 -0600 Subject: [PATCH] fix: correctly dispatch array_function --- siuba/siu/symbolic.py | 4 ++-- siuba/tests/test_siu_symbolic.py | 38 +++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/siuba/siu/symbolic.py b/siuba/siu/symbolic.py index d518de40..483c0086 100644 --- a/siuba/siu/symbolic.py +++ b/siuba/siu/symbolic.py @@ -15,7 +15,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Handle numpy universal functions. E.g. np.sqrt(_).""" return array_ufunc(self, ufunc, method, *inputs, **kwargs) - def __array_function__(self, func, types, *args, **kwargs): + def __array_function__(self, func, types, args, kwargs): return array_function(self, func, types, *args, **kwargs) # allowed methods ---- @@ -128,7 +128,7 @@ def array_function(self, func, types, *args, **kwargs): @array_function.register(Call) def _array_function_call(self, func, types, *args, **kwargs): - return Call("__call__", FuncArg(array_function), self, func, *args, **kwargs) + return Call("__call__", FuncArg(array_function), self, func, types, *args, **kwargs) @array_function.register(Symbolic) diff --git a/siuba/tests/test_siu_symbolic.py b/siuba/tests/test_siu_symbolic.py index 7c5fb69d..0cdfa317 100644 --- a/siuba/tests/test_siu_symbolic.py +++ b/siuba/tests/test_siu_symbolic.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call +from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call, _ as D # Note that currently tests are split across the test_siu.py, and this module. @@ -84,14 +84,40 @@ def test_siu_symbolic_array_ufunc_sql_raises(_, func): assert "Numpy function sql translation" in exc_info.value.args[0] assert "not supported" in exc_info.value.args[0] -def test_siu_symbolic_array_ufunc_pandas(_): + + +@pytest.mark.parametrize("sym, res", [ + (np.sqrt(D), lambda ser: np.sqrt(ser)), # ufunc + (np.add(D, 1), lambda ser: np.add(ser, 1)), + (np.add(1, D), lambda ser : np.add(1, ser)), + (np.add(D, D), lambda ser : np.add(ser, ser)), +]) +def test_siu_symbolic_array_ufunc_pandas(_, sym, res): import pandas as pd - lhs = pd.Series([1,2]) + ser = pd.Series([1,2]) - sym = np.add(_, 1) expr = strip_symbolic(sym) - src = expr(lhs) + src = expr(ser) + dst = res(ser) + assert isinstance(src, pd.Series) - assert src.equals(lhs + 1) + assert src.equals(dst) + + +@pytest.mark.parametrize("sym, res", [ + (np.mean(D), lambda ser : np.mean(ser)), # __array_function__ + (np.sum(D), lambda ser: np.sum(ser)), + (np.sqrt(np.mean(D)), lambda ser: np.sqrt(np.mean(ser))), +]) +def test_siu_symbolic_array_function_pandas(_, sym, res): + import pandas as pd + ser = pd.Series([1,2]) + + expr = strip_symbolic(sym) + + src = expr(ser) + dst = res(ser) + # note that all examples currently are aggregates + assert src == dst