From 5b5515e14453d3b4d5b7b476d19f2000459980a7 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 9 Sep 2022 15:48:35 -0300 Subject: [PATCH 1/3] Replace overload_glue by @overload for np.linspace and np.take --- numba/core/typing/arraydecl.py | 17 ----- numba/core/typing/npydecl.py | 45 ----------- numba/np/arrayobj.py | 132 +++++++++++++++------------------ 3 files changed, 61 insertions(+), 133 deletions(-) diff --git a/numba/core/typing/arraydecl.py b/numba/core/typing/arraydecl.py index 629263aad90..60d4d0201c4 100644 --- a/numba/core/typing/arraydecl.py +++ b/numba/core/typing/arraydecl.py @@ -483,23 +483,6 @@ def resolve_flatten(self, ary, args, kws): assert not args return signature(ary.copy(ndim=1, layout='C')) - @bound_function("array.take") - def resolve_take(self, ary, args, kws): - if kws: - raise NumbaAssertionError("kws not supported") - argty, = args - if isinstance(argty, types.Integer): - sig = signature(ary.dtype, *args) - elif isinstance(argty, types.Array): - sig = signature(argty.copy(layout='C', dtype=ary.dtype), *args) - elif isinstance(argty, types.List): # 1d lists only - sig = signature(types.Array(ary.dtype, 1, 'C'), *args) - elif isinstance(argty, types.BaseTuple): - sig = signature(types.Array(ary.dtype, np.ndim(argty), 'C'), *args) - else: - raise TypeError("take(%s) not supported for %s" % argty) - return sig - def generic_resolve(self, ary, attr): # Resolution of other attributes, for record arrays if isinstance(ary.dtype, types.Record): diff --git a/numba/core/typing/npydecl.py b/numba/core/typing/npydecl.py index 7758c7c8623..24df235b0ee 100644 --- a/numba/core/typing/npydecl.py +++ b/numba/core/typing/npydecl.py @@ -589,30 +589,6 @@ def _infer_dtype_from_inputs(inputs): return dtype -@glue_typing(np.linspace) -class NdLinspace(AbstractTemplate): - - def generic(self, args, kws): - assert not kws - bounds = args[:2] - if not all(isinstance(arg, types.Number) for arg in bounds): - return - if len(args) >= 3: - num = args[2] - if not isinstance(num, types.Integer): - return - if len(args) >= 4: - # Not supporting the other arguments as it would require - # keyword arguments for reasonable use. - return - if any(isinstance(arg, types.Complex) for arg in bounds): - dtype = types.complex128 - else: - dtype = types.float64 - return_type = types.Array(ndim=1, dtype=dtype, layout='C') - return signature(return_type, *args) - - @glue_typing(np.frombuffer) class NdFromBuffer(CallableTemplate): @@ -1118,27 +1094,6 @@ def typer(ref, k=0): return typer -@glue_typing(np.take) -class Take(AbstractTemplate): - - def generic(self, args, kws): - if kws: - raise NumbaAssertionError("kws not supported") - if len(args) != 2: - raise NumbaAssertionError("two arguments are required") - arr, ind = args - if isinstance(ind, types.Number): - retty = arr.dtype - elif isinstance(ind, types.Array): - retty = types.Array(ndim=ind.ndim, dtype=arr.dtype, layout='C') - elif isinstance(ind, types.List): - retty = types.Array(ndim=1, dtype=arr.dtype, layout='C') - elif isinstance(ind, types.BaseTuple): - retty = types.Array(ndim=np.ndim(ind), dtype=arr.dtype, layout='C') - else: - return None - - return signature(retty, *args) # ----------------------------------------------------------------------------- # Numba helpers diff --git a/numba/np/arrayobj.py b/numba/np/arrayobj.py index 760784de31e..268f0702b1d 100644 --- a/numba/np/arrayobj.py +++ b/numba/np/arrayobj.py @@ -4389,66 +4389,51 @@ def diag_impl(arr, k=0): return impl_ret_new_ref(context, builder, sig.return_type, res) -@lower_builtin('array.take', types.Array, types.Integer) -@glue_lowering(np.take, types.Array, types.Integer) -def numpy_take_1(context, builder, sig, args): +@overload(np.take) +@overload_method(types.Array, 'take') +def numpy_take(a, indices): - def take_impl(a, indices): - if indices > (a.size - 1) or indices < -a.size: - raise IndexError("Index out of bounds") - return a.ravel()[indices] - - res = context.compile_internal(builder, take_impl, sig, args) - return impl_ret_new_ref(context, builder, sig.return_type, res) - - -@lower_builtin('array.take', types.Array, types.Array) -@glue_lowering(np.take, types.Array, types.Array) -def numpy_take_2(context, builder, sig, args): - - F_order = sig.args[1].layout == 'F' - - def take_impl(a, indices): - ret = np.empty(indices.size, dtype=a.dtype) - if F_order: - walker = indices.copy() # get C order - else: - walker = indices - it = np.nditer(walker) - i = 0 - flat = a.ravel() - for x in it: - if x > (a.size - 1) or x < -a.size: + if isinstance(a, types.Array) and isinstance(indices, types.Integer): + def take_impl(a, indices): + if indices > (a.size - 1) or indices < -a.size: raise IndexError("Index out of bounds") - ret[i] = flat[x] - i = i + 1 - return ret.reshape(indices.shape) - - res = context.compile_internal(builder, take_impl, sig, args) - return impl_ret_new_ref(context, builder, sig.return_type, res) - - -@lower_builtin('array.take', types.Array, types.List) -@glue_lowering(np.take, types.Array, types.List) -@lower_builtin('array.take', types.Array, types.BaseTuple) -@glue_lowering(np.take, types.Array, types.BaseTuple) -def numpy_take_3(context, builder, sig, args): - - def take_impl(a, indices): - convert = np.array(indices) - ret = np.empty(convert.size, dtype=a.dtype) - it = np.nditer(convert) - i = 0 - flat = a.ravel() - for x in it: - if x > (a.size - 1) or x < -a.size: - raise IndexError("Index out of bounds") - ret[i] = flat[x] - i = i + 1 - return ret.reshape(convert.shape) - - res = context.compile_internal(builder, take_impl, sig, args) - return impl_ret_new_ref(context, builder, sig.return_type, res) + return a.ravel()[indices] + return take_impl + + if all(isinstance(arg, types.Array) for arg in [a, indices]): + F_order = indices.layout == 'F' + def take_impl(a, indices): + ret = np.empty(indices.size, dtype=a.dtype) + if F_order: + walker = indices.copy() # get C order + else: + walker = indices + it = np.nditer(walker) + i = 0 + flat = a.ravel() + for x in it: + if x > (a.size - 1) or x < -a.size: + raise IndexError("Index out of bounds") + ret[i] = flat[x] + i = i + 1 + return ret.reshape(indices.shape) + return take_impl + + if isinstance(a, types.Array) and \ + isinstance(indices, (types.List, types.BaseTuple)): + def take_impl(a, indices): + convert = np.array(indices) + ret = np.empty(convert.size, dtype=a.dtype) + it = np.nditer(convert) + i = 0 + flat = a.ravel() + for x in it: + if x > (a.size - 1) or x < -a.size: + raise IndexError("Index out of bounds") + ret[i] = flat[x] + i = i + 1 + return ret.reshape(convert.shape) + return take_impl def _arange_dtype(*args): @@ -4545,21 +4530,28 @@ def impl(start, stop=None, step=None, dtype=None): return impl -@glue_lowering(np.linspace, types.Number, types.Number) -def numpy_linspace_2(context, builder, sig, args): +@overload(np.linspace) +def numpy_linspace_2(start, stop): + if isinstance(start, types.Number) and isinstance(stop, types.Number): + def impl(start, stop): + return np.linspace(start, stop, 50) + return impl - def linspace(start, stop): - return np.linspace(start, stop, 50) - res = context.compile_internal(builder, linspace, sig, args) - return impl_ret_new_ref(context, builder, sig.return_type, res) +@overload(np.linspace) +def numpy_linspace_3(start, stop, num): + if not all(isinstance(arg, types.Number) for arg in [start, stop]): + return + if not isinstance(num, types.Integer): + return -@glue_lowering(np.linspace, types.Number, types.Number, types.Integer) -def numpy_linspace_3(context, builder, sig, args): - dtype = as_dtype(sig.return_type.dtype) + if any(isinstance(arg, types.Complex) for arg in [start, stop]): + dtype = types.complex128 + else: + dtype = types.float64 + dtype = as_dtype(dtype) - # Implementation based on https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L24 # noqa: E501 def linspace(start, stop, num): arr = np.empty(num, dtype) # The multiply by 1.0 mirrors @@ -4582,9 +4574,7 @@ def linspace(start, stop, num): if num > 1: arr[-1] = stop return arr - - res = context.compile_internal(builder, linspace, sig, args) - return impl_ret_new_ref(context, builder, sig.return_type, res) + return linspace def _array_copy(context, builder, sig, args): From 1f536e764e8e11e02336411c88550c0f716e3bfa Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 9 Sep 2022 17:28:15 -0300 Subject: [PATCH 2/3] flake8 --- numba/np/arrayobj.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numba/np/arrayobj.py b/numba/np/arrayobj.py index 268f0702b1d..632acb24978 100644 --- a/numba/np/arrayobj.py +++ b/numba/np/arrayobj.py @@ -4402,6 +4402,7 @@ def take_impl(a, indices): if all(isinstance(arg, types.Array) for arg in [a, indices]): F_order = indices.layout == 'F' + def take_impl(a, indices): ret = np.empty(indices.size, dtype=a.dtype) if F_order: From 3f0a1586221153e9d7764d8680464995b4dafde9 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 12 Sep 2022 14:27:42 -0300 Subject: [PATCH 3/3] address reviewer comments --- numba/np/arrayobj.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/numba/np/arrayobj.py b/numba/np/arrayobj.py index 632acb24978..53a824518f7 100644 --- a/numba/np/arrayobj.py +++ b/numba/np/arrayobj.py @@ -4532,28 +4532,21 @@ def impl(start, stop=None, step=None, dtype=None): @overload(np.linspace) -def numpy_linspace_2(start, stop): - if isinstance(start, types.Number) and isinstance(stop, types.Number): - def impl(start, stop): - return np.linspace(start, stop, 50) - return impl - - -@overload(np.linspace) -def numpy_linspace_3(start, stop, num): +def numpy_linspace(start, stop, num=50): if not all(isinstance(arg, types.Number) for arg in [start, stop]): return - if not isinstance(num, types.Integer): - return + if not isinstance(num, (int, types.Integer)): + msg = 'The argument "num" must be an integer' + raise errors.TypingError(msg) if any(isinstance(arg, types.Complex) for arg in [start, stop]): dtype = types.complex128 else: dtype = types.float64 - dtype = as_dtype(dtype) - def linspace(start, stop, num): + # Implementation based on https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L24 # noqa: E501 + def linspace(start, stop, num=50): arr = np.empty(num, dtype) # The multiply by 1.0 mirrors # https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/function_base.py#L125-L128 # noqa: E501