Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, nan_to_num, one_hot, pad
from ._delegation import (
argpartition,
isclose,
nan_to_num,
one_hot,
pad,
partition,
)
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand All @@ -23,6 +30,7 @@
__all__ = [
"__version__",
"apply_where",
"argpartition",
"at",
"atleast_nd",
"broadcast_shapes",
Expand All @@ -37,6 +45,7 @@
"nunique",
"one_hot",
"pad",
"partition",
"setdiff1d",
"sinc",
]
156 changes: 155 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._helpers import asarrays, eager_shape
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
Expand Down Expand Up @@ -326,3 +326,157 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def partition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Return a partitioned copy of an array.

Parameters
----------
a : 1-dimensional array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is -1 (the last axis).
If None, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
partitioned_array
Array of the same type and shape as a.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
(size,) = eager_shape(a, axis)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.partition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

# Get smallest `kth` elements along axis
kth += 1 # HACK: we use a non-specified behavior of torch.topk:
# in `a_left`, the element in the last position is the max
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I would rather not rely on undocumented behaviour. Is there an alternative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair ^^

Three options:

  • add an assert a_left.max() == a_left[k]
  • We can just re-run the same logic with kth=1 and largest=True. Impact on perfs is probably 10 to 100% slower depending on the input. But it doens't add a lot of logic
  • We can do a if a_left.max() != a_left[k]: swap_max_with_last_element(a_left, axis=-1) => requires to implement swap_max_with_last_element (and the equivalent for argsort).

I vote for 1 because I'm lazy but I like perf :p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: wait I need to rethink something about numpy.partition specs...

Copy link
Contributor Author

@cakedev0 cakedev0 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So! I rewrote entirely this section, it now relies on torch.kthvalue and is very aligned with numpy's behavior.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

I will maybe open an issue on numpy to ask for some clarification.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a side note: the description of the behavior of the partition function in numpy is fairly blurry when the k-th element has duplicates... In practice, numpy does a tree-way partitioning: <, == and >. I reproduced this behavior in my new torch implementation, but jax doesn't (I tried to test the tree-way partitioning and jax fails it...).

It might be worth contributing this consideration to the array API spec discussion:


# Build a mask to remove the selected elements
mask_right = xp.ones(a.shape, dtype=bool)
mask_right.scatter_(dim=-1, index=indices, value=False)

# Remaining elements along axis
a_right = a[mask_right] # 1-d array

# Reshape. This is valid only because we work on the last axis
a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1))

# Concatenate the two parts along axis
partitioned_array = xp.cat((a_left, a_right), dim=-1)
if not (axis == -1 or axis == a.ndim - 1):
partitioned_array = xp.transpose(partitioned_array, axis, -1)
return partitioned_array

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.partition(a, kth, axis=axis, xp=xp)


def argpartition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Perform an indirect partition along the given axis.

Parameters
----------
a : Array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is -1 (the last axis).
If None, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
index_array
Array of indices that partition `a` along the specified axis.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if is_pydata_sparse_namespace(xp):
msg = "Not implemented for sparse backend: no argsort"
raise NotImplementedError(msg)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
(size,) = eager_shape(a, axis)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.argpartition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
# see `partition` above for commented details of those steps:
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

kth += 1 # HACK
_, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False)

mask_right = xp.ones(a.shape, dtype=bool)
mask_right.scatter_(dim=-1, index=indices_left, value=False)

indices_right = xp.nonzero(mask_right)[-1]
indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1))

# Concatenate the two parts along axis
index_array = xp.cat((indices_left, indices_right), dim=-1)
if not (axis == -1 or axis == a.ndim - 1):
index_array = xp.transpose(index_array, axis, -1)
return index_array

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, axis=axis, xp=xp)
24 changes: 24 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,27 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def partition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.sort(x, axis=axis, stable=False)


def argpartition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, axis=axis, stable=False)
13 changes: 11 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,31 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]:
yield *i, j


def eager_shape(x: Array, /) -> tuple[int, ...]:
def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]:
"""
Return shape of an array. Raise if shape is not fully defined.

Parameters
----------
x : Array
Input array.
axis : int, optional
If provided, only returns the tuple (shape[axis],).

Returns
-------
tuple[int, ...]
Shape of the array.
"""
shape = x.shape
# Dask arrays uses non-standard NaN instead of None
if axis is not None:
s = shape[axis]
# Dask arrays uses non-standard NaN instead of None
if s is None or math.isnan(s):
msg = f"Unsupported lazy shape for axis {axis}"
raise TypeError(msg)
return (s,)

if any(s is None or math.isnan(s) for s in shape):
msg = "Unsupported lazy shape"
raise TypeError(msg)
Expand Down
Loading