diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index 899a2591..8a7c7887 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -63,7 +63,7 @@ def __repr__(self): _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] _funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout -_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS +_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"] for attr in _top_level_attrs: try: diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 31f1e153..c4235ba1 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -11,7 +11,7 @@ integers, just, lists, none, one_of, sampled_from, shared) -from . import _array_module as xp +from . import _array_module as xp, api_version from . import dtype_helpers as dh from . import shape_helpers as sh from . import xps @@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape return OnewayBroadcastableShapes(input_shape, result_shape) +def all_floating_dtypes() -> SearchStrategy[DataType]: + strat = xps.floating_dtypes() + if api_version >= "2022.12": + strat |= xps.complex_dtypes() + return strat + + # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index c51b14a6..e6ede7b2 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -122,6 +122,7 @@ def assert_dtype( >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int) """ + __tracebackhide__ = True in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] @@ -149,6 +150,7 @@ def assert_kw_dtype( >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype) """ + __tracebackhide__ = True f_kw_dtype = dh.dtype_to_name[kw_dtype] f_out_dtype = dh.dtype_to_name[out_dtype] msg = ( @@ -166,6 +168,7 @@ def assert_default_float(func_name: str, out_dtype: DataType): >>> assert_default_float('ones', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_float] msg = ( @@ -183,6 +186,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType): >>> assert_default_complex('asarray', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_complex] msg = ( @@ -200,6 +204,7 @@ def assert_default_int(func_name: str, out_dtype: DataType): >>> assert_default_int('full', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_int] msg = ( @@ -217,6 +222,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty >>> assert_default_int('argmax', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] msg = ( f"{repr_name}={f_dtype}, should be the default index dtype, " @@ -240,6 +246,7 @@ def assert_shape( >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3)) """ + __tracebackhide__ = True if isinstance(out_shape, int): out_shape = (out_shape,) if isinstance(expected, int): @@ -273,6 +280,7 @@ def assert_result_shape( >>> assert out.shape == (3, 3) """ + __tracebackhide__ = True if expected is None: expected = sh.broadcast_shapes(*in_shapes) f_in_shapes = " . ".join(str(s) for s in in_shapes) @@ -307,6 +315,7 @@ def assert_keepdimable_shape( >>> assert out2.shape == (1, 1) """ + __tracebackhide__ = True if keepdims: shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) else: @@ -337,6 +346,7 @@ def assert_0d_equals( >>> assert res[0] == x[0] """ + __tracebackhide__ = True msg = ( f"{out_repr}={out_val}, but should be {x_repr}={x_val} " f"[{func_name}({fmt_kw(kw)})]" @@ -369,6 +379,7 @@ def assert_scalar_equals( >>> assert int(out) == 5 """ + __tracebackhide__ = True repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ in [bool, int]: @@ -401,6 +412,7 @@ def assert_fill( >>> assert xp.all(out == 42) """ + __tracebackhide__ = True msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" if cmath.isnan(fill_value): assert xp.all(xp.isnan(out)), msg @@ -443,6 +455,7 @@ def assert_array_elements( >>> assert xp.all(out == x) """ + __tracebackhide__ = True dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index ba7d994e..6a0bdfde 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Tuple, Union from ndindex import iter_indices as _iter_indices @@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape): def normalise_axis( - axis: Optional[Union[int, Tuple[int, ...]]], ndim: int + axis: Optional[Union[int, Sequence[int]]], ndim: int ) -> Tuple[int, ...]: if axis is None: return tuple(range(ndim)) + elif isinstance(axis, Sequence) and not isinstance(axis, tuple): + axis = tuple(axis) axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) return axes diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 69ec886d..0134765b 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -52,7 +52,7 @@ all_funcs.extend(funcs) name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} -EXTENSIONS: str = ["linalg"] +EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py new file mode 100644 index 00000000..7dc70d56 --- /dev/null +++ b/array_api_tests/test_fft.py @@ -0,0 +1,298 @@ +import math +from typing import List, Optional +from unittest.mock import MagicMock + +import pytest +from hypothesis import assume, given +from hypothesis import strategies as st + +from array_api_tests.typing import Array, DataType + +from . import api_version +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps +from ._array_module import mod as xp + +pytestmark = [ + pytest.mark.ci, + pytest.mark.xp_extension("fft"), + pytest.mark.min_version("2022.12"), +] + + +# Using xps.complex_dtypes() raises an AttributeError for 2021.12 instances of +# xps, hence this hack. TODO: figure out a better way to manage this! +if api_version < "2022.12": + xps = MagicMock(xps) + +fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) + + +def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: + size = math.prod(x.shape) + n = data.draw( + st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n" + ) + axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") + if size_gt_1: + _axis = x.ndim - 1 if axis == -1 else axis + assume(x.shape[_axis] > 1) + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("n", n, None), + ("axis", axis, -1), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return n, axis, norm, kwargs + + +def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: + all_axes = list(range(x.ndim)) + axes = data.draw( + st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True), + label="axes", + ) + _axes = all_axes if axes is None else axes + axes_sides = [x.shape[axis] for axis in _axes] + s_strat = st.tuples( + *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides] + ) + if axes is None: + s_strat = st.none() | s_strat + s = data.draw(s_strat, label="s") + if size_gt_1: + _s = x.shape if s is None else s + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + assume(side > 1) + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("s", s, None), + ("axes", axes, None), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return s, axes, norm, kwargs + + +def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType): + if in_dtype == xp.float32: + expected = xp.complex64 + elif in_dtype == xp.float64: + expected = xp.complex128 + else: + assert dh.is_float_dtype(in_dtype) # sanity check + expected = in_dtype + ph.assert_dtype( + func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected + ) + + +def assert_n_axis_shape( + func_name: str, + *, + x: Array, + n: Optional[int], + axis: int, + out: Array, + size_gt_1: bool = False, +): + _axis = len(x.shape) - 1 if axis == -1 else axis + if n is None: + if size_gt_1: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = x.shape[_axis] + else: + axis_side = n + expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape(func_name, out_shape=out.shape, expected=expected) + + +def assert_s_axes_shape( + func_name: str, + *, + x: Array, + s: Optional[List[int]], + axes: Optional[List[int]], + out: Array, + size_gt_1: bool = False, +): + _axes = sh.normalise_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + if size_gt_1: + last_axis = _axes[-1] + expected[last_axis] = 2 * (expected[last_axis] - 1) + assume(expected[last_axis] > 0) # TODO: generate valid examples + ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected)) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_fft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.fft(x, **kwargs) + + assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ifft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.ifft(x, **kwargs) + + assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_fftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.fftn(x, **kwargs) + + assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ifftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.ifftn(x, **kwargs) + + assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_rfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.rfft(x, **kwargs) + + assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out) + + +@given( + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_irfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.irfft(x, **kwargs) + + assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape) + + +@given( + x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_rfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.rfftn(x, **kwargs) + + assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out) + + +@given( + x=xps.arrays( + dtype=xps.complex_dtypes(), shape=fft_shapes_strat.filter(lambda s: s[-1] > 1) + ), + data=st.data(), +) +def test_irfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.irfftn(x, **kwargs) + + assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_hfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.hfft(x, **kwargs) + + assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape) + + +@given( + x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat), + data=st.data(), +) +def test_ihfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.ihfft(x, **kwargs) + + assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True) + + +# TODO: +# fftfreq +# rfftfreq +# fftshift +# ifftshift diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4d803bb0..39905456 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -33,13 +33,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -def all_floating_dtypes() -> st.SearchStrategy[DataType]: - strat = xps.floating_dtypes() - if api_version >= "2022.12": - strat |= xps.complex_dtypes() - return strat - - def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -714,7 +707,7 @@ def test_abs(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) @@ -724,7 +717,7 @@ def test_acos(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -748,7 +741,7 @@ def test_add(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) @@ -758,7 +751,7 @@ def test_asin(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -766,7 +759,7 @@ def test_asinh(x): unary_assert_against_refimpl("asinh", x, out, math.asinh) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) @@ -782,7 +775,7 @@ def test_atan2(x1, x2): binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -932,7 +925,7 @@ def test_conj(x): unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) @@ -940,7 +933,7 @@ def test_cos(x): unary_assert_against_refimpl("cos", x, out, math.cos) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1001,7 +994,7 @@ def test_equal(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1009,7 +1002,7 @@ def test_exp(x): unary_assert_against_refimpl("exp", x, out, math.exp) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1158,7 +1151,7 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1168,7 +1161,7 @@ def test_log(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1178,7 +1171,7 @@ def test_log1p(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1188,7 +1181,7 @@ def test_log2(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1379,7 +1372,7 @@ def test_sign(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1387,7 +1380,7 @@ def test_sin(x): unary_assert_against_refimpl("sin", x, out, math.sin) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1405,7 +1398,7 @@ def test_square(x): ) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1429,7 +1422,7 @@ def test_subtract(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) @@ -1437,7 +1430,7 @@ def test_tan(x): unary_assert_against_refimpl("tan", x, out, math.tan) -@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype)