Skip to content

Commit e8a9164

Browse files
authored
Merge pull request #4 from asmeurer/update-ci
Fix array-api-tests job
2 parents 52175c4 + cc8a438 commit e8a9164

File tree

5 files changed

+51
-17
lines changed

5 files changed

+51
-17
lines changed

.github/workflows/array-api-tests.yml

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
name: Array API Tests
22

3+
on: [push, pull_request]
4+
35
env:
4-
PYTEST_ARGS: "-v -rxXfE --ci"
6+
PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline"
57

68
jobs:
7-
tests:
9+
array-api-tests:
810
runs-on: ubuntu-latest
911
strategy:
1012
matrix:
@@ -37,13 +39,14 @@ jobs:
3739
else
3840
python -m pip install 'numpy>=1.26,<2.0';
3941
fi
42+
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
43+
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
4044
- name: Run the array API testsuite
4145
env:
4246
ARRAY_API_TESTS_MODULE: array_api_strict
4347
# This enables the NEP 50 type promotion behavior (without it a lot of
4448
# tests fail in numpy 1.26 on bad scalar type promotion behavior)
4549
NPY_PROMOTION_STATE: weak
4650
run: |
47-
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
4851
cd ${GITHUB_WORKSPACE}/array-api-tests
49-
pytest array_api_tests/ --xfails-file ${PYTEST_ARGS}
52+
pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS}

array-api-tests-xfails.txt

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# copy=False is not yet implemented
2+
# https://github.com/numpy/numpy/pull/25168
3+
array_api_tests/test_creation_functions.py::test_asarray_arrays
4+
5+
# Some fft tests are currently wrong
6+
# (https://github.com/data-apis/array-api-tests/issues/231)
7+
array_api_tests/test_fft.py::test_fft
8+
array_api_tests/test_fft.py::test_ifft
9+
array_api_tests/test_fft.py::test_fftn
10+
array_api_tests/test_fft.py::test_ifftn
11+
array_api_tests/test_fft.py::test_rfft
12+
array_api_tests/test_fft.py::test_irfft
13+
array_api_tests/test_fft.py::test_rfftn
14+
array_api_tests/test_fft.py::test_irfftn
15+
array_api_tests/test_fft.py::test_hfft
16+
array_api_tests/test_fft.py::test_ihfft
17+
18+
# Known special case issue in NumPy. Not worth working around here
19+
# https://github.com/numpy/numpy/issues/21213
20+
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
21+
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
22+
23+
# The test suite is incorrectly checking sums that have loss of significance
24+
# (https://github.com/data-apis/array-api-tests/issues/168)
25+
array_api_tests/test_statistical_functions.py::test_sum

array_api_strict/_array_object.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
particular, type promotion rules are different (the standard has no
66
value-based casting). The standard also specifies a more limited subset of
77
array methods and functionalities than are implemented on ndarray. Since the
8-
goal of the array_api namespace is to be a minimal implementation of the array
9-
API standard, we need to define a separate wrapper class for the array_api
8+
goal of the array_api_strict namespace is to be a minimal implementation of the array
9+
API standard, we need to define a separate wrapper class for the array_api_strict
1010
namespace.
1111
1212
The standard compliant class is only a wrapper class. It is *not* a subclass
@@ -73,7 +73,7 @@ def _new(cls, x, /):
7373
This is a private method for initializing the array API Array
7474
object.
7575
76-
Functions outside of the array_api submodule should not use this
76+
Functions outside of the array_api_strict module should not use this
7777
method. Use one of the creation functions instead, such as
7878
``asarray``.
7979
@@ -86,7 +86,7 @@ def _new(cls, x, /):
8686
_dtype = _DType(x.dtype)
8787
if _dtype not in _all_dtypes:
8888
raise TypeError(
89-
f"The array_api namespace does not support the dtype '{x.dtype}'"
89+
f"The array_api_strict namespace does not support the dtype '{x.dtype}'"
9090
)
9191
obj._array = x
9292
obj._dtype = _dtype
@@ -95,7 +95,7 @@ def _new(cls, x, /):
9595
# Prevent Array() from working
9696
def __new__(cls, *args, **kwargs):
9797
raise TypeError(
98-
"The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
98+
"The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
9999
)
100100

101101
# These functions are not required by the spec, but are implemented for
@@ -121,7 +121,7 @@ def __repr__(self: Array, /) -> str:
121121
return prefix + mid + suffix
122122

123123
# This function is not required by the spec, but we implement it here for
124-
# convenience so that np.asarray(np.array_api.Array) will work.
124+
# convenience so that np.asarray(array_api_strict.Array) will work.
125125
def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
126126
"""
127127
Warning: this method is NOT part of the array API spec. Implementers
@@ -338,7 +338,7 @@ def _validate_index(self, key):
338338
if i is not None:
339339
nonexpanding_key.append(i)
340340
if isinstance(i, np.ndarray):
341-
raise IndexError("Index arrays for np.array_api must be np.array_api arrays")
341+
raise IndexError("Index arrays for array_api_strict must be array_api_strict arrays")
342342
if isinstance(i, Array):
343343
if i.dtype in _boolean_dtypes:
344344
key_has_mask = True
@@ -471,7 +471,7 @@ def __array_namespace__(
471471
if api_version is not None and not api_version.startswith("2021."):
472472
raise ValueError(f"Unrecognized array API version: {api_version!r}")
473473
import array_api_strict
474-
return array_api
474+
return array_api_strict
475475

476476
def __bool__(self: Array, /) -> bool:
477477
"""
@@ -571,7 +571,7 @@ def __getitem__(
571571
# docstring of _validate_index
572572
self._validate_index(key)
573573
if isinstance(key, Array):
574-
# Indexing self._array with array_api arrays can be erroneous
574+
# Indexing self._array with array_api_strict arrays can be erroneous
575575
key = key._array
576576
res = self._array.__getitem__(key)
577577
return self._new(res)
@@ -761,7 +761,7 @@ def __setitem__(
761761
# docstring of _validate_index
762762
self._validate_index(key)
763763
if isinstance(key, Array):
764-
# Indexing self._array with array_api arrays can be erroneous
764+
# Indexing self._array with array_api_strict arrays can be erroneous
765765
key = key._array
766766
self._array.__setitem__(key, asarray(value)._array)
767767

array_api_strict/_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, np_dtype):
1111
self._np_dtype = np_dtype
1212

1313
def __repr__(self):
14-
return f"np.array_api.{self._np_dtype.name}"
14+
return f"array_api_strict.{self._np_dtype.name}"
1515

1616
def __eq__(self, other):
1717
# See https://github.com/numpy/numpy/pull/25370/files#r1423259515.

array_api_strict/linalg.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,16 @@ def slogdet(x: Array, /) -> SlogdetResult:
305305
# To workaround this, the below is the code from np.linalg.solve except
306306
# only calling solve1 in the exactly 1D case.
307307
def _solve(a, b):
308-
from numpy.linalg._linalg import (
308+
try:
309+
from numpy.linalg._linalg import (
309310
_makearray, _assert_stacked_2d, _assert_stacked_square,
310311
_commonType, isComplexType, _raise_linalgerror_singular
311-
)
312+
)
313+
except ImportError:
314+
from numpy.linalg.linalg import (
315+
_makearray, _assert_stacked_2d, _assert_stacked_square,
316+
_commonType, isComplexType, _raise_linalgerror_singular
317+
)
312318
from numpy.linalg import _umath_linalg
313319

314320
a, _ = _makearray(a)

0 commit comments

Comments
 (0)