Skip to content

Commit

Permalink
Simplify some type inference code, implement type inference for reduceat
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Feb 1, 2013
1 parent 5221249 commit c2459c3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
22 changes: 22 additions & 0 deletions numba/tests/np/test_ufunc_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def accumulate(ufunc, a):
def accumulate_dtype(ufunc, a, dtype):
return numba.typeof(ufunc.accumulate(a, dtype=dtype))

@autojit
def reduceat(ufunc, a):
return numba.typeof(ufunc.reduceat(a))

@autojit
def reduceat_dtype(ufunc, a, dtype):
return numba.typeof(ufunc.reduceat(a, dtype=dtype))

#------------------------------------------------------------------------
# Tests
Expand Down Expand Up @@ -98,7 +105,22 @@ def test_ufunc_accumulate():

equals(accumulate(np.logical_and, a), bool_[:])

# Test with dtype
equals(accumulate_dtype(np.add, a, np.double), double[:])

def test_ufunc_reduceat():
equals(reduceat(np.add, a), int32[:])
equals(reduceat(np.multiply, a), int32[:])

equals(reduceat(np.bitwise_and, a), int32[:])

equals(reduceat(np.logical_and, a), bool_[:])

# Test with dtype
equals(reduceat_dtype(np.add, a, np.double), double[:])

if __name__ == "__main__":
# test_binary_ufunc()
# test_ufunc_reduce()
test_ufunc_accumulate()
test_ufunc_reduceat()
6 changes: 6 additions & 0 deletions numba/type_inference/modules/numpymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def promote_to_array(dtype):
dtype = minitypes.ArrayType(dtype, 0)
return dtype

def demote_to_scalar(type):
"Demote 0d arrays to scalars"
if type and type.is_array and type.ndim == 0:
return type.dtype
return type

def array_from_object(a):
"""
object -> array type:
Expand Down
39 changes: 31 additions & 8 deletions numba/type_inference/modules/numpyufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,39 @@
from numba.type_inference.modules.numpymodule import (get_dtype,
array_from_type,
promote,
promote_to_array)
promote_to_array,
demote_to_scalar)

#----------------------------------------------------------------------------
# Utilities
#----------------------------------------------------------------------------

def array_of_dtype(a, dtype, static_dtype, out):
if out is not None:
return out

a = array_from_type(a)
if not a.is_object:
dtype = _dtype(a, dtype, static_dtype)
if dtype is not None:
return a.copy(dtype=dtype)

def _dtype(a, dtype, static_dtype):
if static_dtype:
return static_dtype
elif dtype:
return dtype
return dtype.dtype
elif a.is_array:
return a.dtype
elif not a.is_object:
return a
else:
return None

#----------------------------------------------------------------------------
# Ufunc type inference
#----------------------------------------------------------------------------

def binary_map(context, a, b, out):
if out is not None:
return out
Expand Down Expand Up @@ -67,16 +86,18 @@ def reduce_bool(a, axis, dtype, out):
return reduce_(a, axis, dtype, out, bool_)

def accumulate(a, axis, dtype, out, static_dtype=None):
if out is not None:
return out

dtype = _dtype(a, dtype, static_dtype)
if dtype:
return promote_to_array(a).copy(dtype=dtype)
return demote_to_scalar(array_of_dtype(a, dtype, static_dtype, out))

def accumulate_bool(a, axis, dtype, out):
return accumulate(a, axis, dtype, out, bool_)

def reduceat(a, indices, axis, dtype, out, static_dtype=None):
return accumulate(a, axis, dtype, out, static_dtype)

def reduceat_bool(a, indices, axis, dtype, out):
return reduceat(a, indices, axis, dtype, out, bool_)


#------------------------------------------------------------------------
# Binary Ufuncs
#------------------------------------------------------------------------
Expand Down Expand Up @@ -129,8 +150,10 @@ def accumulate_bool(a, axis, dtype, out):
register_inferer(np, binary_ufunc, binary_map)
register_unbound(np, binary_ufunc, "reduce", reduce_)
register_unbound(np, binary_ufunc, "accumulate", accumulate)
register_unbound(np, binary_ufunc, "reduceat", reduceat)

for binary_ufunc in binary_ufuncs_compare + binary_ufuncs_logical:
register_inferer(np, binary_ufunc, binary_map_bool)
register_unbound(np, binary_ufunc, "reduce", reduce_bool)
register_unbound(np, binary_ufunc, "accumulate", accumulate_bool)
register_unbound(np, binary_ufunc, "reduceat", reduceat_bool)

0 comments on commit c2459c3

Please sign in to comment.