Skip to content

Commit

Permalink
Implement type inference for np.outer and ufunc.outer
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Feb 1, 2013
1 parent c2459c3 commit e267ce4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
13 changes: 13 additions & 0 deletions numba/tests/np/test_ufunc_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def reduceat(ufunc, a):
def reduceat_dtype(ufunc, a, dtype):
return numba.typeof(ufunc.reduceat(a, dtype=dtype))

@autojit
def outer(ufunc, a):
return numba.typeof(ufunc.outer(a, a))

#------------------------------------------------------------------------
# Tests
#------------------------------------------------------------------------
Expand Down Expand Up @@ -119,8 +123,17 @@ def test_ufunc_reduceat():
# Test with dtype
equals(reduceat_dtype(np.add, a, np.double), double[:])

def test_ufunc_outer():
equals(outer(np.add, a), int32[:, :])
equals(outer(np.multiply, a), int32[:, :])

equals(outer(np.bitwise_and, a), int32[:, :])

equals(outer(np.logical_and, a), bool_[:, :])

if __name__ == "__main__":
# test_binary_ufunc()
# test_ufunc_reduce()
test_ufunc_accumulate()
test_ufunc_reduceat()
test_ufunc_outer()
4 changes: 3 additions & 1 deletion numba/type_inference/modules/numpymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def inner(context, a, b):

@register(np)
def outer(context, a, b):
raise NotImplementedError("XXX")
result_type = promote(context, a, b)
if result_type.is_array:
return result_type.dtype[:, :]

@register(np)
def tensordot(context, a, b):
Expand Down
9 changes: 9 additions & 0 deletions numba/type_inference/modules/numpyufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def reduceat(a, indices, axis, dtype, out, static_dtype=None):
def reduceat_bool(a, indices, axis, dtype, out):
return reduceat(a, indices, axis, dtype, out, bool_)

def outer(context, a, b, static_dtype=None):
a = array_of_dtype(a, None, static_dtype, out=None)
if a and a.is_array:
return a.dtype[:, :]

def outer_bool(context, a, b):
return outer(context, a, b, bool_)

#------------------------------------------------------------------------
# Binary Ufuncs
Expand Down Expand Up @@ -151,9 +158,11 @@ def reduceat_bool(a, indices, axis, dtype, out):
register_unbound(np, binary_ufunc, "reduce", reduce_)
register_unbound(np, binary_ufunc, "accumulate", accumulate)
register_unbound(np, binary_ufunc, "reduceat", reduceat)
register_unbound(np, binary_ufunc, "outer", outer)

for binary_ufunc in binary_ufuncs_compare + binary_ufuncs_logical:
register_inferer(np, binary_ufunc, binary_map_bool)
register_unbound(np, binary_ufunc, "reduce", reduce_bool)
register_unbound(np, binary_ufunc, "accumulate", accumulate_bool)
register_unbound(np, binary_ufunc, "reduceat", reduceat_bool)
register_unbound(np, binary_ufunc, "outer", outer_bool)

0 comments on commit e267ce4

Please sign in to comment.