Skip to content

Commit

Permalink
Merge pull request #419 from machow/fix-array-function
Browse files Browse the repository at this point in the history
fix: correctly dispatch array_function
  • Loading branch information
machow authored Apr 30, 2022
2 parents a9bac54 + f3d39cb commit 0c8d3e2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
4 changes: 2 additions & 2 deletions siuba/siu/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Handle numpy universal functions. E.g. np.sqrt(_)."""
return array_ufunc(self, ufunc, method, *inputs, **kwargs)

def __array_function__(self, func, types, *args, **kwargs):
def __array_function__(self, func, types, args, kwargs):
return array_function(self, func, types, *args, **kwargs)

# allowed methods ----
Expand Down Expand Up @@ -128,7 +128,7 @@ def array_function(self, func, types, *args, **kwargs):

@array_function.register(Call)
def _array_function_call(self, func, types, *args, **kwargs):
return Call("__call__", FuncArg(array_function), self, func, *args, **kwargs)
return Call("__call__", FuncArg(array_function), self, func, types, *args, **kwargs)


@array_function.register(Symbolic)
Expand Down
38 changes: 32 additions & 6 deletions siuba/tests/test_siu_symbolic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call
from siuba.siu import strip_symbolic, FunctionLookupError, Symbolic, MetaArg, Call, _ as D


# Note that currently tests are split across the test_siu.py, and this module.
Expand Down Expand Up @@ -84,14 +84,40 @@ def test_siu_symbolic_array_ufunc_sql_raises(_, func):
assert "Numpy function sql translation" in exc_info.value.args[0]
assert "not supported" in exc_info.value.args[0]

def test_siu_symbolic_array_ufunc_pandas(_):


@pytest.mark.parametrize("sym, res", [
(np.sqrt(D), lambda ser: np.sqrt(ser)), # ufunc
(np.add(D, 1), lambda ser: np.add(ser, 1)),
(np.add(1, D), lambda ser : np.add(1, ser)),
(np.add(D, D), lambda ser : np.add(ser, ser)),
])
def test_siu_symbolic_array_ufunc_pandas(_, sym, res):
import pandas as pd
lhs = pd.Series([1,2])
ser = pd.Series([1,2])

sym = np.add(_, 1)
expr = strip_symbolic(sym)

src = expr(lhs)
src = expr(ser)
dst = res(ser)

assert isinstance(src, pd.Series)
assert src.equals(lhs + 1)
assert src.equals(dst)


@pytest.mark.parametrize("sym, res", [
(np.mean(D), lambda ser : np.mean(ser)), # __array_function__
(np.sum(D), lambda ser: np.sum(ser)),
(np.sqrt(np.mean(D)), lambda ser: np.sqrt(np.mean(ser))),
])
def test_siu_symbolic_array_function_pandas(_, sym, res):
import pandas as pd
ser = pd.Series([1,2])

expr = strip_symbolic(sym)

src = expr(ser)
dst = res(ser)

# note that all examples currently are aggregates
assert src == dst

0 comments on commit 0c8d3e2

Please sign in to comment.