Skip to content

Commit 37bbb58

Browse files
authored
Merge pull request #196 from honno/test-fft
`test_fft.py`
2 parents b5ed713 + dc2d4b9 commit 37bbb58

7 files changed

+344
-31
lines changed

array_api_tests/_array_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __repr__(self):
6363
_constants = ["e", "inf", "nan", "pi"]
6464
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
6565
_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
66-
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
66+
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]
6767

6868
for attr in _top_level_attrs:
6969
try:

array_api_tests/hypothesis_helpers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
integers, just, lists, none, one_of,
1212
sampled_from, shared)
1313

14-
from . import _array_module as xp
14+
from . import _array_module as xp, api_version
1515
from . import dtype_helpers as dh
1616
from . import shape_helpers as sh
1717
from . import xps
@@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape
141141
return OnewayBroadcastableShapes(input_shape, result_shape)
142142

143143

144+
def all_floating_dtypes() -> SearchStrategy[DataType]:
145+
strat = xps.floating_dtypes()
146+
if api_version >= "2022.12":
147+
strat |= xps.complex_dtypes()
148+
return strat
149+
150+
144151
# shared() allows us to draw either the function or the function name and they
145152
# will both correspond to the same function.
146153

array_api_tests/pytest_helpers.py

+13
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def assert_dtype(
122122
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123123
124124
"""
125+
__tracebackhide__ = True
125126
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
126127
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
127128
f_out_dtype = dh.dtype_to_name[out_dtype]
@@ -149,6 +150,7 @@ def assert_kw_dtype(
149150
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
150151
151152
"""
153+
__tracebackhide__ = True
152154
f_kw_dtype = dh.dtype_to_name[kw_dtype]
153155
f_out_dtype = dh.dtype_to_name[out_dtype]
154156
msg = (
@@ -166,6 +168,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
166168
>>> assert_default_float('ones', out.dtype)
167169
168170
"""
171+
__tracebackhide__ = True
169172
f_dtype = dh.dtype_to_name[out_dtype]
170173
f_default = dh.dtype_to_name[dh.default_float]
171174
msg = (
@@ -183,6 +186,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
183186
>>> assert_default_complex('asarray', out.dtype)
184187
185188
"""
189+
__tracebackhide__ = True
186190
f_dtype = dh.dtype_to_name[out_dtype]
187191
f_default = dh.dtype_to_name[dh.default_complex]
188192
msg = (
@@ -200,6 +204,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
200204
>>> assert_default_int('full', out.dtype)
201205
202206
"""
207+
__tracebackhide__ = True
203208
f_dtype = dh.dtype_to_name[out_dtype]
204209
f_default = dh.dtype_to_name[dh.default_int]
205210
msg = (
@@ -217,6 +222,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
217222
>>> assert_default_int('argmax', out.dtype)
218223
219224
"""
225+
__tracebackhide__ = True
220226
f_dtype = dh.dtype_to_name[out_dtype]
221227
msg = (
222228
f"{repr_name}={f_dtype}, should be the default index dtype, "
@@ -240,6 +246,7 @@ def assert_shape(
240246
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
241247
242248
"""
249+
__tracebackhide__ = True
243250
if isinstance(out_shape, int):
244251
out_shape = (out_shape,)
245252
if isinstance(expected, int):
@@ -273,6 +280,7 @@ def assert_result_shape(
273280
>>> assert out.shape == (3, 3)
274281
275282
"""
283+
__tracebackhide__ = True
276284
if expected is None:
277285
expected = sh.broadcast_shapes(*in_shapes)
278286
f_in_shapes = " . ".join(str(s) for s in in_shapes)
@@ -307,6 +315,7 @@ def assert_keepdimable_shape(
307315
>>> assert out2.shape == (1, 1)
308316
309317
"""
318+
__tracebackhide__ = True
310319
if keepdims:
311320
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
312321
else:
@@ -337,6 +346,7 @@ def assert_0d_equals(
337346
>>> assert res[0] == x[0]
338347
339348
"""
349+
__tracebackhide__ = True
340350
msg = (
341351
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
342352
f"[{func_name}({fmt_kw(kw)})]"
@@ -369,6 +379,7 @@ def assert_scalar_equals(
369379
>>> assert int(out) == 5
370380
371381
"""
382+
__tracebackhide__ = True
372383
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
373384
f_func = f"{func_name}({fmt_kw(kw)})"
374385
if type_ in [bool, int]:
@@ -401,6 +412,7 @@ def assert_fill(
401412
>>> assert xp.all(out == 42)
402413
403414
"""
415+
__tracebackhide__ = True
404416
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
405417
if cmath.isnan(fill_value):
406418
assert xp.all(xp.isnan(out)), msg
@@ -443,6 +455,7 @@ def assert_array_elements(
443455
>>> assert xp.all(out == x)
444456
445457
"""
458+
__tracebackhide__ = True
446459
dh.result_type(out.dtype, expected.dtype) # sanity check
447460
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
448461
f_func = f"[{func_name}({fmt_kw(kw)})]"

array_api_tests/shape_helpers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Iterator, List, Optional, Tuple, Union
3+
from typing import Iterator, List, Optional, Sequence, Tuple, Union
44

55
from ndindex import iter_indices as _iter_indices
66

@@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape):
6666

6767

6868
def normalise_axis(
69-
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
69+
axis: Optional[Union[int, Sequence[int]]], ndim: int
7070
) -> Tuple[int, ...]:
7171
if axis is None:
7272
return tuple(range(ndim))
73+
elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
74+
axis = tuple(axis)
7375
axes = axis if isinstance(axis, tuple) else (axis,)
7476
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
7577
return axes

array_api_tests/stubs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
all_funcs.extend(funcs)
5353
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
5454

55-
EXTENSIONS: str = ["linalg"]
55+
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
5656
extension_to_funcs: Dict[str, List[FunctionType]] = {}
5757
for ext in EXTENSIONS:
5858
mod = name_to_mod[ext]

0 commit comments

Comments
 (0)