Skip to content

Commit b41da86

Browse files
authored
ENH/TST: tougher restrictions on array_api_strict (#179)
* ENH: tougher restrictions on array_api_strict * Clean up Backend tests * Code review
1 parent 8d7e3a9 commit b41da86

File tree

10 files changed

+184
-57
lines changed

10 files changed

+184
-57
lines changed

src/array_api_extra/_lib/_at.py

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def _op(
344344
msg = f"Can't update read-only array {x}"
345345
raise ValueError(msg)
346346

347+
# Backends without boolean indexing (other than JAX) crash here
347348
if in_place_op: # add(), subtract(), ...
348349
x[idx] = in_place_op(x[idx], y)
349350
else: # set()

src/array_api_extra/_lib/_backends.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Backends with which array-api-extra interacts in delegation and testing."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Callable
46
from enum import Enum
57
from types import ModuleType
6-
from typing import cast
78

89
from ._utils import _compat
910

@@ -23,9 +24,14 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2324
corresponding to the backend.
2425
"""
2526

27+
# Use :<tag> to prevent Enum from deduplicating items with the same value
2628
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
29+
ARRAY_API_STRICTEST = (
30+
"array_api_strict:strictest",
31+
_compat.is_array_api_strict_namespace,
32+
)
2733
NUMPY = "numpy", _compat.is_numpy_namespace
28-
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
34+
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
2935
CUPY = "cupy", _compat.is_cupy_namespace
3036
TORCH = "torch", _compat.is_torch_namespace
3137
DASK = "dask.array", _compat.is_dask_namespace
@@ -48,4 +54,13 @@ def __init__(
4854

4955
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5056
"""Pretty-print parameterized test names."""
51-
return cast(str, self.value)
57+
return self.name.lower()
58+
59+
@property
60+
def modname(self) -> str: # numpydoc ignore=RT01
61+
"""Module name to be imported."""
62+
return self.value.split(":")[0]
63+
64+
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
65+
"""Check if this backend uses the same module as others."""
66+
return any(self.modname == other.modname for other in others)

src/array_api_extra/_lib/_funcs.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import (
12-
array_namespace,
13-
is_dask_namespace,
14-
is_jax_array,
15-
is_jax_namespace,
11+
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
12+
from ._utils._helpers import (
13+
asarrays,
14+
capabilities,
15+
eager_shape,
16+
meta_namespace,
17+
ndindex,
1618
)
17-
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
1819
from ._utils._typing import Array
1920

2021
__all__ = [
@@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
152153
) -> Array:
153154
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
154155

155-
if is_jax_namespace(xp):
156+
if not capabilities(xp)["boolean indexing"]:
156157
# jax.jit does not support assignment by boolean mask
157158
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
158159

@@ -708,14 +709,34 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
708709
# size= is JAX-specific
709710
# https://github.com/data-apis/array-api/issues/883
710711
_, counts = xp.unique_counts(x, size=_compat.size(x))
711-
return xp.astype(counts, xp.bool).sum()
712-
713-
_, counts = xp.unique_counts(x)
714-
n = _compat.size(counts)
715-
# FIXME https://github.com/data-apis/array-api-compat/pull/231
716-
if n is None: # e.g. Dask, ndonnx
717-
return xp.astype(counts, xp.bool).sum()
718-
return xp.asarray(n, device=_compat.device(x))
712+
return (counts > 0).sum()
713+
714+
# There are 3 general use cases:
715+
# 1. backend has unique_counts and it returns an array with known shape
716+
# 2. backend has unique_counts and it returns a None-sized array;
717+
# e.g. Dask, ndonnx
718+
# 3. backend does not have unique_counts; e.g. wrapped JAX
719+
if capabilities(xp)["data-dependent shapes"]:
720+
# xp has unique_counts; O(n) complexity
721+
_, counts = xp.unique_counts(x)
722+
n = _compat.size(counts)
723+
if n is None:
724+
return xp.sum(xp.ones_like(counts))
725+
return xp.asarray(n, device=_compat.device(x))
726+
727+
# xp does not have unique_counts; O(n*logn) complexity
728+
x = xp.sort(xp.reshape(x, -1))
729+
mask = x != xp.roll(x, -1)
730+
default_int = xp.__array_namespace_info__().default_dtypes(
731+
device=_compat.device(x)
732+
)["integral"]
733+
return xp.maximum(
734+
# Special cases:
735+
# - array is size 0
736+
# - array has all elements equal to each other
737+
xp.astype(xp.any(~mask), default_int),
738+
xp.sum(xp.astype(mask, default_int)),
739+
)
719740

720741

721742
def pad(

src/array_api_extra/_lib/_utils/_helpers.py

+36
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
array_namespace,
1313
is_array_api_obj,
1414
is_dask_namespace,
15+
is_jax_namespace,
1516
is_numpy_array,
17+
is_pydata_sparse_namespace,
1618
)
1719
from ._typing import Array
1820

@@ -23,6 +25,7 @@
2325

2426
__all__ = [
2527
"asarrays",
28+
"capabilities",
2629
"eager_shape",
2730
"in1d",
2831
"is_python_scalar",
@@ -270,3 +273,36 @@ def meta_namespace(
270273
# Quietly skip scalars and None's
271274
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
272275
return array_namespace(*metas)
276+
277+
278+
def capabilities(xp: ModuleType) -> dict[str, int]:
279+
"""
280+
Return patched ``xp.__array_namespace_info__().capabilities()``.
281+
282+
TODO this helper should be eventually removed once all the special cases
283+
it handles are fixed in the respective backends.
284+
285+
Parameters
286+
----------
287+
xp : array_namespace
288+
The standard-compatible namespace.
289+
290+
Returns
291+
-------
292+
dict
293+
Capabilities of the namespace.
294+
"""
295+
if is_pydata_sparse_namespace(xp):
296+
# No __array_namespace_info__(); no indexing by sparse arrays
297+
return {"boolean indexing": False, "data-dependent shapes": True}
298+
out = xp.__array_namespace_info__().capabilities()
299+
if is_jax_namespace(xp):
300+
# FIXME https://github.com/jax-ml/jax/issues/27418
301+
out = out.copy()
302+
out["boolean indexing"] = False
303+
if is_dask_namespace(xp):
304+
# FIXME https://github.com/data-apis/array-api-compat/pull/290
305+
out = out.copy()
306+
out["boolean indexing"] = True
307+
out["data-dependent shapes"] = True
308+
return out

tests/conftest.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pytest fixtures."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Generator
44
from contextlib import suppress
55
from functools import partial, wraps
66
from types import ModuleType
@@ -19,6 +19,7 @@
1919
T = TypeVar("T")
2020
P = ParamSpec("P")
2121

22+
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
2223
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2324

2425

@@ -43,7 +44,7 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
4344
msg = f"argument of {marker_name} must be a Backend enum"
4445
raise TypeError(msg)
4546
if library == elem:
46-
reason = library.value
47+
reason = str(library)
4748
with suppress(KeyError):
4849
reason += ":" + cast(str, marker.kwargs["reason"])
4950
skip_or_xfail(reason=reason)
@@ -104,7 +105,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
104105
@pytest.fixture
105106
def xp(
106107
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
107-
) -> ModuleType: # numpydoc ignore=PR01,RT03
108+
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
108109
"""
109110
Parameterized fixture that iterates on all libraries.
110111
@@ -113,25 +114,38 @@ def xp(
113114
The current array namespace.
114115
"""
115116
if library == Backend.NUMPY_READONLY:
116-
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
117-
xp = pytest.importorskip(library.value)
117+
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
118+
return
119+
120+
if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
121+
pytest.skip("array_api_strict is untested on NumPy <1.26")
122+
123+
xp = pytest.importorskip(library.modname)
118124
# Possibly wrap module with array_api_compat
119125
xp = array_namespace(xp.empty(0))
120126

127+
if library == Backend.ARRAY_API_STRICTEST:
128+
with xp.ArrayAPIStrictFlags(
129+
boolean_indexing=False,
130+
data_dependent_shapes=False,
131+
# writeable=False, # TODO implement in array-api-strict
132+
# lazy=True, # TODO implement in array-api-strict
133+
enabled_extensions=(),
134+
):
135+
yield xp
136+
return
137+
121138
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
122139
# in the global scope of the module containing the test function.
123140
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
124141

125-
if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26":
126-
pytest.skip("array_api_strict is untested on NumPy <1.26")
127-
128142
if library == Backend.JAX:
129143
import jax
130144

131145
# suppress unused-ignore to run mypy in -e lint as well as -e dev
132146
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
133147

134-
return xp
148+
yield xp
135149

136150

137151
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`

tests/test_at.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
pytestmark = [
2121
pytest.mark.skip_xp_backend(
2222
Backend.SPARSE, reason="read-only backend without .at support"
23-
)
23+
),
24+
pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing"),
2425
]
2526

2627

@@ -256,7 +257,7 @@ def test_incompatible_dtype(
256257
elif library is Backend.DASK:
257258
z = at_op(x, idx, op, 1.1, copy=copy)
258259

259-
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
260+
elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
260261
with pytest.raises(Exception, match=r"cast|promote|dtype"):
261262
_ = at_op(x, idx, op, 1.1, copy=copy)
262263

tests/test_funcs.py

+40-10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from array_api_extra._lib._utils._typing import Array, Device
3333
from array_api_extra.testing import lazy_xp_function
3434

35+
from .conftest import NUMPY_VERSION
36+
3537
# some xp backends are untyped
3638
# mypy: disable-error-code=no-untyped-def
3739

@@ -48,12 +50,6 @@
4850
lazy_xp_function(sinc, static_argnames="xp")
4951

5052

51-
NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2
52-
53-
54-
@pytest.mark.skip_xp_backend(
55-
Backend.SPARSE, reason="read-only backend without .at support"
56-
)
5753
class TestApplyWhere:
5854
@staticmethod
5955
def f1(x: Array, y: Array | int = 10) -> Array:
@@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
153149
xp_assert_equal(actual, xp.asarray([100, 12]))
154150
xp_assert_equal(fill_value, xp.asarray([100, 200]))
155151

152+
@pytest.mark.skip_xp_backend(
153+
Backend.ARRAY_API_STRICTEST,
154+
reason="no boolean indexing -> run everywhere",
155+
)
156+
@pytest.mark.skip_xp_backend(
157+
Backend.SPARSE,
158+
reason="no indexing by sparse array -> run everywhere",
159+
)
156160
def test_dont_run_on_false(self, xp: ModuleType):
157161
x = xp.asarray([1.0, 2.0, 0.0])
158162
y = xp.asarray([0.0, 3.0, 4.0])
@@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
192196
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
193197
assert get_device(y) == device
194198

199+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
195200
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
196201
@hypothesis.settings(
197202
# The xp and library fixtures are not regenerated between hypothesis iterations
@@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
217222
library: Backend,
218223
):
219224
if (
220-
library in (Backend.NUMPY, Backend.NUMPY_READONLY)
221-
and not NUMPY_GE2
225+
library.like(Backend.NUMPY)
226+
and NUMPY_VERSION < (2, 0)
222227
and dtype is np.float32
223228
):
224229
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
@@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType):
562567
assert y.shape == (1, 1, 1, 3)
563568

564569

570+
@pytest.mark.filterwarnings( # array_api_strictest
571+
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
572+
)
565573
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
566574
class TestIsClose:
567575
@pytest.mark.parametrize("swap", [False, True])
@@ -680,13 +688,15 @@ def test_bool_dtype(self, xp: ModuleType):
680688
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
681689
)
682690

691+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
683692
def test_none_shape(self, xp: ModuleType):
684693
a = xp.asarray([1, 5, 0])
685694
b = xp.asarray([1, 4, 2])
686695
b = b[a < 5]
687696
a = a[a < 5]
688697
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
689698

699+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
690700
def test_none_shape_bool(self, xp: ModuleType):
691701
a = xp.asarray([True, True, False])
692702
b = xp.asarray([True, False, True])
@@ -819,8 +829,27 @@ def test_empty(self, xp: ModuleType):
819829
a = xp.asarray([])
820830
xp_assert_equal(nunique(a), xp.asarray(0))
821831

822-
def test_device(self, xp: ModuleType, device: Device):
823-
a = xp.asarray(0.0, device=device)
832+
def test_size1(self, xp: ModuleType):
833+
a = xp.asarray([123])
834+
xp_assert_equal(nunique(a), xp.asarray(1))
835+
836+
def test_all_equal(self, xp: ModuleType):
837+
a = xp.asarray([123, 123, 123])
838+
xp_assert_equal(nunique(a), xp.asarray(1))
839+
840+
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="No equal_nan kwarg in unique")
841+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#855")
842+
def test_nan(self, xp: ModuleType, library: Backend):
843+
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
844+
pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique")
845+
846+
# Each NaN is counted separately
847+
a = xp.asarray([xp.nan, 123.0, xp.nan])
848+
xp_assert_equal(nunique(a), xp.asarray(3))
849+
850+
@pytest.mark.parametrize("size", [0, 1, 2])
851+
def test_device(self, xp: ModuleType, device: Device, size: int):
852+
a = xp.asarray([0.0] * size, device=device)
824853
assert get_device(nunique(a)) == device
825854

826855
def test_xp(self, xp: ModuleType):
@@ -895,6 +924,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
895924

896925

897926
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
927+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_values")
898928
class TestSetDiff1D:
899929
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
900930
@pytest.mark.xfail_xp_backend(

0 commit comments

Comments
 (0)