Skip to content

Commit

Permalink
Merge pull request numba#8431 from guilhermeleobas/refactor_np_linspa…
Browse files Browse the repository at this point in the history
…ce_take

Replace `@overload_glue` by `@overload` for `np.linspace` and `np.take`
  • Loading branch information
sklam authored Sep 13, 2022
2 parents e82b042 + 3f0a158 commit c87e22e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 133 deletions.
17 changes: 0 additions & 17 deletions numba/core/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,23 +483,6 @@ def resolve_flatten(self, ary, args, kws):
assert not args
return signature(ary.copy(ndim=1, layout='C'))

@bound_function("array.take")
def resolve_take(self, ary, args, kws):
if kws:
raise NumbaAssertionError("kws not supported")
argty, = args
if isinstance(argty, types.Integer):
sig = signature(ary.dtype, *args)
elif isinstance(argty, types.Array):
sig = signature(argty.copy(layout='C', dtype=ary.dtype), *args)
elif isinstance(argty, types.List): # 1d lists only
sig = signature(types.Array(ary.dtype, 1, 'C'), *args)
elif isinstance(argty, types.BaseTuple):
sig = signature(types.Array(ary.dtype, np.ndim(argty), 'C'), *args)
else:
raise TypeError("take(%s) not supported for %s" % argty)
return sig

def generic_resolve(self, ary, attr):
# Resolution of other attributes, for record arrays
if isinstance(ary.dtype, types.Record):
Expand Down
45 changes: 0 additions & 45 deletions numba/core/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,30 +589,6 @@ def _infer_dtype_from_inputs(inputs):
return dtype


@glue_typing(np.linspace)
class NdLinspace(AbstractTemplate):

def generic(self, args, kws):
assert not kws
bounds = args[:2]
if not all(isinstance(arg, types.Number) for arg in bounds):
return
if len(args) >= 3:
num = args[2]
if not isinstance(num, types.Integer):
return
if len(args) >= 4:
# Not supporting the other arguments as it would require
# keyword arguments for reasonable use.
return
if any(isinstance(arg, types.Complex) for arg in bounds):
dtype = types.complex128
else:
dtype = types.float64
return_type = types.Array(ndim=1, dtype=dtype, layout='C')
return signature(return_type, *args)


@glue_typing(np.frombuffer)
class NdFromBuffer(CallableTemplate):

Expand Down Expand Up @@ -1118,27 +1094,6 @@ def typer(ref, k=0):
return typer


@glue_typing(np.take)
class Take(AbstractTemplate):

def generic(self, args, kws):
if kws:
raise NumbaAssertionError("kws not supported")
if len(args) != 2:
raise NumbaAssertionError("two arguments are required")
arr, ind = args
if isinstance(ind, types.Number):
retty = arr.dtype
elif isinstance(ind, types.Array):
retty = types.Array(ndim=ind.ndim, dtype=arr.dtype, layout='C')
elif isinstance(ind, types.List):
retty = types.Array(ndim=1, dtype=arr.dtype, layout='C')
elif isinstance(ind, types.BaseTuple):
retty = types.Array(ndim=np.ndim(ind), dtype=arr.dtype, layout='C')
else:
return None

return signature(retty, *args)

# -----------------------------------------------------------------------------
# Numba helpers
Expand Down
126 changes: 55 additions & 71 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4389,66 +4389,52 @@ def diag_impl(arr, k=0):
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin('array.take', types.Array, types.Integer)
@glue_lowering(np.take, types.Array, types.Integer)
def numpy_take_1(context, builder, sig, args):
@overload(np.take)
@overload_method(types.Array, 'take')
def numpy_take(a, indices):

def take_impl(a, indices):
if indices > (a.size - 1) or indices < -a.size:
raise IndexError("Index out of bounds")
return a.ravel()[indices]

res = context.compile_internal(builder, take_impl, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin('array.take', types.Array, types.Array)
@glue_lowering(np.take, types.Array, types.Array)
def numpy_take_2(context, builder, sig, args):

F_order = sig.args[1].layout == 'F'

def take_impl(a, indices):
ret = np.empty(indices.size, dtype=a.dtype)
if F_order:
walker = indices.copy() # get C order
else:
walker = indices
it = np.nditer(walker)
i = 0
flat = a.ravel()
for x in it:
if x > (a.size - 1) or x < -a.size:
if isinstance(a, types.Array) and isinstance(indices, types.Integer):
def take_impl(a, indices):
if indices > (a.size - 1) or indices < -a.size:
raise IndexError("Index out of bounds")
ret[i] = flat[x]
i = i + 1
return ret.reshape(indices.shape)
return a.ravel()[indices]
return take_impl

res = context.compile_internal(builder, take_impl, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin('array.take', types.Array, types.List)
@glue_lowering(np.take, types.Array, types.List)
@lower_builtin('array.take', types.Array, types.BaseTuple)
@glue_lowering(np.take, types.Array, types.BaseTuple)
def numpy_take_3(context, builder, sig, args):

def take_impl(a, indices):
convert = np.array(indices)
ret = np.empty(convert.size, dtype=a.dtype)
it = np.nditer(convert)
i = 0
flat = a.ravel()
for x in it:
if x > (a.size - 1) or x < -a.size:
raise IndexError("Index out of bounds")
ret[i] = flat[x]
i = i + 1
return ret.reshape(convert.shape)
if all(isinstance(arg, types.Array) for arg in [a, indices]):
F_order = indices.layout == 'F'

res = context.compile_internal(builder, take_impl, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
def take_impl(a, indices):
ret = np.empty(indices.size, dtype=a.dtype)
if F_order:
walker = indices.copy() # get C order
else:
walker = indices
it = np.nditer(walker)
i = 0
flat = a.ravel()
for x in it:
if x > (a.size - 1) or x < -a.size:
raise IndexError("Index out of bounds")
ret[i] = flat[x]
i = i + 1
return ret.reshape(indices.shape)
return take_impl

if isinstance(a, types.Array) and \
isinstance(indices, (types.List, types.BaseTuple)):
def take_impl(a, indices):
convert = np.array(indices)
ret = np.empty(convert.size, dtype=a.dtype)
it = np.nditer(convert)
i = 0
flat = a.ravel()
for x in it:
if x > (a.size - 1) or x < -a.size:
raise IndexError("Index out of bounds")
ret[i] = flat[x]
i = i + 1
return ret.reshape(convert.shape)
return take_impl


def _arange_dtype(*args):
Expand Down Expand Up @@ -4545,22 +4531,22 @@ def impl(start, stop=None, step=None, dtype=None):
return impl


@glue_lowering(np.linspace, types.Number, types.Number)
def numpy_linspace_2(context, builder, sig, args):

def linspace(start, stop):
return np.linspace(start, stop, 50)

res = context.compile_internal(builder, linspace, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
@overload(np.linspace)
def numpy_linspace(start, stop, num=50):
if not all(isinstance(arg, types.Number) for arg in [start, stop]):
return

if not isinstance(num, (int, types.Integer)):
msg = 'The argument "num" must be an integer'
raise errors.TypingError(msg)

@glue_lowering(np.linspace, types.Number, types.Number, types.Integer)
def numpy_linspace_3(context, builder, sig, args):
dtype = as_dtype(sig.return_type.dtype)
if any(isinstance(arg, types.Complex) for arg in [start, stop]):
dtype = types.complex128
else:
dtype = types.float64

# Implementation based on https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L24 # noqa: E501
def linspace(start, stop, num):
def linspace(start, stop, num=50):
arr = np.empty(num, dtype)
# The multiply by 1.0 mirrors
# https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L125-L128 # noqa: E501
Expand All @@ -4582,9 +4568,7 @@ def linspace(start, stop, num):
if num > 1:
arr[-1] = stop
return arr

res = context.compile_internal(builder, linspace, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
return linspace


def _array_copy(context, builder, sig, args):
Expand Down

0 comments on commit c87e22e

Please sign in to comment.