-
Notifications
You must be signed in to change notification settings - Fork 16
ENH: add partition
and argpartition
functions
#449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
partition
and argpartition
functions
see array-api-extra/tests/test_funcs.py Line 1183 in ca20f03
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @cakedev0 !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @cakedev0 !
Looks like there is also a merge conflict now.
src/array_api_extra/_delegation.py
Outdated
kth += 1 # HACK: we use a non-specified behavior of torch.topk: | ||
# in `a_left`, the element in the last position is the max | ||
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, I would rather not rely on undocumented behaviour. Is there an alternative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair ^^
Three options:
- add an
assert a_left.max() == a_left[k]
- We can just re-run the same logic with
kth=1
andlargest=True
. Impact on perfs is probably 10 to 100% slower depending on the input. But it doens't add a lot of logic - We can do a
if a_left.max() != a_left[k]: swap_max_with_last_element(a_left, axis=-1)
=> requires to implementswap_max_with_last_element
(and the equivalent for argsort).
I vote for 1 because I'm lazy but I like perf :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: wait I need to rethink something about numpy.partition specs...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So! I rewrote entirely this section, it now relies on torch.kthvalue
and is very aligned with numpy's behavior.
On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).
I will maybe open an issue on numpy to ask for some clarification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).
It might be worth contributing this consideration to the array API spec discussion:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @cakedev0, looks close!
Thanks for the reactive and helpful reviews @lucascolley Sorry for the numpy docs style details, I'll make sure to read the doc about this carefully before opening another PR 😉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @cakedev0, let's merge it! And thanks for taking a look too @ogrisel .
Would be great to follow-up on #449 (comment).
Are you interested in taking over gh-341 next?
Closes #448
Supports nd-arrays.
The crux of this PR is the torch part: transforming
kthvalue
outputs to a partition/argpartition. The rest is mostly wrappers and checks.