diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..afee030 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -26,10 +26,12 @@ _integer_dtypes, _integer_or_boolean_dtypes, _floating_dtypes, + _real_floating_dtypes, _complex_floating_dtypes, _numeric_dtypes, _result_type, _dtype_categories, + _real_to_complex_map, ) from ._flags import get_array_api_strict_flags, set_array_api_strict_flags @@ -243,6 +245,7 @@ def _promote_scalar(self, scalar): """ from ._data_type_functions import iinfo + target_dtype = self.dtype # Note: Only Python scalar types that match the array dtype are # allowed. if isinstance(scalar, bool): @@ -268,10 +271,13 @@ def _promote_scalar(self, scalar): "Python float scalars can only be promoted with floating-point arrays." ) elif isinstance(scalar, complex): - if self.dtype not in _complex_floating_dtypes: + if self.dtype not in _floating_dtypes: raise TypeError( - "Python complex scalars can only be promoted with complex floating-point arrays." + "Python complex scalars can only be promoted with floating-point arrays." ) + # 1j * array(floating) is allowed + if self.dtype in _real_floating_dtypes: + target_dtype = _real_to_complex_map[self.dtype] else: raise TypeError("'scalar' must be a Python scalar") @@ -282,7 +288,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device) + return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index b51ed92..66304dd 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -126,6 +126,7 @@ def __hash__(self): "floating-point": _floating_dtypes, } +_real_to_complex_map = {float32: complex64, float64: complex128} # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 4535d99..edfa073 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -108,6 +108,14 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): # - a Python int or float for real floating-point array dtypes # - a Python int, float, or complex for complex floating-point array dtypes + # an exception: complex scalar floating array + scalar_types_for_float = [float, int] + if not (func_name.startswith("__i") + or (func_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"] + and type(s) == complex) + ): + scalar_types_for_float += [complex] + if ((dtypes == "all" or dtypes == "numeric" and a.dtype in _numeric_dtypes or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes @@ -121,7 +129,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): # isinstance here. and (a.dtype in _boolean_dtypes and type(s) == bool or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 0b90f0b..cc3a2cd 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -233,15 +233,25 @@ def _array_vals(): if nargs(func) != 2: continue + nocomplex = [ + 'atan2', 'copysign', 'floor_divide', 'hypot', 'logaddexp', 'nextafter', + 'remainder', + 'greater', 'less', 'greater_equal', 'less_equal', 'maximum', 'minimum', + ] + for s in [1, 1.0, 1j, BIG_INT, False]: for a in _array_vals(): for func1 in [lambda s: func(a, s), lambda s: func(s, a)]: - allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) + + if func_name in nocomplex and type(s) == complex: + allowed = False + else: + allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) # only check `func(array, scalar) == `func(array, array)` if # the former is legal under the promotion rules if allowed: - conv_scalar = asarray(s, dtype=a.dtype) + conv_scalar = a._promote_scalar(s) with suppress_warnings() as sup: # ignore warnings from pow(BIG_INT)