diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9906f97..1da7603 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -1104,6 +1104,7 @@ def __imod__(self, other: Array | float, /) -> Array: """ Performs the operation __imod__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__imod__") if other is NotImplemented: return other @@ -1126,6 +1127,7 @@ def __imul__(self, other: Array | complex, /) -> Array: """ Performs the operation __imul__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__imul__") if other is NotImplemented: return other @@ -1148,6 +1150,7 @@ def __ior__(self, other: Array | int, /) -> Array: """ Performs the operation __ior__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") if other is NotImplemented: return other @@ -1170,6 +1173,7 @@ def __ipow__(self, other: Array | complex, /) -> Array: """ Performs the operation __ipow__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__ipow__") if other is NotImplemented: return other @@ -1182,6 +1186,7 @@ def __rpow__(self, other: Array | complex, /) -> Array: """ from ._elementwise_functions import pow # type: ignore[attr-defined] + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other @@ -1193,6 +1198,7 @@ def __irshift__(self, other: Array | int, /) -> Array: """ Performs the operation __irshift__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__irshift__") if other is NotImplemented: return other @@ -1215,6 +1221,7 @@ def __isub__(self, other: Array | complex, /) -> Array: """ Performs the operation __isub__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__isub__") if other is NotImplemented: return other @@ -1237,6 +1244,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __itruediv__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") if other is NotImplemented: return other @@ -1259,6 +1267,7 @@ def __ixor__(self, other: Array | int, /) -> Array: """ Performs the operation __ixor__. """ + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") if other is NotImplemented: return other diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index bb4263c..de52f4c 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -344,6 +344,13 @@ def _array_vals(): getattr(x, _op)(y) else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) + # finally, test that array op ndarray raises + # XXX: as long as there is __array__ or __buffer__, __rop__s + # still return ndarrays + if not _op.startswith("__r"): + with assert_raises(TypeError): + getattr(x, _op)(y._array) + for op, dtypes in unary_op_dtypes.items(): for a in _array_vals():