-
Notifications
You must be signed in to change notification settings - Fork 16
ENH: add partition
and argpartition
functions
#449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
51ade21
81b8ac3
45121c5
74c509f
6efc73a
6e69083
259f93d
c88e93e
579b3bc
c2827da
8d5af47
b2a567e
ac19a23
2a3fd9f
1084835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
is_torch_namespace, | ||
) | ||
from ._lib._utils._compat import device as get_device | ||
from ._lib._utils._helpers import asarrays | ||
from ._lib._utils._helpers import asarrays, eager_shape | ||
from ._lib._utils._typing import Array, DType | ||
|
||
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] | ||
|
@@ -326,3 +326,157 @@ def pad( | |
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] | ||
|
||
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) | ||
|
||
|
||
def partition( | ||
a: Array, | ||
kth: int, | ||
/, | ||
axis: int | None = -1, | ||
*, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Return a partitioned copy of an array. | ||
cakedev0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
a : 1-dimensional array | ||
cakedev0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Input array. | ||
kth : int | ||
Element index to partition by. | ||
axis : int, optional | ||
Axis along which to partition. The default is -1 (the last axis). | ||
cakedev0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
If None, the flattened array is used. | ||
cakedev0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
|
||
Returns | ||
------- | ||
partitioned_array | ||
Array of the same type and shape as a. | ||
cakedev0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
# Validate inputs. | ||
if xp is None: | ||
xp = array_namespace(a) | ||
if a.ndim < 1: | ||
msg = "`a` must be at least 1-dimensional" | ||
raise TypeError(msg) | ||
if axis is None: | ||
return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) | ||
(size,) = eager_shape(a, axis) | ||
if not (0 <= kth < size): | ||
msg = f"kth(={kth}) out of bounds [0 {size})" | ||
raise ValueError(msg) | ||
|
||
# Delegate where possible. | ||
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): | ||
return xp.partition(a, kth, axis=axis) | ||
|
||
# Use top-k when possible: | ||
if is_torch_namespace(xp): | ||
if not (axis == -1 or axis == a.ndim - 1): | ||
a = xp.transpose(a, axis, -1) | ||
|
||
# Get smallest `kth` elements along axis | ||
kth += 1 # HACK: we use a non-specified behavior of torch.topk: | ||
# in `a_left`, the element in the last position is the max | ||
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False) | ||
|
||
|
||
# Build a mask to remove the selected elements | ||
mask_right = xp.ones(a.shape, dtype=bool) | ||
mask_right.scatter_(dim=-1, index=indices, value=False) | ||
|
||
# Remaining elements along axis | ||
a_right = a[mask_right] # 1-d array | ||
|
||
# Reshape. This is valid only because we work on the last axis | ||
a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1)) | ||
|
||
# Concatenate the two parts along axis | ||
partitioned_array = xp.cat((a_left, a_right), dim=-1) | ||
if not (axis == -1 or axis == a.ndim - 1): | ||
partitioned_array = xp.transpose(partitioned_array, axis, -1) | ||
return partitioned_array | ||
|
||
# Note: dask topk/argtopk sort the return values, so it's | ||
# not much more efficient than sorting everything when | ||
# kth is not small compared to x.size | ||
|
||
return _funcs.partition(a, kth, axis=axis, xp=xp) | ||
|
||
|
||
def argpartition( | ||
a: Array, | ||
kth: int, | ||
/, | ||
axis: int | None = -1, | ||
*, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Perform an indirect partition along the given axis. | ||
|
||
Parameters | ||
---------- | ||
a : Array | ||
Input array. | ||
kth : int | ||
Element index to partition by. | ||
axis : int, optional | ||
Axis along which to partition. The default is -1 (the last axis). | ||
If None, the flattened array is used. | ||
cakedev0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
|
||
Returns | ||
------- | ||
index_array | ||
Array of indices that partition `a` along the specified axis. | ||
""" | ||
# Validate inputs. | ||
if xp is None: | ||
xp = array_namespace(a) | ||
if is_pydata_sparse_namespace(xp): | ||
msg = "Not implemented for sparse backend: no argsort" | ||
raise NotImplementedError(msg) | ||
if a.ndim < 1: | ||
msg = "`a` must be at least 1-dimensional" | ||
raise TypeError(msg) | ||
if axis is None: | ||
return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) | ||
(size,) = eager_shape(a, axis) | ||
if not (0 <= kth < size): | ||
msg = f"kth(={kth}) out of bounds [0 {size})" | ||
raise ValueError(msg) | ||
|
||
# Delegate where possible. | ||
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): | ||
return xp.argpartition(a, kth, axis=axis) | ||
|
||
# Use top-k when possible: | ||
if is_torch_namespace(xp): | ||
# see `partition` above for commented details of those steps: | ||
if not (axis == -1 or axis == a.ndim - 1): | ||
a = xp.transpose(a, axis, -1) | ||
|
||
kth += 1 # HACK | ||
_, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False) | ||
|
||
mask_right = xp.ones(a.shape, dtype=bool) | ||
mask_right.scatter_(dim=-1, index=indices_left, value=False) | ||
|
||
indices_right = xp.nonzero(mask_right)[-1] | ||
indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1)) | ||
|
||
# Concatenate the two parts along axis | ||
index_array = xp.cat((indices_left, indices_right), dim=-1) | ||
if not (axis == -1 or axis == a.ndim - 1): | ||
index_array = xp.transpose(index_array, axis, -1) | ||
return index_array | ||
|
||
# Note: dask topk/argtopk sort the return values, so it's | ||
# not much more efficient than sorting everything when | ||
# kth is not small compared to x.size | ||
|
||
return _funcs.argpartition(a, kth, axis=axis, xp=xp) |
Uh oh!
There was an error while loading. Please reload this page.