forked from data-apis/array-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindexing_functions.py
37 lines (26 loc) · 1.48 KB
/
indexing_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
__all__ = ["take"]
from ._types import Union, Optional, array
def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array:
"""
Returns elements of an array along an axis.
.. note::
Conceptually, ``take(x, indices, axis=3)`` is equivalent to ``x[:,:,:,indices,...]``; however, explicit indexing via arrays of indices is not currently supported in this specification due to concerns regarding ``__setitem__`` and array mutation semantics.
Parameters
----------
x: array
input array.
indices: array
array indices. The array must be one-dimensional and have an integer data type.
.. note::
This specification does not require bounds checking. The behavior for out-of-bounds indices is left unspecified.
axis: int
axis over which to select values. If ``axis`` is negative, the function must determine the axis along which to select values by counting from the last dimension.
If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required.
Returns
-------
out: array
an array having the same data type as ``x``. The output array must have the same rank (i.e., number of dimensions) as ``x`` and must have the same shape as ``x``, except for the axis specified by ``axis`` whose size must equal the number of elements in ``indices``.
Notes
-----
.. versionadded:: 2022.12
"""