Skip to content

Commit

Permalink
TST: adapt tests for the lack of __array__
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Dec 26, 2024
1 parent a3c04ab commit ddc14d8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 32 deletions.
17 changes: 2 additions & 15 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@

import numpy as np

@contextmanager
def allow_array():
"""
Temporarily enable Array.__array__. This is needed for np.array to parse
list of lists of Array objects.
"""
from . import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = True
yield
finally:
_array_object._allow_array = original_value

def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
Expand Down Expand Up @@ -112,8 +99,8 @@ def asarray(
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
with allow_array():
res = np.array(obj, dtype=_np_dtype, copy=copy)

res = np.array(obj, dtype=_np_dtype, copy=copy)
return Array._new(res, device=device)


Expand Down
18 changes: 1 addition & 17 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,25 +361,9 @@ def test_array_conversion():

for device in ("device1", "device2"):
a = ones((2, 3), device=array_api_strict.Device(device))
with pytest.raises(RuntimeError, match="Can not convert array"):
with pytest.raises((RuntimeError, TypeError)):
asarray([a])

def test__array__():
# __array__ should work for now
a = ones((2, 3))
np.array(a)

# Test the _allow_array private global flag for disabling it in the
# future.
from .. import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = False
a = ones((2, 3))
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
np.array(a)
finally:
_array_object._allow_array = original_value

def test_allow_newaxis():
a = ones(5)
Expand Down

0 comments on commit ddc14d8

Please sign in to comment.