32
32
from array_api_extra ._lib ._utils ._typing import Array , Device
33
33
from array_api_extra .testing import lazy_xp_function
34
34
35
+ from .conftest import NUMPY_VERSION
36
+
35
37
# some xp backends are untyped
36
38
# mypy: disable-error-code=no-untyped-def
37
39
48
50
lazy_xp_function (sinc , static_argnames = "xp" )
49
51
50
52
51
- NUMPY_GE2 = int (np .__version__ .split ("." )[0 ]) >= 2
52
-
53
-
54
- @pytest .mark .skip_xp_backend (
55
- Backend .SPARSE , reason = "read-only backend without .at support"
56
- )
57
53
class TestApplyWhere :
58
54
@staticmethod
59
55
def f1 (x : Array , y : Array | int = 10 ) -> Array :
@@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
153
149
xp_assert_equal (actual , xp .asarray ([100 , 12 ]))
154
150
xp_assert_equal (fill_value , xp .asarray ([100 , 200 ]))
155
151
152
+ @pytest .mark .skip_xp_backend (
153
+ Backend .ARRAY_API_STRICTEST ,
154
+ reason = "no boolean indexing -> run everywhere" ,
155
+ )
156
+ @pytest .mark .skip_xp_backend (
157
+ Backend .SPARSE ,
158
+ reason = "no indexing by sparse array -> run everywhere" ,
159
+ )
156
160
def test_dont_run_on_false (self , xp : ModuleType ):
157
161
x = xp .asarray ([1.0 , 2.0 , 0.0 ])
158
162
y = xp .asarray ([0.0 , 3.0 , 4.0 ])
@@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
192
196
y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = x )
193
197
assert get_device (y ) == device
194
198
199
+ @pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no isdtype" )
195
200
@pytest .mark .filterwarnings ("ignore::RuntimeWarning" ) # overflows, etc.
196
201
@hypothesis .settings (
197
202
# The xp and library fixtures are not regenerated between hypothesis iterations
@@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
217
222
library : Backend ,
218
223
):
219
224
if (
220
- library in (Backend .NUMPY , Backend . NUMPY_READONLY )
221
- and not NUMPY_GE2
225
+ library . like (Backend .NUMPY )
226
+ and NUMPY_VERSION < ( 2 , 0 )
222
227
and dtype is np .float32
223
228
):
224
229
pytest .xfail (reason = "NumPy 1.x dtype promotion for scalars" )
@@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType):
562
567
assert y .shape == (1 , 1 , 1 , 3 )
563
568
564
569
570
+ @pytest .mark .filterwarnings ( # array_api_strictest
571
+ "ignore:invalid value encountered:RuntimeWarning:array_api_strict"
572
+ )
565
573
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
566
574
class TestIsClose :
567
575
@pytest .mark .parametrize ("swap" , [False , True ])
@@ -680,13 +688,15 @@ def test_bool_dtype(self, xp: ModuleType):
680
688
isclose (xp .asarray (True ), b , atol = 1 ), xp .asarray ([True , True , True ])
681
689
)
682
690
691
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
683
692
def test_none_shape (self , xp : ModuleType ):
684
693
a = xp .asarray ([1 , 5 , 0 ])
685
694
b = xp .asarray ([1 , 4 , 2 ])
686
695
b = b [a < 5 ]
687
696
a = a [a < 5 ]
688
697
xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
689
698
699
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
690
700
def test_none_shape_bool (self , xp : ModuleType ):
691
701
a = xp .asarray ([True , True , False ])
692
702
b = xp .asarray ([True , False , True ])
@@ -819,8 +829,27 @@ def test_empty(self, xp: ModuleType):
819
829
a = xp .asarray ([])
820
830
xp_assert_equal (nunique (a ), xp .asarray (0 ))
821
831
822
- def test_device (self , xp : ModuleType , device : Device ):
823
- a = xp .asarray (0.0 , device = device )
832
+ def test_size1 (self , xp : ModuleType ):
833
+ a = xp .asarray ([123 ])
834
+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
835
+
836
+ def test_all_equal (self , xp : ModuleType ):
837
+ a = xp .asarray ([123 , 123 , 123 ])
838
+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
839
+
840
+ @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "No equal_nan kwarg in unique" )
841
+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "sparse#855" )
842
+ def test_nan (self , xp : ModuleType , library : Backend ):
843
+ if library .like (Backend .NUMPY ) and NUMPY_VERSION < (1 , 24 ):
844
+ pytest .xfail ("NumPy <1.24 has no equal_nan kwarg in unique" )
845
+
846
+ # Each NaN is counted separately
847
+ a = xp .asarray ([xp .nan , 123.0 , xp .nan ])
848
+ xp_assert_equal (nunique (a ), xp .asarray (3 ))
849
+
850
+ @pytest .mark .parametrize ("size" , [0 , 1 , 2 ])
851
+ def test_device (self , xp : ModuleType , device : Device , size : int ):
852
+ a = xp .asarray ([0.0 ] * size , device = device )
824
853
assert get_device (nunique (a )) == device
825
854
826
855
def test_xp (self , xp : ModuleType ):
@@ -895,6 +924,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
895
924
896
925
897
926
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
927
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "no unique_values" )
898
928
class TestSetDiff1D :
899
929
@pytest .mark .xfail_xp_backend (Backend .DASK , reason = "NaN-shaped arrays" )
900
930
@pytest .mark .xfail_xp_backend (
0 commit comments