Skip to content

Commit dd1de30

Browse files
committed
Clean up Backend tests
1 parent a8487a7 commit dd1de30

File tree

5 files changed

+32
-28
lines changed

5 files changed

+32
-28
lines changed

src/array_api_extra/_lib/_backends.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Backends with which array-api-extra interacts in delegation and testing."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Callable
46
from enum import Enum
57
from types import ModuleType
6-
from typing import cast
78

89
from ._utils import _compat
910

@@ -23,10 +24,14 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2324
corresponding to the backend.
2425
"""
2526

27+
# Use :<tag> to prevent Enum from deduplicating items with the same value
2628
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27-
ARRAY_API_STRICTEST = "array_api_strictest", _compat.is_array_api_strict_namespace
29+
ARRAY_API_STRICTEST = (
30+
"array_api_strict:strictest",
31+
_compat.is_array_api_strict_namespace,
32+
)
2833
NUMPY = "numpy", _compat.is_numpy_namespace
29-
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
34+
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
3035
CUPY = "cupy", _compat.is_cupy_namespace
3136
TORCH = "torch", _compat.is_torch_namespace
3237
DASK = "dask.array", _compat.is_dask_namespace
@@ -49,4 +54,13 @@ def __init__(
4954

5055
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5156
"""Pretty-print parameterized test names."""
52-
return cast(str, self.value)
57+
return self.name.lower()
58+
59+
@property
60+
def modname(self) -> str: # numpydoc ignore=RT01
61+
"""Module name to be imported."""
62+
return self.value.split(":")[0]
63+
64+
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
65+
"""Check if this backend uses the same module as others."""
66+
return any(self.modname == other.modname for other in others)

tests/conftest.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
T = TypeVar("T")
2020
P = ParamSpec("P")
2121

22+
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
2223
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2324

2425

@@ -43,7 +44,7 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
4344
msg = f"argument of {marker_name} must be a Backend enum"
4445
raise TypeError(msg)
4546
if library == elem:
46-
reason = library.value
47+
reason = str(library)
4748
with suppress(KeyError):
4849
reason += ":" + cast(str, marker.kwargs["reason"])
4950
skip_or_xfail(reason=reason)
@@ -116,14 +117,14 @@ def xp(
116117
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
117118
return
118119

119-
if (
120-
library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST)
121-
and np.__version__ < "1.26"
122-
):
120+
if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
123121
pytest.skip("array_api_strict is untested on NumPy <1.26")
124122

123+
xp = pytest.importorskip(library.modname)
124+
# Possibly wrap module with array_api_compat
125+
xp = array_namespace(xp.empty(0))
126+
125127
if library == Backend.ARRAY_API_STRICTEST:
126-
xp = pytest.importorskip("array_api_strict")
127128
with xp.ArrayAPIStrictFlags(
128129
boolean_indexing=False,
129130
data_dependent_shapes=False,
@@ -134,10 +135,6 @@ def xp(
134135
yield xp
135136
return
136137

137-
xp = pytest.importorskip(library.value)
138-
# Possibly wrap module with array_api_compat
139-
xp = array_namespace(xp.empty(0))
140-
141138
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
142139
# in the global scope of the module containing the test function.
143140
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
@@ -147,8 +144,6 @@ def xp(
147144

148145
# suppress unused-ignore to run mypy in -e lint as well as -e dev
149146
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
150-
yield xp
151-
return
152147

153148
yield xp
154149

tests/test_at.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,7 @@ def test_incompatible_dtype(
257257
elif library is Backend.DASK:
258258
z = at_op(x, idx, op, 1.1, copy=copy)
259259

260-
elif (
261-
library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST)
262-
and op is not _AtOp.SET
263-
):
260+
elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
264261
with pytest.raises(Exception, match=r"cast|promote|dtype"):
265262
_ = at_op(x, idx, op, 1.1, copy=copy)
266263

tests/test_funcs.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from array_api_extra._lib._utils._typing import Array, Device
3333
from array_api_extra.testing import lazy_xp_function
3434

35+
from .conftest import NUMPY_VERSION
36+
3537
# some xp backends are untyped
3638
# mypy: disable-error-code=no-untyped-def
3739

@@ -48,9 +50,6 @@
4850
lazy_xp_function(sinc, static_argnames="xp")
4951

5052

51-
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
52-
53-
5453
class TestApplyWhere:
5554
@staticmethod
5655
def f1(x: Array, y: Array | int = 10) -> Array:
@@ -197,7 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
197196
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
198197
assert get_device(y) == device
199198

200-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
199+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
201200
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
202201
@hypothesis.settings(
203202
# The xp and library fixtures are not regenerated between hypothesis iterations
@@ -223,7 +222,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
223222
library: Backend,
224223
):
225224
if (
226-
library in (Backend.NUMPY, Backend.NUMPY_READONLY)
225+
library.like(Backend.NUMPY)
227226
and NUMPY_VERSION < (2, 0)
228227
and dtype is np.float32
229228
):
@@ -843,8 +842,7 @@ def test_all_equal(self, xp: ModuleType):
843842
Backend.SPARSE, reason="Non-compliant equal_nan=True behaviour"
844843
)
845844
def test_nan(self, xp: ModuleType, library: Backend):
846-
is_numpy = library in (Backend.NUMPY, Backend.NUMPY_READONLY)
847-
if is_numpy and NUMPY_VERSION < (1, 24):
845+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
848846
pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique")
849847

850848
# Each NaN is counted separately

tests/test_testing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
223223
pytest.importorskip("scipy")
224224
assert erf is not None
225225
x = xp.asarray([6.0, 7.0])
226-
if library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST, Backend.JAX):
226+
if library.like(Backend.ARRAY_API_STRICT, Backend.JAX):
227227
# array-api-strict arrays are auto-converted to NumPy
228228
# which results in an assertion error for mismatched namespaces
229229
# eager JAX arrays are auto-converted to NumPy in eager JAX

0 commit comments

Comments
 (0)