Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: reject ndarrays in binary operators #103

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def _check_device(self, other):
elif isinstance(other, Array):
if self.device != other.device:
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should explicitly only reject ndarray. That way NotImplemented can still work on other types (although I honestly don't know how important that is).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think that mixing different array objects is a bad idea in general, and is basically impossible to make consistent. Even when it fails, there are just too many possible reasons or failure modes.
So one either buys in duck-typing all the way (and whatever quacking falls out, it does), or say no to anything which is not the same array namespace. If -strict strives to be strict, the latter makes sense, it seems.
That said, I'm happy with whatever you think is best.

In [4]: xp.arange(5) + cp.arange(5)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 xp.arange(5) + cp.arange(5)

File cupy/_core/core.pyx:1265, in cupy._core.core._ndarray_base.__add__()

File cupy/_core/_kernel.pyx:1286, in cupy._core._kernel.ufunc.__call__()

File cupy/_core/_kernel.pyx:159, in cupy._core._kernel._preprocess_args()

File cupy/_core/_kernel.pyx:145, in cupy._core._kernel._preprocess_arg()

TypeError: Unsupported type <class 'array_api_strict._array_object.Array'>

In [5]: cp.arange(3) + xp.arange(3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 cp.arange(3) + xp.arange(3)

File cupy/_core/core.pyx:1269, in cupy._core.core._ndarray_base.__add__()

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(<ufunc 'add'>, '__call__', array([0, 1, 2]), Array([0, 1, 2], dtype=array_api_strict.int64)): 'ndarray', 'Array'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So one either buys in duck-typing all the way (and whatever quacking falls out, it does), or say no to anything which is not the same array namespace. If -strict strives to be strict, the latter makes sense, it seems.

That sounds good to me.

raise TypeError(f"Cannot combine an Array with {type(other)}.")

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
Expand Down Expand Up @@ -1066,6 +1068,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imod__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
Expand All @@ -1088,6 +1091,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imul__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
if other is NotImplemented:
return other
Expand All @@ -1110,6 +1114,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __ior__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
if other is NotImplemented:
return other
Expand All @@ -1132,6 +1137,7 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
Expand All @@ -1144,6 +1150,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
from ._elementwise_functions import pow

self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
Expand All @@ -1155,6 +1162,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __irshift__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
if other is NotImplemented:
return other
Expand All @@ -1177,6 +1185,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __isub__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
if other is NotImplemented:
return other
Expand All @@ -1199,6 +1208,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array:
"""
Performs the operation __itruediv__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
if other is NotImplemented:
return other
Expand All @@ -1221,6 +1231,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __ixor__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
if other is NotImplemented:
return other
Expand Down
8 changes: 8 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def _array_vals():
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y))

# finally, test that array op ndarray raises
# XXX: as long as there is __array__, __rop__s still
# return ndarrays
if not _op.startswith("__r"):
with assert_raises(TypeError):
getattr(x, _op)(y._array)


unary_op_dtypes = {
"__abs__": "numeric",
"__invert__": "integer_or_boolean",
Expand Down
Loading