diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0de6b8a..630d669 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -234,6 +234,8 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): @@ -1066,6 +1068,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __imod__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__imod__") if other is NotImplemented: return other @@ -1088,6 +1091,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __imul__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__imul__") if other is NotImplemented: return other @@ -1110,6 +1114,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __ior__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") if other is NotImplemented: return other @@ -1132,6 +1137,7 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ipow__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__ipow__") if other is NotImplemented: return other @@ -1144,6 +1150,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: """ from ._elementwise_functions import pow + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other @@ -1155,6 +1162,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __irshift__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__irshift__") if other is NotImplemented: return other @@ -1177,6 +1185,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __isub__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__isub__") if other is NotImplemented: return other @@ -1199,6 +1208,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: """ Performs the operation __itruediv__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") if other is NotImplemented: return other @@ -1221,6 +1231,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __ixor__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") if other is NotImplemented: return other diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 8f185f0..d74cf4f 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -212,6 +212,14 @@ def _array_vals(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # finally, test that array op ndarray raises + # XXX: as long as there is __array__, __rop__s still + # return ndarrays + if not _op.startswith("__r"): + with assert_raises(TypeError): + getattr(x, _op)(y._array) + + unary_op_dtypes = { "__abs__": "numeric", "__invert__": "integer_or_boolean",