-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
heat operators eq ne now allow non array operands #1773
base: main
Are you sure you want to change the base?
heat operators eq ne now allow non array operands #1773
Conversation
Thank you for the PR! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1773 +/- ##
==========================================
- Coverage 92.26% 92.25% -0.01%
==========================================
Files 84 84
Lines 12447 12453 +6
==========================================
+ Hits 11484 11489 +5
- Misses 963 964 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👏 thanks, just a few comments
@@ -33,11 +33,12 @@ | |||
] | |||
|
|||
|
|||
def eq(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: | |||
def eq(x, y) -> DNDarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hints are still needed
""" | ||
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise comparision. | ||
Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be | ||
compared as argument. | ||
Returns False if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bit too restrictive as well, for example check this out:
>>> import heat as ht
>>> import torch
>>> import numpy as np
>>> a = ht.arange(10)
>>> b = torch.arange(10)
>>> c = np.arange(10)
>>> comp = list(range(10))
>>> a == comp
False
>>> b == comp
False
>>> c == comp
array([ True, True, True, True, True, True, True, True, True,
True])
>>> (c == comp).all()
True
@@ -372,11 +378,12 @@ def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr | |||
less.__doc__ = lt.__doc__ | |||
|
|||
|
|||
def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray: | |||
def ne(x, y) -> DNDarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hints
""" | ||
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich comparison of non-equality between values from two operands, commutative. | ||
Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be | ||
compared as argument. | ||
Returns True if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above, we might need to allow more flexibility
with self.assertRaises(TypeError): | ||
ht.eq("self.a_tensor", "s") | ||
self.assertFalse(ht.eq(self.a_tensor, self.another_vector)) | ||
self.assertFalse(ht.eq(self.a_tensor, self.errorneous_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertFalse(ht.eq(self.a_tensor, self.errorneous_type)) | |
self.assertFalse(ht.eq(self.a_tensor, self.erroneous_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errorneous_type
should be correct. I get an error when running it with erroneous_type
with self.assertRaises(TypeError): | ||
ht.ne("self.a_tensor", "s") | ||
self.assertTrue(ht.ne(self.a_tensor, self.another_vector)) | ||
self.assertTrue(ht.ne(self.a_tensor, self.errorneous_type)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertTrue(ht.ne(self.a_tensor, self.errorneous_type)) | |
self.assertTrue(ht.ne(self.a_tensor, self.erroneous_type)) |
…llow_non-array_operands Claudia said so
Thank you for the PR! |
…llow_non-array_operands
Thank you for the PR! |
Due Diligence
Description
The eq and ne operands were changed to allow the comparison between not supported Types without causing an error.
Instead eq now returns False and ne returns True, when before there would have been a TypeError or ValueError.
It was also necessary to change the relations test as it checked whether these errors occurred.
Issue/s resolved: #1292
Changes proposed:
Type of change
Bug fix
Memory requirements
Performance
Does this change modify the behaviour of other functions? If so, which?
no?