Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arithmetic operations accept numpy arrays #102

Open
ev-br opened this issue Nov 27, 2024 · 3 comments · May be fixed by #115
Open

Arithmetic operations accept numpy arrays #102

ev-br opened this issue Nov 27, 2024 · 3 comments · May be fixed by #115

Comments

@ev-br
Copy link
Contributor

ev-br commented Nov 27, 2024

Supposedly, mixing array-api-strict arrays with other array types should not be allowed.

Or all of them should be allowed, but then we'd need to specify something like __array_priority__ and that opens quite a Pandora box, so I guess not?

In [5]: import numpy as np

In [6]: import array_api_strict as xp

In [7]: xp.arange(5, dtype=xp.int8) + np.arange(5, dtype=np.complex64)
Out[7]: array([0.+0.j, 2.+0.j, 4.+0.j, 6.+0.j, 8.+0.j], dtype=complex64)           # xp + np -> np !

In [8]: import torch

In [10]: xp.arange(5, dtype=xp.int8) + torch.arange(5)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 xp.arange(5, dtype=xp.int8) + torch.arange(5)

TypeError: unsupported operand type(s) for +: 'Array' and 'Tensor'

In [11]: import jax.numpy as jnp

In [12]: xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 xp.arange(5, dtype=xp.int8) + jnp.arange(5, dtype=jnp.complex64)

TypeError: unsupported operand type(s) for +: 'Array' and 'jaxlib.xla_extension.ArrayImpl'
@ev-br
Copy link
Contributor Author

ev-br commented Nov 27, 2024

The offender is https://github.com/data-apis/array-api-strict/blob/main/array_api_strict/_array_object.py#L189 :

    def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
        ...
        if isinstance(other, (int, complex, float, bool)):
            ...
        elif isinstance(other, Array):
            ....
        else:
            return NotImplemented

and then, after __add__ returns NotImplemented , control flow ends in __array__ which happily calls np.asarray(self._array).

ISTM it's best to add an explicit type check to __op__(self, other) and explicitly whitelist arrays or scalars, so that we do not depend on #67.

EDIT: Alternatively, can probably attach the _allow_array flag to the Array object and force it to False in binops.

@asmeurer
Copy link
Member

Yes, ideally we would remove __array__, which would presumably fix this. But there were issues with that, which are described in the issue (the principle issue being that array-api-strict doesn't support the buffer protocol, as discussed in a recent consortium meeting).

@ev-br
Copy link
Contributor Author

ev-br commented Nov 27, 2024

Yes, it would just disappear if we could remove __array__. But realistically we cannot, not until from_dlpack matures anyway.

Re: buffer protocol, PEP 688 was mentioned in the consortium meeting, and indeed:

>>> class X:
...    def __init__(self, a):
...      self._a = a
...    def __buffer__(self, flags):
...       print('__buffer__')
...       return memoryview(self._a)
...    def __release_buffer(self, buffer):
...       print('__release__')
...       # what now? do we `del self._a` here?
... 
>>> import numpy as np
>>> x = X(np.arange(5))
>>> np.asarray(x)
__buffer__
array([0, 1, 2, 3, 4])

The main problem is, as also mentioned offline, is that it's new in python 3.12. Meaning downstream like SciPy will only be able to use it in a couple of years at best (cf NEP 29, https://numpy.org/neps/nep-0029-deprecation_policy.html).

So it seems that our options today are either

  • do nothing, wait for dlpack or __buffer__ support, whichever comes fist;
  • block what we can, as in BUG: reject ndarrays in binary operators #103 : block xp.array op np.array and accept that we can't do a thing about np.array op xp.array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants