Skip to content

Commit f82c7bc

Browse files
authored
Merge pull request #189 from asmeurer/2022-fix
Some fixes for v2022.12
2 parents fb49802 + 9064d5d commit f82c7bc

6 files changed

+30
-13
lines changed

Diff for: array_api_tests/dtype_helpers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,18 @@ def is_int_dtype(dtype):
157157
return dtype in all_int_dtypes
158158

159159

160-
def is_float_dtype(dtype):
160+
def is_float_dtype(dtype, *, include_complex=True):
161161
# None equals NumPy's xp.float64 object, so we specifically check it here.
162162
# xp.float64 is in fact an alias of np.dtype('float64'), and its equality
163163
# with None is meant to be deprecated at some point.
164164
# See https://github.com/numpy/numpy/issues/18434
165165
if dtype is None:
166166
return False
167167
valid_dtypes = real_float_dtypes
168-
if api_version > "2021.12":
168+
if api_version > "2021.12" and include_complex:
169169
valid_dtypes += complex_dtypes
170170
return dtype in valid_dtypes
171171

172-
173172
def get_scalar_type(dtype: DataType) -> ScalarType:
174173
if dtype in all_int_dtypes:
175174
return int

Diff for: array_api_tests/pytest_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def assert_array_elements(
464464
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
465465
f"{f_func}"
466466
)
467-
_assert_float_element(at_out.real, at_expected.real, msg)
468-
_assert_float_element(at_out.imag, at_expected.imag, msg)
467+
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
468+
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
469469
else:
470470
assert xp.all(
471471
out == expected

Diff for: array_api_tests/test_data_type_functions.py

+2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def test_finfo(dtype_name):
162162
assert isinstance(
163163
value, stype
164164
), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}"
165+
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
165166
# TODO: test values
166167

167168

@@ -179,6 +180,7 @@ def test_iinfo(dtype_name):
179180
assert isinstance(
180181
value, int
181182
), f"type(out.{attr})={type(value)!r}, but should be int {f_func}"
183+
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
182184
# TODO: test values
183185

184186

Diff for: array_api_tests/test_indexing_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_take(x, data):
4949
f_axis_idx = sh.fmt_idx("x", axis_idx)
5050
for i in _indices:
5151
f_take_idx = sh.fmt_idx(f_axis_idx, i)
52-
indexed_x = x[axis_idx][i]
52+
indexed_x = x[axis_idx][i, ...]
5353
for at_idx in sh.ndindex(indexed_x.shape):
5454
out_idx = next(out_indices)
5555
ph.assert_0d_equals(

Diff for: array_api_tests/test_set_functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_unique_all(x):
112112

113113
if dh.is_float_dtype(out.values.dtype):
114114
assume(math.prod(x.shape) <= 128) # may not be representable
115-
expected = sum(v for k, v in counts.items() if math.isnan(k))
115+
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
116116
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
117117

118118

@@ -137,7 +137,7 @@ def test_unique_counts(x):
137137
for idx in sh.ndindex(out.values.shape):
138138
val = scalar_type(out.values[idx])
139139
count = int(out.counts[idx])
140-
if math.isnan(val):
140+
if cmath.isnan(val):
141141
nans += 1
142142
assert count == 1, (
143143
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
@@ -159,7 +159,7 @@ def test_unique_counts(x):
159159
vals_idx[val] = idx
160160
if dh.is_float_dtype(out.values.dtype):
161161
assume(math.prod(x.shape) <= 128) # may not be representable
162-
expected = sum(v for k, v in counts.items() if math.isnan(k))
162+
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
163163
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
164164

165165

@@ -188,7 +188,7 @@ def test_unique_inverse(x):
188188
nans = 0
189189
for idx in sh.ndindex(out.values.shape):
190190
val = scalar_type(out.values[idx])
191-
if math.isnan(val):
191+
if cmath.isnan(val):
192192
nans += 1
193193
else:
194194
assert (

Diff for: array_api_tests/test_statistical_functions.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
14-
from . import xps
14+
from . import xps, api_version
1515
from ._array_module import _UndefinedStub
1616
from .typing import DataType
1717

@@ -145,11 +145,19 @@ def test_prod(x, data):
145145
_dtype = x.dtype
146146
else:
147147
_dtype = default_dtype
148-
else:
148+
elif dh.is_float_dtype(x.dtype, include_complex=False):
149149
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
150150
_dtype = x.dtype
151151
else:
152152
_dtype = dh.default_float
153+
elif api_version > "2021.12":
154+
# Complex dtype
155+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
156+
_dtype = x.dtype
157+
else:
158+
_dtype = dh.default_complex
159+
else:
160+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
153161
else:
154162
_dtype = dtype
155163
if _dtype is None:
@@ -253,11 +261,19 @@ def test_sum(x, data):
253261
_dtype = x.dtype
254262
else:
255263
_dtype = default_dtype
256-
else:
264+
elif dh.is_float_dtype(x.dtype, include_complex=False):
257265
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
258266
_dtype = x.dtype
259267
else:
260268
_dtype = dh.default_float
269+
elif api_version > "2021.12":
270+
# Complex dtype
271+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
272+
_dtype = x.dtype
273+
else:
274+
_dtype = dh.default_complex
275+
else:
276+
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
261277
else:
262278
_dtype = dtype
263279
if _dtype is None:

0 commit comments

Comments
 (0)