Skip to content

Commit 05b46ac

Browse files
committed
ENH: allow python scalars as inputs to result_type
1 parent 61bf3c1 commit 05b46ac

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

array_api_strict/_data_type_functions.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def isdtype(
197197
else:
198198
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
199199

200-
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
200+
def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype:
201201
"""
202202
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
203203
@@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
208208
# too many extra type promotions like int64 + uint64 -> float64, and does
209209
# value-based casting on scalar arrays.
210210
A = []
211+
scalars = []
211212
for a in arrays_and_dtypes:
212213
if isinstance(a, Array):
213214
a = a.dtype
215+
elif isinstance(a, (bool, int, float, complex)):
216+
scalars.append(a)
214217
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
215218
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
216219
A.append(a)
217220

221+
# remove python scalars
222+
A = [a for a in A if not isinstance(a, (bool, int, float, complex))]
223+
218224
if len(A) == 0:
219225
raise ValueError("at least one array or dtype is required")
220226
elif len(A) == 1:
221-
return A[0]
227+
result = A[0]
222228
else:
223229
t = A[0]
224230
for t2 in A[1:]:
225231
t = _result_type(t, t2)
226-
return t
232+
result = t
233+
234+
if len(scalars) == 0:
235+
return result
236+
237+
if get_array_api_strict_flags()['api_version'] <= '2023.12':
238+
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
239+
240+
# promote python scalars given the result_type for all arrays/dtypes
241+
from ._creation_functions import empty
242+
arr = empty(1, dtype=result)
243+
for s in scalars:
244+
x = arr._promote_scalar(s)
245+
result = _result_type(x.dtype, result)
246+
247+
return result

array_api_strict/tests/test_data_type_functions.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import numpy as np
77

88
from .._creation_functions import asarray
9-
from .._data_type_functions import astype, can_cast, isdtype
9+
from .._data_type_functions import astype, can_cast, isdtype, result_type
1010
from .._dtypes import (
11-
bool, int8, int16, uint8, float64,
11+
bool, int8, int16, uint8, float64, int64
1212
)
1313
from .._flags import set_array_api_strict_flags
1414

@@ -70,3 +70,22 @@ def astype_device(api_version):
7070
else:
7171
pytest.raises(TypeError, lambda: astype(a, int8, device=None))
7272
pytest.raises(TypeError, lambda: astype(a, int8, device=a.device))
73+
74+
75+
@pytest.mark.parametrize("api_version", ['2023.12', '2024.12'])
76+
def test_result_type_py_scalars(api_version):
77+
if api_version <= '2023.12':
78+
set_array_api_strict_flags(api_version=api_version)
79+
80+
with pytest.raises(TypeError):
81+
result_type(int16, 3)
82+
else:
83+
with pytest.warns(UserWarning):
84+
set_array_api_strict_flags(api_version=api_version)
85+
86+
assert result_type(int8, 3) == int8
87+
assert result_type(uint8, 3) == uint8
88+
assert result_type(float64, 3) == float64
89+
90+
with pytest.raises(TypeError):
91+
result_type(int64, True)

0 commit comments

Comments
 (0)