Skip to content

Commit 528c172

Browse files
authored
Merge branch 'main' into 2024.12
2 parents a448710 + 73f6426 commit 528c172

20 files changed

+472
-123
lines changed

array_api_compat/common/_aliases.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
233233
**kwargs,
234234
)
235235

236-
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
237-
if not copy and dtype == x.dtype:
238-
return x
239-
return x.astype(dtype=dtype, copy=copy)
240-
241236
# These functions have different keyword argument names
242237

243238
def std(
@@ -579,7 +574,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
579574
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
580575
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
581576
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
582-
'astype', 'std', 'var', 'cumulative_sum', 'cumulative_prod', 'clip', 'permute_dims',
577+
'std', 'var', 'cumulative_sum', 'cumulative_prod', 'clip', 'permute_dims',
583578
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
584579
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
585580
'unstack', 'sign']

array_api_compat/common/_helpers.py

+108-27
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import warnings
2020

21-
def _is_jax_zero_gradient_array(x):
21+
def _is_jax_zero_gradient_array(x: object) -> bool:
2222
"""Return True if `x` is a zero-gradient array.
2323
2424
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
3232

3333
return isinstance(x, np.ndarray) and x.dtype == jax.float0
3434

35-
def is_numpy_array(x):
35+
36+
def is_numpy_array(x: object) -> bool:
3637
"""
3738
Return True if `x` is a NumPy array.
3839
@@ -63,7 +64,8 @@ def is_numpy_array(x):
6364
return (isinstance(x, (np.ndarray, np.generic))
6465
and not _is_jax_zero_gradient_array(x))
6566

66-
def is_cupy_array(x):
67+
68+
def is_cupy_array(x: object) -> bool:
6769
"""
6870
Return True if `x` is a CuPy array.
6971
@@ -93,7 +95,8 @@ def is_cupy_array(x):
9395
# TODO: Should we reject ndarray subclasses?
9496
return isinstance(x, cp.ndarray)
9597

96-
def is_torch_array(x):
98+
99+
def is_torch_array(x: object) -> bool:
97100
"""
98101
Return True if `x` is a PyTorch tensor.
99102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120123
# TODO: Should we reject ndarray subclasses?
121124
return isinstance(x, torch.Tensor)
122125

123-
def is_ndonnx_array(x):
126+
127+
def is_ndonnx_array(x: object) -> bool:
124128
"""
125129
Return True if `x` is a ndonnx Array.
126130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147151

148152
return isinstance(x, ndx.Array)
149153

150-
def is_dask_array(x):
154+
155+
def is_dask_array(x: object) -> bool:
151156
"""
152157
Return True if `x` is a dask.array Array.
153158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174179

175180
return isinstance(x, dask.array.Array)
176181

177-
def is_jax_array(x):
182+
183+
def is_jax_array(x: object) -> bool:
178184
"""
179185
Return True if `x` is a JAX array.
180186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202208

203209
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204210

211+
205212
def is_pydata_sparse_array(x) -> bool:
206213
"""
207214
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231238
# TODO: Account for other backends.
232239
return isinstance(x, sparse.SparseArray)
233240

234-
def is_array_api_obj(x):
241+
242+
def is_array_api_obj(x: object) -> bool:
235243
"""
236244
Return True if `x` is an array API compatible array object.
237245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254262
or is_pydata_sparse_array(x) \
255263
or hasattr(x, '__array_namespace__')
256264

257-
def _compat_module_name():
265+
266+
def _compat_module_name() -> str:
258267
assert __name__.endswith('.common._helpers')
259268
return __name__.removesuffix('.common._helpers')
260269

270+
261271
def is_numpy_namespace(xp) -> bool:
262272
"""
263273
Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278288
"""
279289
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280290

291+
281292
def is_cupy_namespace(xp) -> bool:
282293
"""
283294
Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298309
"""
299310
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300311

312+
301313
def is_torch_namespace(xp) -> bool:
302314
"""
303315
Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319331
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320332

321333

322-
def is_ndonnx_namespace(xp):
334+
def is_ndonnx_namespace(xp) -> bool:
323335
"""
324336
Returns True if `xp` is an NDONNX namespace.
325337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337349
"""
338350
return xp.__name__ == 'ndonnx'
339351

340-
def is_dask_namespace(xp):
352+
353+
def is_dask_namespace(xp) -> bool:
341354
"""
342355
Returns True if `xp` is a Dask namespace.
343356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357370
"""
358371
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359372

360-
def is_jax_namespace(xp):
373+
374+
def is_jax_namespace(xp) -> bool:
361375
"""
362376
Returns True if `xp` is a JAX namespace.
363377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378392
"""
379393
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380394

381-
def is_pydata_sparse_namespace(xp):
395+
396+
def is_pydata_sparse_namespace(xp) -> bool:
382397
"""
383398
Returns True if `xp` is a pydata/sparse namespace.
384399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396411
"""
397412
return xp.__name__ == 'sparse'
398413

399-
def is_array_api_strict_namespace(xp):
414+
415+
def is_array_api_strict_namespace(xp) -> bool:
400416
"""
401417
Returns True if `xp` is an array-api-strict namespace.
402418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414430
"""
415431
return xp.__name__ == 'array_api_strict'
416432

417-
def _check_api_version(api_version):
433+
434+
def _check_api_version(api_version: str) -> None:
418435
if api_version in ['2021.12', '2022.12']:
419436
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
420437
elif api_version is not None and api_version not in ['2021.12', '2022.12',
421438
'2023.12']:
422439
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
423440

441+
424442
def array_namespace(*xs, api_version=None, use_compat=None):
425443
"""
426444
Get the array API compatible namespace for the arrays `xs`.
@@ -631,13 +649,9 @@ def device(x: Array, /) -> Device:
631649
return "cpu"
632650
elif is_dask_array(x):
633651
# Peek at the metadata of the jax array to determine type
634-
try:
635-
import numpy as np
636-
if isinstance(x._meta, np.ndarray):
637-
# Must be on CPU since backed by numpy
638-
return "cpu"
639-
except ImportError:
640-
pass
652+
if is_numpy_array(x._meta):
653+
# Must be on CPU since backed by numpy
654+
return "cpu"
641655
return _DASK_DEVICE
642656
elif is_jax_array(x):
643657
# JAX has .device() as a method, but it is being deprecated so that it
@@ -788,24 +802,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788802
return x.to_device(device, stream=stream)
789803

790804

791-
def size(x):
805+
def size(x: Array) -> int | None:
792806
"""
793807
Return the total number of elements of x.
794808
795809
This is equivalent to `x.size` according to the `standard
796810
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
811+
797812
This helper is included because PyTorch defines `size` in an
798813
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
799-
814+
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
815+
the standard requires None.
800816
"""
817+
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
801818
if None in x.shape:
802819
return None
803-
return math.prod(x.shape)
820+
out = math.prod(x.shape)
821+
# dask.array.Array.shape can contain NaN
822+
return None if math.isnan(out) else out
804823

805824

806-
def is_writeable_array(x) -> bool:
825+
def is_writeable_array(x: object) -> bool:
807826
"""
808827
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
828+
Return False if `x` is not an array API compatible object.
809829
810830
Warning
811831
-------
@@ -816,7 +836,67 @@ def is_writeable_array(x) -> bool:
816836
return x.flags.writeable
817837
if is_jax_array(x) or is_pydata_sparse_array(x):
818838
return False
819-
return True
839+
return is_array_api_obj(x)
840+
841+
842+
def is_lazy_array(x: object) -> bool:
843+
"""Return True if x is potentially a future or it may be otherwise impossible or
844+
expensive to eagerly read its contents, regardless of their size, e.g. by
845+
calling ``bool(x)`` or ``float(x)``.
846+
847+
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
848+
cheap as long as the array has the right dtype and size.
849+
850+
Note
851+
----
852+
This function errs on the side of caution for array types that may or may not be
853+
lazy, e.g. JAX arrays, by always returning True for them.
854+
"""
855+
if (
856+
is_numpy_array(x)
857+
or is_cupy_array(x)
858+
or is_torch_array(x)
859+
or is_pydata_sparse_array(x)
860+
):
861+
return False
862+
863+
# **JAX note:** while it is possible to determine if you're inside or outside
864+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
865+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
866+
867+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
868+
# This behaviour, while impossible to change without breaking backwards
869+
# compatibility, is highly detrimental to performance as the whole graph will end
870+
# up being computed multiple times.
871+
872+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
873+
return True
874+
875+
if not is_array_api_obj(x):
876+
return False
877+
878+
# Unknown Array API compatible object. Note that this test may have dire consequences
879+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
880+
# on __bool__ (dask is one such example, which however is special-cased above).
881+
882+
# Select a single point of the array
883+
s = size(x)
884+
if s is None:
885+
return True
886+
xp = array_namespace(x)
887+
if s > 1:
888+
x = xp.reshape(x, (-1,))[0]
889+
# Cast to dtype=bool and deal with size 0 arrays
890+
x = xp.any(x)
891+
892+
try:
893+
bool(x)
894+
return False
895+
# The Array API standard dictactes that __bool__ should raise TypeError if the
896+
# output cannot be defined.
897+
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
898+
except Exception:
899+
return True
820900

821901

822902
__all__ = [
@@ -840,6 +920,7 @@ def is_writeable_array(x) -> bool:
840920
"is_pydata_sparse_array",
841921
"is_pydata_sparse_namespace",
842922
"is_writeable_array",
923+
"is_lazy_array",
843924
"size",
844925
"to_device",
845926
]

array_api_compat/cupy/_aliases.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import cupy as cp
44

5-
from ..common import _aliases
5+
from ..common import _aliases, _helpers
66
from .._internal import get_xp
77

88
from ._info import __array_namespace_info__
@@ -46,7 +46,6 @@
4646
unique_counts = get_xp(cp)(_aliases.unique_counts)
4747
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
4848
unique_values = get_xp(cp)(_aliases.unique_values)
49-
astype = _aliases.astype
5049
std = get_xp(cp)(_aliases.std)
5150
var = get_xp(cp)(_aliases.var)
5251
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
@@ -111,6 +110,21 @@ def asarray(
111110

112111
return cp.array(obj, dtype=dtype, **kwargs)
113112

113+
114+
def astype(
115+
x: ndarray,
116+
dtype: Dtype,
117+
/,
118+
*,
119+
copy: bool = True,
120+
device: Optional[Device] = None,
121+
) -> ndarray:
122+
if device is None:
123+
return x.astype(dtype=dtype, copy=copy)
124+
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
125+
return out.copy() if copy and out is x else out
126+
127+
114128
# These functions are completely new here. If the library already has them
115129
# (i.e., numpy 2.0), use the library version instead of our wrapper.
116130
if hasattr(cp, 'vecdot'):
@@ -128,10 +142,10 @@ def asarray(
128142
else:
129143
unstack = get_xp(cp)(_aliases.unstack)
130144

131-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
145+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
132146
'acos', 'acosh', 'asin', 'asinh', 'atan',
133147
'atan2', 'atanh', 'bitwise_left_shift',
134148
'bitwise_invert', 'bitwise_right_shift',
135-
'concat', 'pow', 'sign']
149+
'bool', 'concat', 'pow', 'sign']
136150

137151
_all_ignore = ['cp', 'get_xp']

0 commit comments

Comments
 (0)