Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: test binary operators vs. numpy generics #145

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 2, 2025

In #135, binary operators vs. numpy.int64 became (correctly) disallowed.
Add a test for it.
Also test binary operators vs. numpy.float64 and numpy.complex128, which must be allowed as they are subclasses of float and complex respectively.

)
def test_binary_operators_vs_numpy_float(op):
"""
np.float64 is a subclass of float and must be allowed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to admit I'm not a fan of this. This is technically correct, it is a subclass. That said, it is an implementation detail IMO.

I'd think that all instances where numpy returns a numpy scalar, all other libraries return 0D arrays, so I'd think we better always handle them together with arrays not python scalars.

Copy link
Contributor Author

@crusaderky crusaderky Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying that

x = xp.asarray(1, dtype=xp.float32)
y = np.float64(2)
z = xp.maximum(x, y)

maximum should crash instead of returning an Array[float32]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely not crash :-).

I think the following is correct[ because np.float64(2) is conceptually a np.asarray(2.) not python scalar 2.)

In [4]: xp = array_namespace(np.empty(2))

In [5]: xp.maximum(xp.asarray(1, dtype=xp.float32), 2)
Out[5]: np.float32(2.0)

In [6]: xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Out[6]: np.float64(2.0)

Likewise, the following is also correct, for the same reason:

In [12]: xpt = array_namespace(torch.empty(2))

In [13]: xpt.maximum(xpt.asarray(1, dtype=xpt.float32), np.float64(2))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[13], line 1
----> 1 xpt.maximum(xpt.asarray(1, dtype=xpt.float32), np.float64(2))

File ~/repos/array-api-compat/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
     88 @_wraps(f)
     89 def _f(x1, x2, /, **kwargs):
     90     x1, x2 = _fix_promotion(x1, x2)
---> 91     return f(x1, x2, **kwargs)

TypeError: maximum(): argument 'other' (position 2) must be Tensor, not numpy.float64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [6]: xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Out[6]: np.float64(2.0)

This looks definitely wrong to me. array-api-strict shouldn't quietly accept a numpy array.
Right now the output is the expected one without explicitly disallowing inheritance:

>>> xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Array(2., dtype=array_api_strict.float32)
>>> xp.maximum(xp.asarray(1, dtype=xp.float32), np.asarray(2, dtype=np.float64))
TypeError: Only real numeric dtypes are allowed in maximum(...). Got array_api_strict.float32 and float64.
>>> xp.maximum(xp.asarray(1, dtype=xp.int32), np.int64(2))
TypeError: Only real numeric dtypes are allowed in maximum(...). Got array_api_strict.int32 and int64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the message of the TypeError is a bit misleading; it should call out explicitly something like "only Array and python scalars are accepted; got numpy.ndarray

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks definitely wrong to me. array-api-strict shouldn't quietly accept a numpy array.
Right now the output is the expected one without explicitly disallowing inheritance:

Yes, in my example xp is array_api_compat.numpy. As your example shows, array-api-strict does the right thing in main, too. So what is this PR fixing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants