Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_fft.py #196

Merged
merged 13 commits into from
Nov 10, 2023
2 changes: 1 addition & 1 deletion array_api_tests/_array_module.py
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ def __repr__(self):
_constants = ["e", "inf", "nan", "pi"]
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS
_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]

for attr in _top_level_attrs:
try:
9 changes: 8 additions & 1 deletion array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
integers, just, lists, none, one_of,
sampled_from, shared)

from . import _array_module as xp
from . import _array_module as xp, api_version
from . import dtype_helpers as dh
from . import shape_helpers as sh
from . import xps
@@ -141,6 +141,13 @@ def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShape
return OnewayBroadcastableShapes(input_shape, result_shape)


def all_floating_dtypes() -> SearchStrategy[DataType]:
strat = xps.floating_dtypes()
if api_version >= "2022.12":
strat |= xps.complex_dtypes()
return strat


# shared() allows us to draw either the function or the function name and they
# will both correspond to the same function.

13 changes: 13 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -122,6 +122,7 @@ def assert_dtype(
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)

"""
__tracebackhide__ = True
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
f_out_dtype = dh.dtype_to_name[out_dtype]
@@ -149,6 +150,7 @@ def assert_kw_dtype(
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)

"""
__tracebackhide__ = True
f_kw_dtype = dh.dtype_to_name[kw_dtype]
f_out_dtype = dh.dtype_to_name[out_dtype]
msg = (
@@ -166,6 +168,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
>>> assert_default_float('ones', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_float]
msg = (
@@ -183,6 +186,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
>>> assert_default_complex('asarray', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_complex]
msg = (
@@ -200,6 +204,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
>>> assert_default_int('full', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
f_default = dh.dtype_to_name[dh.default_int]
msg = (
@@ -217,6 +222,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
>>> assert_default_int('argmax', out.dtype)

"""
__tracebackhide__ = True
f_dtype = dh.dtype_to_name[out_dtype]
msg = (
f"{repr_name}={f_dtype}, should be the default index dtype, "
@@ -240,6 +246,7 @@ def assert_shape(
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))

"""
__tracebackhide__ = True
if isinstance(out_shape, int):
out_shape = (out_shape,)
if isinstance(expected, int):
@@ -273,6 +280,7 @@ def assert_result_shape(
>>> assert out.shape == (3, 3)

"""
__tracebackhide__ = True
if expected is None:
expected = sh.broadcast_shapes(*in_shapes)
f_in_shapes = " . ".join(str(s) for s in in_shapes)
@@ -307,6 +315,7 @@ def assert_keepdimable_shape(
>>> assert out2.shape == (1, 1)

"""
__tracebackhide__ = True
if keepdims:
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
else:
@@ -337,6 +346,7 @@ def assert_0d_equals(
>>> assert res[0] == x[0]

"""
__tracebackhide__ = True
msg = (
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
f"[{func_name}({fmt_kw(kw)})]"
@@ -369,6 +379,7 @@ def assert_scalar_equals(
>>> assert int(out) == 5

"""
__tracebackhide__ = True
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
f_func = f"{func_name}({fmt_kw(kw)})"
if type_ in [bool, int]:
@@ -401,6 +412,7 @@ def assert_fill(
>>> assert xp.all(out == 42)

"""
__tracebackhide__ = True
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
if cmath.isnan(fill_value):
assert xp.all(xp.isnan(out)), msg
@@ -443,6 +455,7 @@ def assert_array_elements(
>>> assert xp.all(out == x)

"""
__tracebackhide__ = True
dh.result_type(out.dtype, expected.dtype) # sanity check
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
f_func = f"[{func_name}({fmt_kw(kw)})]"
6 changes: 4 additions & 2 deletions array_api_tests/shape_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Sequence, Tuple, Union

from ndindex import iter_indices as _iter_indices

@@ -66,10 +66,12 @@ def broadcast_shapes(*shapes: Shape):


def normalise_axis(
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
axis: Optional[Union[int, Sequence[int]]], ndim: int
) -> Tuple[int, ...]:
if axis is None:
return tuple(range(ndim))
elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
axis = tuple(axis)
axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
return axes
2 changes: 1 addition & 1 deletion array_api_tests/stubs.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@
all_funcs.extend(funcs)
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}

EXTENSIONS: str = ["linalg"]
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
extension_to_funcs: Dict[str, List[FunctionType]] = {}
for ext in EXTENSIONS:
mod = name_to_mod[ext]
298 changes: 298 additions & 0 deletions array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
import math
from typing import List, Optional
from unittest.mock import MagicMock

import pytest
from hypothesis import assume, given
from hypothesis import strategies as st

from array_api_tests.typing import Array, DataType

from . import api_version
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from ._array_module import mod as xp

pytestmark = [
pytest.mark.ci,
pytest.mark.xp_extension("fft"),
pytest.mark.min_version("2022.12"),
]


# Using xps.complex_dtypes() raises an AttributeError for 2021.12 instances of
# xps, hence this hack. TODO: figure out a better way to manage this!
if api_version < "2022.12":
xps = MagicMock(xps)

fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)


def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
size = math.prod(x.shape)
n = data.draw(
st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n"
)
axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
if size_gt_1:
_axis = x.ndim - 1 if axis == -1 else axis
assume(x.shape[_axis] > 1)
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
kwargs = data.draw(
hh.specified_kwargs(
("n", n, None),
("axis", axis, -1),
("norm", norm, "backward"),
),
label="kwargs",
)
return n, axis, norm, kwargs


def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
all_axes = list(range(x.ndim))
axes = data.draw(
st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True),
label="axes",
)
_axes = all_axes if axes is None else axes
axes_sides = [x.shape[axis] for axis in _axes]
s_strat = st.tuples(
*[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides]
)
if axes is None:
s_strat = st.none() | s_strat
s = data.draw(s_strat, label="s")
if size_gt_1:
_s = x.shape if s is None else s
for i in range(x.ndim):
if i in _axes:
side = _s[_axes.index(i)]
else:
side = x.shape[i]
assume(side > 1)
norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
kwargs = data.draw(
hh.specified_kwargs(
("s", s, None),
("axes", axes, None),
("norm", norm, "backward"),
),
label="kwargs",
)
return s, axes, norm, kwargs


def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType):
if in_dtype == xp.float32:
expected = xp.complex64
elif in_dtype == xp.float64:
expected = xp.complex128
else:
assert dh.is_float_dtype(in_dtype) # sanity check
expected = in_dtype
ph.assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
)


def assert_n_axis_shape(
func_name: str,
*,
x: Array,
n: Optional[int],
axis: int,
out: Array,
size_gt_1: bool = False,
):
_axis = len(x.shape) - 1 if axis == -1 else axis
if n is None:
if size_gt_1:
axis_side = 2 * (x.shape[_axis] - 1)
else:
axis_side = x.shape[_axis]
else:
axis_side = n
expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
ph.assert_shape(func_name, out_shape=out.shape, expected=expected)


def assert_s_axes_shape(
func_name: str,
*,
x: Array,
s: Optional[List[int]],
axes: Optional[List[int]],
out: Array,
size_gt_1: bool = False,
):
_axes = sh.normalise_axis(axes, x.ndim)
_s = x.shape if s is None else s
expected = []
for i in range(x.ndim):
if i in _axes:
side = _s[_axes.index(i)]
else:
side = x.shape[i]
expected.append(side)
if size_gt_1:
last_axis = _axes[-1]
expected[last_axis] = 2 * (expected[last_axis] - 1)
assume(expected[last_axis] > 0) # TODO: generate valid examples
ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_fft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.fft(x, **kwargs)

assert_fft_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_ifft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.ifft(x, **kwargs)

assert_fft_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_fftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.fftn(x, **kwargs)

assert_fft_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_ifftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.ifftn(x, **kwargs)

assert_fft_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)


@given(
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_rfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.rfft(x, **kwargs)

assert_fft_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("rfft", x=x, n=n, axis=axis, out=out)


@given(
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_irfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)

out = xp.fft.irfft(x, **kwargs)

assert_fft_dtype("irfft", in_dtype=x.dtype, out_dtype=out.dtype)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
axis_side = 2 * (x.shape[_axis] - 1)
else:
axis_side = n
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)


@given(
x=xps.arrays(dtype=xps.complex_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_rfftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)

out = xp.fft.rfftn(x, **kwargs)

assert_fft_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out)


@given(
x=xps.arrays(
dtype=xps.complex_dtypes(), shape=fft_shapes_strat.filter(lambda s: s[-1] > 1)
),
data=st.data(),
)
def test_irfftn(x, data):
s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data, size_gt_1=True)

out = xp.fft.irfftn(x, **kwargs)

assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype)
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True)


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_hfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)

out = xp.fft.hfft(x, **kwargs)

assert_fft_dtype("hfft", in_dtype=x.dtype, out_dtype=out.dtype)

_axis = x.ndim - 1 if axis == -1 else axis
if n is None:
axis_side = 2 * (x.shape[_axis] - 1)
else:
axis_side = n
expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)


@given(
x=xps.arrays(dtype=hh.all_floating_dtypes(), shape=fft_shapes_strat),
data=st.data(),
)
def test_ihfft(x, data):
n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)

out = xp.fft.ihfft(x, **kwargs)

assert_fft_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True)


# TODO:
# fftfreq
# rfftfreq
# fftshift
# ifftshift
45 changes: 19 additions & 26 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
@@ -33,13 +33,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
return xps.boolean_dtypes() | all_integer_dtypes()


def all_floating_dtypes() -> st.SearchStrategy[DataType]:
strat = xps.floating_dtypes()
if api_version >= "2022.12":
strat |= xps.complex_dtypes()
return strat


def mock_int_dtype(n: int, dtype: DataType) -> int:
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
nbits = dh.dtype_nbits[dtype]
@@ -714,7 +707,7 @@ def test_abs(ctx, data):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_acos(x):
out = xp.acos(x)
ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -724,7 +717,7 @@ def test_acos(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_acosh(x):
out = xp.acosh(x)
ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -748,7 +741,7 @@ def test_add(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_asin(x):
out = xp.asin(x)
ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -758,15 +751,15 @@ def test_asin(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_asinh(x):
out = xp.asinh(x)
ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("asinh", x, out, math.asinh)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_atan(x):
out = xp.atan(x)
ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -782,7 +775,7 @@ def test_atan2(x1, x2):
binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_atanh(x):
out = xp.atanh(x)
ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -932,15 +925,15 @@ def test_conj(x):
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_cos(x):
out = xp.cos(x)
ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("cos", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("cos", x, out, math.cos)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_cosh(x):
out = xp.cosh(x)
ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1001,15 +994,15 @@ def test_equal(ctx, data):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_exp(x):
out = xp.exp(x)
ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("exp", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("exp", x, out, math.exp)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_expm1(x):
out = xp.expm1(x)
ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1158,7 +1151,7 @@ def test_less_equal(ctx, data):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_log(x):
out = xp.log(x)
ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1168,7 +1161,7 @@ def test_log(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_log1p(x):
out = xp.log1p(x)
ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1178,7 +1171,7 @@ def test_log1p(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_log2(x):
out = xp.log2(x)
ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1188,7 +1181,7 @@ def test_log2(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_log10(x):
out = xp.log10(x)
ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1379,15 +1372,15 @@ def test_sign(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_sin(x):
out = xp.sin(x)
ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("sin", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("sin", x, out, math.sin)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_sinh(x):
out = xp.sinh(x)
ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1405,7 +1398,7 @@ def test_square(x):
)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_sqrt(x):
out = xp.sqrt(x)
ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype)
@@ -1429,15 +1422,15 @@ def test_subtract(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_tan(x):
out = xp.tan(x)
ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("tan", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("tan", x, out, math.tan)


@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes()))
@given(xps.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_tanh(x):
out = xp.tanh(x)
ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype)