diff --git a/numba/tests/np/test_ufunc_type_inference.py b/numba/tests/np/test_ufunc_type_inference.py index 55499973f6d..35288826cf2 100644 --- a/numba/tests/np/test_ufunc_type_inference.py +++ b/numba/tests/np/test_ufunc_type_inference.py @@ -51,6 +51,13 @@ def accumulate(ufunc, a): def accumulate_dtype(ufunc, a, dtype): return numba.typeof(ufunc.accumulate(a, dtype=dtype)) +@autojit +def reduceat(ufunc, a): + return numba.typeof(ufunc.reduceat(a)) + +@autojit +def reduceat_dtype(ufunc, a, dtype): + return numba.typeof(ufunc.reduceat(a, dtype=dtype)) #------------------------------------------------------------------------ # Tests @@ -98,7 +105,22 @@ def test_ufunc_accumulate(): equals(accumulate(np.logical_and, a), bool_[:]) + # Test with dtype + equals(accumulate_dtype(np.add, a, np.double), double[:]) + +def test_ufunc_reduceat(): + equals(reduceat(np.add, a), int32[:]) + equals(reduceat(np.multiply, a), int32[:]) + + equals(reduceat(np.bitwise_and, a), int32[:]) + + equals(reduceat(np.logical_and, a), bool_[:]) + + # Test with dtype + equals(reduceat_dtype(np.add, a, np.double), double[:]) + if __name__ == "__main__": # test_binary_ufunc() # test_ufunc_reduce() test_ufunc_accumulate() + test_ufunc_reduceat() diff --git a/numba/type_inference/modules/numpymodule.py b/numba/type_inference/modules/numpymodule.py index 1405c7f6d94..17883704ffe 100644 --- a/numba/type_inference/modules/numpymodule.py +++ b/numba/type_inference/modules/numpymodule.py @@ -64,6 +64,12 @@ def promote_to_array(dtype): dtype = minitypes.ArrayType(dtype, 0) return dtype +def demote_to_scalar(type): + "Demote 0d arrays to scalars" + if type and type.is_array and type.ndim == 0: + return type.dtype + return type + def array_from_object(a): """ object -> array type: diff --git a/numba/type_inference/modules/numpyufuncs.py b/numba/type_inference/modules/numpyufuncs.py index d5a0fec0edb..6fef4dc5a54 100644 --- a/numba/type_inference/modules/numpyufuncs.py +++ b/numba/type_inference/modules/numpyufuncs.py @@ -14,13 +14,28 @@ from numba.type_inference.modules.numpymodule import (get_dtype, array_from_type, promote, - promote_to_array) + promote_to_array, + demote_to_scalar) + +#---------------------------------------------------------------------------- +# Utilities +#---------------------------------------------------------------------------- + +def array_of_dtype(a, dtype, static_dtype, out): + if out is not None: + return out + + a = array_from_type(a) + if not a.is_object: + dtype = _dtype(a, dtype, static_dtype) + if dtype is not None: + return a.copy(dtype=dtype) def _dtype(a, dtype, static_dtype): if static_dtype: return static_dtype elif dtype: - return dtype + return dtype.dtype elif a.is_array: return a.dtype elif not a.is_object: @@ -28,6 +43,10 @@ def _dtype(a, dtype, static_dtype): else: return None +#---------------------------------------------------------------------------- +# Ufunc type inference +#---------------------------------------------------------------------------- + def binary_map(context, a, b, out): if out is not None: return out @@ -67,16 +86,18 @@ def reduce_bool(a, axis, dtype, out): return reduce_(a, axis, dtype, out, bool_) def accumulate(a, axis, dtype, out, static_dtype=None): - if out is not None: - return out - - dtype = _dtype(a, dtype, static_dtype) - if dtype: - return promote_to_array(a).copy(dtype=dtype) + return demote_to_scalar(array_of_dtype(a, dtype, static_dtype, out)) def accumulate_bool(a, axis, dtype, out): return accumulate(a, axis, dtype, out, bool_) +def reduceat(a, indices, axis, dtype, out, static_dtype=None): + return accumulate(a, axis, dtype, out, static_dtype) + +def reduceat_bool(a, indices, axis, dtype, out): + return reduceat(a, indices, axis, dtype, out, bool_) + + #------------------------------------------------------------------------ # Binary Ufuncs #------------------------------------------------------------------------ @@ -129,8 +150,10 @@ def accumulate_bool(a, axis, dtype, out): register_inferer(np, binary_ufunc, binary_map) register_unbound(np, binary_ufunc, "reduce", reduce_) register_unbound(np, binary_ufunc, "accumulate", accumulate) + register_unbound(np, binary_ufunc, "reduceat", reduceat) for binary_ufunc in binary_ufuncs_compare + binary_ufuncs_logical: register_inferer(np, binary_ufunc, binary_map_bool) register_unbound(np, binary_ufunc, "reduce", reduce_bool) register_unbound(np, binary_ufunc, "accumulate", accumulate_bool) + register_unbound(np, binary_ufunc, "reduceat", reduceat_bool)