diff --git a/numba/tests/np/test_ufunc_type_inference.py b/numba/tests/np/test_ufunc_type_inference.py index 35288826cf2..992ed4fe517 100644 --- a/numba/tests/np/test_ufunc_type_inference.py +++ b/numba/tests/np/test_ufunc_type_inference.py @@ -59,6 +59,10 @@ def reduceat(ufunc, a): def reduceat_dtype(ufunc, a, dtype): return numba.typeof(ufunc.reduceat(a, dtype=dtype)) +@autojit +def outer(ufunc, a): + return numba.typeof(ufunc.outer(a, a)) + #------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------ @@ -119,8 +123,17 @@ def test_ufunc_reduceat(): # Test with dtype equals(reduceat_dtype(np.add, a, np.double), double[:]) +def test_ufunc_outer(): + equals(outer(np.add, a), int32[:, :]) + equals(outer(np.multiply, a), int32[:, :]) + + equals(outer(np.bitwise_and, a), int32[:, :]) + + equals(outer(np.logical_and, a), bool_[:, :]) + if __name__ == "__main__": # test_binary_ufunc() # test_ufunc_reduce() test_ufunc_accumulate() test_ufunc_reduceat() + test_ufunc_outer() diff --git a/numba/type_inference/modules/numpymodule.py b/numba/type_inference/modules/numpymodule.py index 17883704ffe..4827399c40c 100644 --- a/numba/type_inference/modules/numpymodule.py +++ b/numba/type_inference/modules/numpymodule.py @@ -220,7 +220,9 @@ def inner(context, a, b): @register(np) def outer(context, a, b): - raise NotImplementedError("XXX") + result_type = promote(context, a, b) + if result_type.is_array: + return result_type.dtype[:, :] @register(np) def tensordot(context, a, b): diff --git a/numba/type_inference/modules/numpyufuncs.py b/numba/type_inference/modules/numpyufuncs.py index 6fef4dc5a54..8135241c037 100644 --- a/numba/type_inference/modules/numpyufuncs.py +++ b/numba/type_inference/modules/numpyufuncs.py @@ -97,6 +97,13 @@ def reduceat(a, indices, axis, dtype, out, static_dtype=None): def reduceat_bool(a, indices, axis, dtype, out): return reduceat(a, indices, axis, dtype, out, bool_) +def outer(context, a, b, static_dtype=None): + a = array_of_dtype(a, None, static_dtype, out=None) + if a and a.is_array: + return a.dtype[:, :] + +def outer_bool(context, a, b): + return outer(context, a, b, bool_) #------------------------------------------------------------------------ # Binary Ufuncs @@ -151,9 +158,11 @@ def reduceat_bool(a, indices, axis, dtype, out): register_unbound(np, binary_ufunc, "reduce", reduce_) register_unbound(np, binary_ufunc, "accumulate", accumulate) register_unbound(np, binary_ufunc, "reduceat", reduceat) + register_unbound(np, binary_ufunc, "outer", outer) for binary_ufunc in binary_ufuncs_compare + binary_ufuncs_logical: register_inferer(np, binary_ufunc, binary_map_bool) register_unbound(np, binary_ufunc, "reduce", reduce_bool) register_unbound(np, binary_ufunc, "accumulate", accumulate_bool) register_unbound(np, binary_ufunc, "reduceat", reduceat_bool) + register_unbound(np, binary_ufunc, "outer", outer_bool)