@@ -197,7 +197,7 @@ def isdtype(
197
197
else :
198
198
raise TypeError (f"'kind' must be a dtype, str, or tuple of dtypes and strs, not { type (kind ).__name__ } " )
199
199
200
- def result_type (* arrays_and_dtypes : Union [Array , Dtype ]) -> Dtype :
200
+ def result_type (* arrays_and_dtypes : Union [Array , Dtype , int , float , complex , bool ]) -> Dtype :
201
201
"""
202
202
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
203
203
@@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
208
208
# too many extra type promotions like int64 + uint64 -> float64, and does
209
209
# value-based casting on scalar arrays.
210
210
A = []
211
+ scalars = []
211
212
for a in arrays_and_dtypes :
212
213
if isinstance (a , Array ):
213
214
a = a .dtype
215
+ elif isinstance (a , (bool , int , float , complex )):
216
+ scalars .append (a )
214
217
elif isinstance (a , np .ndarray ) or a not in _all_dtypes :
215
218
raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
216
219
A .append (a )
217
220
221
+ # remove python scalars
222
+ A = [a for a in A if not isinstance (a , (bool , int , float , complex ))]
223
+
218
224
if len (A ) == 0 :
219
225
raise ValueError ("at least one array or dtype is required" )
220
226
elif len (A ) == 1 :
221
- return A [0 ]
227
+ result = A [0 ]
222
228
else :
223
229
t = A [0 ]
224
230
for t2 in A [1 :]:
225
231
t = _result_type (t , t2 )
226
- return t
232
+ result = t
233
+
234
+ if len (scalars ) == 0 :
235
+ return result
236
+
237
+ if get_array_api_strict_flags ()['api_version' ] <= '2023.12' :
238
+ raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
239
+
240
+ # promote python scalars given the result_type for all arrays/dtypes
241
+ from ._creation_functions import empty
242
+ arr = empty (1 , dtype = result )
243
+ for s in scalars :
244
+ x = arr ._promote_scalar (s )
245
+ result = _result_type (x .dtype , result )
246
+
247
+ return result
0 commit comments