-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST: test binary operators vs. numpy generics #145
base: main
Are you sure you want to change the base?
Conversation
) | ||
def test_binary_operators_vs_numpy_float(op): | ||
""" | ||
np.float64 is a subclass of float and must be allowed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to admit I'm not a fan of this. This is technically correct, it is a subclass. That said, it is an implementation detail IMO.
I'd think that all instances where numpy returns a numpy scalar, all other libraries return 0D arrays, so I'd think we better always handle them together with arrays not python scalars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you saying that
x = xp.asarray(1, dtype=xp.float32)
y = np.float64(2)
z = xp.maximum(x, y)
maximum
should crash instead of returning an Array[float32]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely not crash :-).
I think the following is correct[ because np.float64(2)
is conceptually a np.asarray(2.)
not python scalar 2.)
In [4]: xp = array_namespace(np.empty(2))
In [5]: xp.maximum(xp.asarray(1, dtype=xp.float32), 2)
Out[5]: np.float32(2.0)
In [6]: xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Out[6]: np.float64(2.0)
Likewise, the following is also correct, for the same reason:
In [12]: xpt = array_namespace(torch.empty(2))
In [13]: xpt.maximum(xpt.asarray(1, dtype=xpt.float32), np.float64(2))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[13], line 1
----> 1 xpt.maximum(xpt.asarray(1, dtype=xpt.float32), np.float64(2))
File ~/repos/array-api-compat/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
88 @_wraps(f)
89 def _f(x1, x2, /, **kwargs):
90 x1, x2 = _fix_promotion(x1, x2)
---> 91 return f(x1, x2, **kwargs)
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not numpy.float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [6]: xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Out[6]: np.float64(2.0)
This looks definitely wrong to me. array-api-strict shouldn't quietly accept a numpy array.
Right now the output is the expected one without explicitly disallowing inheritance:
>>> xp.maximum(xp.asarray(1, dtype=xp.float32), np.float64(2))
Array(2., dtype=array_api_strict.float32)
>>> xp.maximum(xp.asarray(1, dtype=xp.float32), np.asarray(2, dtype=np.float64))
TypeError: Only real numeric dtypes are allowed in maximum(...). Got array_api_strict.float32 and float64.
>>> xp.maximum(xp.asarray(1, dtype=xp.int32), np.int64(2))
TypeError: Only real numeric dtypes are allowed in maximum(...). Got array_api_strict.int32 and int64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the message of the TypeError is a bit misleading; it should call out explicitly something like "only Array and python scalars are accepted; got numpy.ndarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks definitely wrong to me. array-api-strict shouldn't quietly accept a numpy array.
Right now the output is the expected one without explicitly disallowing inheritance:
Yes, in my example xp
is array_api_compat.numpy
. As your example shows, array-api-strict does the right thing in main, too. So what is this PR fixing?
In #135, binary operators vs. numpy.int64 became (correctly) disallowed.
Add a test for it.
Also test binary operators vs. numpy.float64 and numpy.complex128, which must be allowed as they are subclasses of float and complex respectively.