2
2
from typing import Union
3
3
4
4
import pytest
5
- from hypothesis import given
5
+ from hypothesis import given , assume
6
6
from hypothesis import strategies as st
7
7
8
8
from . import _array_module as xp
@@ -34,6 +34,8 @@ def _float_match_complex(complex_dtype):
34
34
data = st .data (),
35
35
)
36
36
def test_astype (x_dtype , dtype , kw , data ):
37
+ _complex_dtypes = (xp .complex64 , xp .complex128 )
38
+
37
39
if xp .bool in (x_dtype , dtype ):
38
40
elements_strat = hh .from_dtype (x_dtype )
39
41
else :
@@ -46,12 +48,12 @@ def test_astype(x_dtype, dtype, kw, data):
46
48
cast = float
47
49
48
50
real_dtype = x_dtype
49
- if x_dtype in ( xp . complex64 , xp . complex128 ) :
51
+ if x_dtype in _complex_dtypes :
50
52
real_dtype = _float_match_complex (x_dtype )
51
53
m1 , M1 = dh .dtype_ranges [real_dtype ]
52
54
53
55
real_dtype = dtype
54
- if dtype in ( xp . complex64 , xp . complex128 ) :
56
+ if dtype in _complex_dtypes :
55
57
real_dtype = _float_match_complex (x_dtype )
56
58
m2 , M2 = dh .dtype_ranges [real_dtype ]
57
59
@@ -69,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
69
71
hh .arrays (dtype = x_dtype , shape = hh .shapes (), elements = elements_strat ), label = "x"
70
72
)
71
73
74
+ # according to the spec, "Casting a complex floating-point array to a real-valued
75
+ # data type should not be permitted."
76
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77
+ assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
78
+
72
79
out = xp .astype (x , dtype , ** kw )
73
80
74
81
ph .assert_kw_dtype ("astype" , kw_dtype = dtype , out_dtype = out .dtype )
0 commit comments