Skip to content

Commit

Permalink
Implement type inference for ufunc.accumulate
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Feb 1, 2013
1 parent c63bde0 commit 5221249
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
20 changes: 19 additions & 1 deletion numba/tests/np/test_ufunc_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def add_reduce(a):
def add_reduce_axis(a, axis):
return numba.typeof(np.add.reduce(a, axis=axis))

@autojit
def accumulate(ufunc, a):
return numba.typeof(ufunc.accumulate(a))

@autojit
def accumulate_dtype(ufunc, a, dtype):
return numba.typeof(ufunc.accumulate(a, dtype=dtype))


#------------------------------------------------------------------------
# Tests
#------------------------------------------------------------------------
Expand Down Expand Up @@ -81,6 +90,15 @@ def test_ufunc_reduce():
equals(add_reduce(a), int32)
equals(add_reduce_axis(b, 1), int64[:])

def test_ufunc_accumulate():
equals(accumulate(np.add, a), int32[:])
equals(accumulate(np.multiply, a), int32[:])

equals(accumulate(np.bitwise_and, a), int32[:])

equals(accumulate(np.logical_and, a), bool_[:])

if __name__ == "__main__":
test_binary_ufunc()
# test_binary_ufunc()
# test_ufunc_reduce()
test_ufunc_accumulate()
53 changes: 38 additions & 15 deletions numba/type_inference/modules/numpyufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@
from numba.typesystem import get_type
from numba.type_inference.modules.numpymodule import (get_dtype,
array_from_type,
promote)
promote,
promote_to_array)

def _dtype(a, dtype, static_dtype):
if static_dtype:
return static_dtype
elif dtype:
return dtype
elif a.is_array:
return a.dtype
elif not a.is_object:
return a
else:
return None

def binary_map(context, a, b, out):
if out is not None:
Expand All @@ -32,29 +45,37 @@ def reduce_(a, axis, dtype, out, static_dtype=None):
if out is not None:
return out

if static_dtype:
dtype_type = static_dtype
else:
dtype_type = get_dtype(dtype, default_dtype=a.dtype).dtype
dtype_type = _dtype(a, dtype, static_dtype)

if axis is None:
# Return the scalar type
return dtype_type

# Handle the axis parameter
if axis.is_tuple and axis.is_sized:
# axis=(tuple with a constant size)
return typesystem.array(dtype_type, a.ndim - axis.size)
elif axis.is_int:
# axis=1
return typesystem.array(dtype_type, a.ndim - 1)
else:
# axis=(something unknown)
return object_
if dtype_type:
# Handle the axis parameter
if axis.is_tuple and axis.is_sized:
# axis=(tuple with a constant size)
return typesystem.array(dtype_type, a.ndim - axis.size)
elif axis.is_int:
# axis=1
return typesystem.array(dtype_type, a.ndim - 1)
else:
# axis=(something unknown)
return object_

def reduce_bool(a, axis, dtype, out):
return reduce_(a, axis, dtype, out, bool_)

def accumulate(a, axis, dtype, out, static_dtype=None):
if out is not None:
return out

dtype = _dtype(a, dtype, static_dtype)
if dtype:
return promote_to_array(a).copy(dtype=dtype)

def accumulate_bool(a, axis, dtype, out):
return accumulate(a, axis, dtype, out, bool_)

#------------------------------------------------------------------------
# Binary Ufuncs
Expand Down Expand Up @@ -107,7 +128,9 @@ def reduce_bool(a, axis, dtype, out):
for binary_ufunc in binary_ufuncs_bitwise + binary_ufuncs_arithmetic:
register_inferer(np, binary_ufunc, binary_map)
register_unbound(np, binary_ufunc, "reduce", reduce_)
register_unbound(np, binary_ufunc, "accumulate", accumulate)

for binary_ufunc in binary_ufuncs_compare + binary_ufuncs_logical:
register_inferer(np, binary_ufunc, binary_map_bool)
register_unbound(np, binary_ufunc, "reduce", reduce_bool)
register_unbound(np, binary_ufunc, "accumulate", accumulate_bool)

0 comments on commit 5221249

Please sign in to comment.