18
18
import inspect
19
19
import warnings
20
20
21
- def _is_jax_zero_gradient_array (x ) :
21
+ def _is_jax_zero_gradient_array (x : object ) -> bool :
22
22
"""Return True if `x` is a zero-gradient array.
23
23
24
24
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
32
32
33
33
return isinstance (x , np .ndarray ) and x .dtype == jax .float0
34
34
35
- def is_numpy_array (x ):
35
+
36
+ def is_numpy_array (x : object ) -> bool :
36
37
"""
37
38
Return True if `x` is a NumPy array.
38
39
@@ -63,7 +64,8 @@ def is_numpy_array(x):
63
64
return (isinstance (x , (np .ndarray , np .generic ))
64
65
and not _is_jax_zero_gradient_array (x ))
65
66
66
- def is_cupy_array (x ):
67
+
68
+ def is_cupy_array (x : object ) -> bool :
67
69
"""
68
70
Return True if `x` is a CuPy array.
69
71
@@ -93,7 +95,8 @@ def is_cupy_array(x):
93
95
# TODO: Should we reject ndarray subclasses?
94
96
return isinstance (x , cp .ndarray )
95
97
96
- def is_torch_array (x ):
98
+
99
+ def is_torch_array (x : object ) -> bool :
97
100
"""
98
101
Return True if `x` is a PyTorch tensor.
99
102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120
123
# TODO: Should we reject ndarray subclasses?
121
124
return isinstance (x , torch .Tensor )
122
125
123
- def is_ndonnx_array (x ):
126
+
127
+ def is_ndonnx_array (x : object ) -> bool :
124
128
"""
125
129
Return True if `x` is a ndonnx Array.
126
130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147
151
148
152
return isinstance (x , ndx .Array )
149
153
150
- def is_dask_array (x ):
154
+
155
+ def is_dask_array (x : object ) -> bool :
151
156
"""
152
157
Return True if `x` is a dask.array Array.
153
158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174
179
175
180
return isinstance (x , dask .array .Array )
176
181
177
- def is_jax_array (x ):
182
+
183
+ def is_jax_array (x : object ) -> bool :
178
184
"""
179
185
Return True if `x` is a JAX array.
180
186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202
208
203
209
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204
210
211
+
205
212
def is_pydata_sparse_array (x ) -> bool :
206
213
"""
207
214
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231
238
# TODO: Account for other backends.
232
239
return isinstance (x , sparse .SparseArray )
233
240
234
- def is_array_api_obj (x ):
241
+
242
+ def is_array_api_obj (x : object ) -> bool :
235
243
"""
236
244
Return True if `x` is an array API compatible array object.
237
245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254
262
or is_pydata_sparse_array (x ) \
255
263
or hasattr (x , '__array_namespace__' )
256
264
257
- def _compat_module_name ():
265
+
266
+ def _compat_module_name () -> str :
258
267
assert __name__ .endswith ('.common._helpers' )
259
268
return __name__ .removesuffix ('.common._helpers' )
260
269
270
+
261
271
def is_numpy_namespace (xp ) -> bool :
262
272
"""
263
273
Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278
288
"""
279
289
return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280
290
291
+
281
292
def is_cupy_namespace (xp ) -> bool :
282
293
"""
283
294
Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298
309
"""
299
310
return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300
311
312
+
301
313
def is_torch_namespace (xp ) -> bool :
302
314
"""
303
315
Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319
331
return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320
332
321
333
322
- def is_ndonnx_namespace (xp ):
334
+ def is_ndonnx_namespace (xp ) -> bool :
323
335
"""
324
336
Returns True if `xp` is an NDONNX namespace.
325
337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337
349
"""
338
350
return xp .__name__ == 'ndonnx'
339
351
340
- def is_dask_namespace (xp ):
352
+
353
+ def is_dask_namespace (xp ) -> bool :
341
354
"""
342
355
Returns True if `xp` is a Dask namespace.
343
356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357
370
"""
358
371
return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359
372
360
- def is_jax_namespace (xp ):
373
+
374
+ def is_jax_namespace (xp ) -> bool :
361
375
"""
362
376
Returns True if `xp` is a JAX namespace.
363
377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378
392
"""
379
393
return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380
394
381
- def is_pydata_sparse_namespace (xp ):
395
+
396
+ def is_pydata_sparse_namespace (xp ) -> bool :
382
397
"""
383
398
Returns True if `xp` is a pydata/sparse namespace.
384
399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396
411
"""
397
412
return xp .__name__ == 'sparse'
398
413
399
- def is_array_api_strict_namespace (xp ):
414
+
415
+ def is_array_api_strict_namespace (xp ) -> bool :
400
416
"""
401
417
Returns True if `xp` is an array-api-strict namespace.
402
418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414
430
"""
415
431
return xp .__name__ == 'array_api_strict'
416
432
417
- def _check_api_version (api_version ):
433
+
434
+ def _check_api_version (api_version : str ) -> None :
418
435
if api_version in ['2021.12' , '2022.12' ]:
419
436
warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420
437
elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421
438
'2023.12' ]:
422
439
raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423
440
441
+
424
442
def array_namespace (* xs , api_version = None , use_compat = None ):
425
443
"""
426
444
Get the array API compatible namespace for the arrays `xs`.
@@ -631,13 +649,9 @@ def device(x: Array, /) -> Device:
631
649
return "cpu"
632
650
elif is_dask_array (x ):
633
651
# Peek at the metadata of the jax array to determine type
634
- try :
635
- import numpy as np
636
- if isinstance (x ._meta , np .ndarray ):
637
- # Must be on CPU since backed by numpy
638
- return "cpu"
639
- except ImportError :
640
- pass
652
+ if is_numpy_array (x ._meta ):
653
+ # Must be on CPU since backed by numpy
654
+ return "cpu"
641
655
return _DASK_DEVICE
642
656
elif is_jax_array (x ):
643
657
# JAX has .device() as a method, but it is being deprecated so that it
@@ -788,24 +802,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788
802
return x .to_device (device , stream = stream )
789
803
790
804
791
- def size (x ) :
805
+ def size (x : Array ) -> int | None :
792
806
"""
793
807
Return the total number of elements of x.
794
808
795
809
This is equivalent to `x.size` according to the `standard
796
810
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
811
+
797
812
This helper is included because PyTorch defines `size` in an
798
813
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
799
-
814
+ It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
815
+ the standard requires None.
800
816
"""
817
+ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
801
818
if None in x .shape :
802
819
return None
803
- return math .prod (x .shape )
820
+ out = math .prod (x .shape )
821
+ # dask.array.Array.shape can contain NaN
822
+ return None if math .isnan (out ) else out
804
823
805
824
806
- def is_writeable_array (x ) -> bool :
825
+ def is_writeable_array (x : object ) -> bool :
807
826
"""
808
827
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
828
+ Return False if `x` is not an array API compatible object.
809
829
810
830
Warning
811
831
-------
@@ -816,7 +836,67 @@ def is_writeable_array(x) -> bool:
816
836
return x .flags .writeable
817
837
if is_jax_array (x ) or is_pydata_sparse_array (x ):
818
838
return False
819
- return True
839
+ return is_array_api_obj (x )
840
+
841
+
842
+ def is_lazy_array (x : object ) -> bool :
843
+ """Return True if x is potentially a future or it may be otherwise impossible or
844
+ expensive to eagerly read its contents, regardless of their size, e.g. by
845
+ calling ``bool(x)`` or ``float(x)``.
846
+
847
+ Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
848
+ cheap as long as the array has the right dtype and size.
849
+
850
+ Note
851
+ ----
852
+ This function errs on the side of caution for array types that may or may not be
853
+ lazy, e.g. JAX arrays, by always returning True for them.
854
+ """
855
+ if (
856
+ is_numpy_array (x )
857
+ or is_cupy_array (x )
858
+ or is_torch_array (x )
859
+ or is_pydata_sparse_array (x )
860
+ ):
861
+ return False
862
+
863
+ # **JAX note:** while it is possible to determine if you're inside or outside
864
+ # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
865
+ # as we do below for unknown arrays, this is not recommended by JAX best practices.
866
+
867
+ # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
868
+ # This behaviour, while impossible to change without breaking backwards
869
+ # compatibility, is highly detrimental to performance as the whole graph will end
870
+ # up being computed multiple times.
871
+
872
+ if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
873
+ return True
874
+
875
+ if not is_array_api_obj (x ):
876
+ return False
877
+
878
+ # Unknown Array API compatible object. Note that this test may have dire consequences
879
+ # in terms of performance, e.g. for a lazy object that eagerly computes the graph
880
+ # on __bool__ (dask is one such example, which however is special-cased above).
881
+
882
+ # Select a single point of the array
883
+ s = size (x )
884
+ if s is None :
885
+ return True
886
+ xp = array_namespace (x )
887
+ if s > 1 :
888
+ x = xp .reshape (x , (- 1 ,))[0 ]
889
+ # Cast to dtype=bool and deal with size 0 arrays
890
+ x = xp .any (x )
891
+
892
+ try :
893
+ bool (x )
894
+ return False
895
+ # The Array API standard dictactes that __bool__ should raise TypeError if the
896
+ # output cannot be defined.
897
+ # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
898
+ except Exception :
899
+ return True
820
900
821
901
822
902
__all__ = [
@@ -840,6 +920,7 @@ def is_writeable_array(x) -> bool:
840
920
"is_pydata_sparse_array" ,
841
921
"is_pydata_sparse_namespace" ,
842
922
"is_writeable_array" ,
923
+ "is_lazy_array" ,
843
924
"size" ,
844
925
"to_device" ,
845
926
]
0 commit comments