Skip to content

Conversation

cakedev0
Copy link
Contributor

@cakedev0 cakedev0 commented Oct 1, 2025

Closes #448

Supports nd-arrays.

The crux of this PR is the torch part: transforming kthvalue outputs to a partition/argpartition. The rest is mostly wrappers and checks.

@lucascolley lucascolley changed the title Add partition and argpartition functions ENH: add partition and argpartition functions Oct 1, 2025
@lucascolley lucascolley added enhancement New feature or request new function labels Oct 1, 2025
@lucascolley lucascolley added this to the 0.9.1 milestone Oct 1, 2025
@lucascolley
Copy link
Member

I don't know how to handle the sparse backend (is_pydata_sparse_namespace): it has no argsort...

see

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @cakedev0 !

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @cakedev0 !

Looks like there is also a merge conflict now.

Comment on lines 382 to 384
kth += 1 # HACK: we use a non-specified behavior of torch.topk:
# in `a_left`, the element in the last position is the max
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I would rather not rely on undocumented behaviour. Is there an alternative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair ^^

Three options:

  • add an assert a_left.max() == a_left[k]
  • We can just re-run the same logic with kth=1 and largest=True. Impact on perfs is probably 10 to 100% slower depending on the input. But it doens't add a lot of logic
  • We can do a if a_left.max() != a_left[k]: swap_max_with_last_element(a_left, axis=-1) => requires to implement swap_max_with_last_element (and the equivalent for argsort).

I vote for 1 because I'm lazy but I like perf :p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: wait I need to rethink something about numpy.partition specs...

Copy link
Contributor Author

@cakedev0 cakedev0 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So! I rewrote entirely this section, it now relies on torch.kthvalue and is very aligned with numpy's behavior.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

I will maybe open an issue on numpy to ask for some clarification.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

It might be worth contributing this consideration to the array API spec discussion:

@ogrisel ogrisel mentioned this pull request Oct 3, 2025
Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @cakedev0, looks close!

@cakedev0
Copy link
Contributor Author

cakedev0 commented Oct 3, 2025

Thanks for the reactive and helpful reviews @lucascolley

Sorry for the numpy docs style details, I'll make sure to read the doc about this carefully before opening another PR 😉

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @cakedev0, let's merge it! And thanks for taking a look too @ogrisel .

Would be great to follow-up on #449 (comment).

Are you interested in taking over gh-341 next?

@lucascolley lucascolley merged commit 6ba1e87 into data-apis:main Oct 3, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Adding partition and argpartition?
3 participants