Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define complex scalar <op> float array #121

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
_integer_dtypes,
_integer_or_boolean_dtypes,
_floating_dtypes,
_real_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
_dtype_categories,
_real_to_complex_map,
)
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags

Expand Down Expand Up @@ -243,6 +245,7 @@ def _promote_scalar(self, scalar):
"""
from ._data_type_functions import iinfo

target_dtype = self.dtype
# Note: Only Python scalar types that match the array dtype are
# allowed.
if isinstance(scalar, bool):
Expand All @@ -268,10 +271,13 @@ def _promote_scalar(self, scalar):
"Python float scalars can only be promoted with floating-point arrays."
)
elif isinstance(scalar, complex):
if self.dtype not in _complex_floating_dtypes:
if self.dtype not in _floating_dtypes:
raise TypeError(
"Python complex scalars can only be promoted with complex floating-point arrays."
"Python complex scalars can only be promoted with floating-point arrays."
)
# 1j * array(floating) is allowed
if self.dtype in _real_floating_dtypes:
target_dtype = _real_to_complex_map[self.dtype]
else:
raise TypeError("'scalar' must be a Python scalar")

Expand All @@ -282,7 +288,7 @@ def _promote_scalar(self, scalar):
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
# cast or raise OverflowError).
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device)
return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device)

@staticmethod
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
Expand Down
1 change: 1 addition & 0 deletions array_api_strict/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __hash__(self):
"floating-point": _floating_dtypes,
}

_real_to_complex_map = {float32: complex64, float64: complex128}

# Note: the spec defines a restricted type promotion table compared to NumPy.
# In particular, cross-kind promotions like integer + float or boolean +
Expand Down
10 changes: 9 additions & 1 deletion array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
# - a Python int or float for real floating-point array dtypes
# - a Python int, float, or complex for complex floating-point array dtypes

# an exception: complex scalar <op> floating array
scalar_types_for_float = [float, int]
if not (func_name.startswith("__i")
or (func_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]
and type(s) == complex)
):
scalar_types_for_float += [complex]

if ((dtypes == "all"
or dtypes == "numeric" and a.dtype in _numeric_dtypes
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
Expand All @@ -121,7 +129,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
# isinstance here.
and (a.dtype in _boolean_dtypes and type(s) == bool
or a.dtype in _integer_dtypes and type(s) == int
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
)):
if a.dtype in _integer_dtypes and s == BIG_INT:
Expand Down
14 changes: 12 additions & 2 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,25 @@ def _array_vals():
if nargs(func) != 2:
continue

nocomplex = [
'atan2', 'copysign', 'floor_divide', 'hypot', 'logaddexp', 'nextafter',
'remainder',
'greater', 'less', 'greater_equal', 'less_equal', 'maximum', 'minimum',
]

for s in [1, 1.0, 1j, BIG_INT, False]:
for a in _array_vals():
for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)

if func_name in nocomplex and type(s) == complex:
allowed = False
else:
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)

# only check `func(array, scalar) == `func(array, array)` if
# the former is legal under the promotion rules
if allowed:
conv_scalar = asarray(s, dtype=a.dtype)
conv_scalar = a._promote_scalar(s)

with suppress_warnings() as sup:
# ignore warnings from pow(BIG_INT)
Expand Down
Loading