Skip to content

Commit a448710

Browse files
committed
torch: allow python scalars in result_type
1 parent bfe3fcc commit a448710

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

array_api_compat/torch/_aliases.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def _fix_promotion(x1, x2, only_scalar=True):
119119
x1 = x1.to(dtype)
120120
return x1, x2
121121

122+
123+
_torch_dtype_and_py_scalars = (torch.dtype, bool, int, float, complex)
124+
122125
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
123126
if len(arrays_and_dtypes) == 0:
124127
raise TypeError("At least one array or dtype must be provided")
@@ -140,8 +143,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
140143
# This doesn't result_type(dtype, dtype) for non-array API dtypes
141144
# because torch.result_type only accepts tensors. This does however, allow
142145
# cross-kind promotion.
143-
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
144-
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
146+
x = torch.tensor([], dtype=x) if isinstance(x, _torch_dtype_and_py_scalars) else x
147+
y = torch.tensor([], dtype=y) if isinstance(y, _torch_dtype_and_py_scalars) else y
145148
return torch.result_type(x, y)
146149

147150
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:

0 commit comments

Comments
 (0)