From d17fd2f11dc552aaf6ddec1f2495d889de2a1092 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 15:24:25 +0000 Subject: [PATCH 1/3] ENH: `pad`: add delegation --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 2 +- src/array_api_extra/_delegators.py | 59 +++++++++++++++++++ src/array_api_extra/_lib/__init__.py | 2 +- src/array_api_extra/_lib/_compat.py | 19 ------ src/array_api_extra/{ => _lib}/_funcs.py | 16 ++--- src/array_api_extra/_lib/_utils/__init__.py | 1 + src/array_api_extra/_lib/_utils/_compat.py | 31 ++++++++++ .../_lib/{ => _utils}/_compat.pyi | 4 ++ .../_lib/{_utils.py => _utils/_helpers.py} | 2 +- .../_lib/{ => _utils}/_typing.py | 0 tests/test_funcs.py | 2 +- tests/test_utils.py | 4 +- 13 files changed, 110 insertions(+), 33 deletions(-) create mode 100644 src/array_api_extra/_delegators.py delete mode 100644 src/array_api_extra/_lib/_compat.py rename src/array_api_extra/{ => _lib}/_funcs.py (98%) create mode 100644 src/array_api_extra/_lib/_utils/__init__.py create mode 100644 src/array_api_extra/_lib/_utils/_compat.py rename src/array_api_extra/_lib/{ => _utils}/_compat.pyi (70%) rename src/array_api_extra/_lib/{_utils.py => _utils/_helpers.py} (97%) rename src/array_api_extra/_lib/{ => _utils}/_typing.py (100%) diff --git a/docs/api-reference.md b/docs/api-reference.md index ffe68f2..457adef 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -11,6 +11,7 @@ create_diagonal expand_dims kron + pad setdiff1d sinc ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 83808e0..eae343e 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._funcs import ( +from ._lib._funcs import ( atleast_nd, cov, create_diagonal, diff --git a/src/array_api_extra/_delegators.py b/src/array_api_extra/_delegators.py new file mode 100644 index 0000000..118a111 --- /dev/null +++ b/src/array_api_extra/_delegators.py @@ -0,0 +1,59 @@ +"""Delegators to existing implementations for Public API Functions.""" + +from ._lib import _funcs +from ._lib._utils._compat import ( + array_namespace, + is_cupy_namespace, + is_jax_namespace, + is_numpy_namespace, + is_torch_namespace, +) +from ._lib._utils._typing import Array, ModuleType + + +def pad( + x: Array, + pad_width: int, + mode: str = "constant", + *, + constant_values: bool | int | float | complex = 0, + xp: ModuleType | None = None, +) -> Array: + """ + Pad the input array. + + Parameters + ---------- + x : array + Input array. + pad_width : int + Pad the input array with this many elements from each side. + mode : str, optional + Only "constant" mode is currently supported, which pads with + the value passed to `constant_values`. + constant_values : python scalar, optional + Use this value to pad the input. Default is zero. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + The input array, + padded with ``pad_width`` elements equal to ``constant_values``. + """ + xp = array_namespace(x) if xp is None else xp + + value = constant_values + + # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 + if is_torch_namespace(xp): + pad_width = xp.asarray(pad_width) + pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) + pad_width = xp.flip(pad_width, axis=(0,)).flatten() + return xp.nn.functional.pad(x, (pad_width,), value=value) + + if is_numpy_namespace(x) or is_jax_namespace(xp) or is_cupy_namespace(xp): + return xp.pad(x, pad_width, mode, constant_values=value) + + return _funcs.pad(x, pad_width, mode, constant_values=value, xp=xp) diff --git a/src/array_api_extra/_lib/__init__.py b/src/array_api_extra/_lib/__init__.py index d7a7952..b5e805b 100644 --- a/src/array_api_extra/_lib/__init__.py +++ b/src/array_api_extra/_lib/__init__.py @@ -1 +1 @@ -"""Modules housing private functions.""" +"""Array-agnostic implementations for the public API.""" diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py deleted file mode 100644 index de7a220..0000000 --- a/src/array_api_extra/_lib/_compat.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Acquire helpers from array-api-compat.""" -# Allow packages that vendor both `array-api-extra` and -# `array-api-compat` to override the import location - -try: - from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] - array_namespace, # pyright: ignore[reportUnknownVariableType] - device, # pyright: ignore[reportUnknownVariableType] - ) -except ImportError: - from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] - array_namespace, # pyright: ignore[reportUnknownVariableType] - device, - ) - -__all__ = [ - "array_namespace", - "device", -] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_lib/_funcs.py similarity index 98% rename from src/array_api_extra/_funcs.py rename to src/array_api_extra/_lib/_funcs.py index 369319e..13f43b2 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -2,9 +2,9 @@ import warnings -from ._lib import _compat, _utils -from ._lib._compat import array_namespace -from ._lib._typing import Array, ModuleType +from ._utils import _compat, _helpers +from ._utils._compat import array_namespace +from ._utils._typing import Array, ModuleType __all__ = [ "atleast_nd", @@ -136,7 +136,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = _utils.mean(m, axis=1, xp=xp) + avg = _helpers.mean(m, axis=1, xp=xp) fact = m.shape[1] - 1 if fact <= 0: @@ -449,7 +449,7 @@ def setdiff1d( else: x1 = xp.unique_values(x1) x2 = xp.unique_values(x2) - return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] + return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: @@ -546,8 +546,8 @@ def pad( pad_width: int, mode: str = "constant", *, - xp: ModuleType | None = None, constant_values: bool | int | float | complex = 0, + xp: ModuleType | None = None, ) -> Array: """ Pad the input array. @@ -561,10 +561,10 @@ def pad( mode : str, optional Only "constant" mode is currently supported, which pads with the value passed to `constant_values`. - xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. constant_values : python scalar, optional Use this value to pad the input. Default is zero. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. Returns ------- diff --git a/src/array_api_extra/_lib/_utils/__init__.py b/src/array_api_extra/_lib/_utils/__init__.py new file mode 100644 index 0000000..3628c45 --- /dev/null +++ b/src/array_api_extra/_lib/_utils/__init__.py @@ -0,0 +1 @@ +"""Modules housing private utility functions.""" diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py new file mode 100644 index 0000000..9d50090 --- /dev/null +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -0,0 +1,31 @@ +"""Acquire helpers from array-api-compat.""" +# Allow packages that vendor both `array-api-extra` and +# `array-api-compat` to override the import location + +try: + from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] + array_namespace, # pyright: ignore[reportUnknownVariableType] + device, # pyright: ignore[reportUnknownVariableType] + is_cupy_namespace, # pyright: ignore[reportUnknownVariableType] + is_jax_namespace, # pyright: ignore[reportUnknownVariableType] + is_numpy_namespace, # pyright: ignore[reportUnknownVariableType] + is_torch_namespace, # pyright: ignore[reportUnknownVariableType] + ) +except ImportError: + from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] + array_namespace, # pyright: ignore[reportUnknownVariableType] + device, + is_cupy_namespace, # pyright: ignore[reportUnknownVariableType] + is_jax_namespace, # pyright: ignore[reportUnknownVariableType] + is_numpy_namespace, # pyright: ignore[reportUnknownVariableType] + is_torch_namespace, # pyright: ignore[reportUnknownVariableType] + ) + +__all__ = [ + "array_namespace", + "device", + "is_cupy_namespace", + "is_jax_namespace", + "is_numpy_namespace", + "is_torch_namespace", +] diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi similarity index 70% rename from src/array_api_extra/_lib/_compat.pyi rename to src/array_api_extra/_lib/_utils/_compat.pyi index 2105708..62a1ea4 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -15,3 +15,7 @@ def array_namespace( use_compat: bool | None = None, ) -> ArrayModule: ... # numpydoc ignore=GL08 def device(x: Array, /) -> Device: ... # numpydoc ignore=GL08 +def is_cupy_namespace(xp: ModuleType, /) -> bool: ... +def is_jax_namespace(xp: ModuleType, /) -> bool: ... +def is_numpy_namespace(xp: ModuleType, /) -> bool: ... +def is_torch_namespace(xp: ModuleType, /) -> bool: ... diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils/_helpers.py similarity index 97% rename from src/array_api_extra/_lib/_utils.py rename to src/array_api_extra/_lib/_utils/_helpers.py index 523c21b..4e79eba 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -1,4 +1,4 @@ -"""Utility functions used by `array_api_extra/_funcs.py`.""" +"""Helper functions used by `array_api_extra/_funcs.py`.""" from . import _compat from ._typing import Array, ModuleType diff --git a/src/array_api_extra/_lib/_typing.py b/src/array_api_extra/_lib/_utils/_typing.py similarity index 100% rename from src/array_api_extra/_lib/_typing.py rename to src/array_api_extra/_lib/_utils/_typing.py diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 938a4f3..f622339 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -17,7 +17,7 @@ setdiff1d, sinc, ) -from array_api_extra._lib._typing import Array +from array_api_extra._lib._utils._typing import Array class TestAtLeastND: diff --git a/tests/test_utils.py b/tests/test_utils.py index 1807627..916c4c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,8 +3,8 @@ import pytest from numpy.testing import assert_array_equal -from array_api_extra._lib._typing import Array -from array_api_extra._lib._utils import in1d +from array_api_extra._lib._utils._helpers import in1d +from array_api_extra._lib._utils._typing import Array # some test coverage already provided by TestSetDiff1D From 38690bbce581c11275164f6d79e22e4372ea323b Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 15:33:33 +0000 Subject: [PATCH 2/3] fix vendor test --- vendor_tests/test_vendor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index a248abc..4343d76 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -19,6 +19,6 @@ def test_vendor_extra(): def test_vendor_extra_uses_vendor_compat(): from ._array_api_compat_vendor import array_namespace as n1 - from .array_api_extra._lib._compat import array_namespace as n2 + from .array_api_extra._lib._utils._compat import array_namespace as n2 assert n1 is n2 From d4d05b0abf1c629bb81fbef832f7b32cdb34ab68 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 15:46:53 +0000 Subject: [PATCH 3/3] fixes --- src/array_api_extra/__init__.py | 2 +- src/array_api_extra/_delegators.py | 12 ++++--- src/array_api_extra/_lib/_funcs.py | 40 +++------------------- src/array_api_extra/_lib/_utils/_compat.py | 2 +- 4 files changed, 13 insertions(+), 43 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index eae343e..9006ca5 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,12 +1,12 @@ """Extra array functions built on top of the array API standard.""" +from ._delegators import pad from ._lib._funcs import ( atleast_nd, cov, create_diagonal, expand_dims, kron, - pad, setdiff1d, sinc, ) diff --git a/src/array_api_extra/_delegators.py b/src/array_api_extra/_delegators.py index 118a111..6ca9f88 100644 --- a/src/array_api_extra/_delegators.py +++ b/src/array_api_extra/_delegators.py @@ -44,16 +44,18 @@ def pad( """ xp = array_namespace(x) if xp is None else xp - value = constant_values + if mode != "constant": + msg = "Only `'constant'` mode is currently supported" + raise NotImplementedError(msg) # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 if is_torch_namespace(xp): pad_width = xp.asarray(pad_width) pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) pad_width = xp.flip(pad_width, axis=(0,)).flatten() - return xp.nn.functional.pad(x, (pad_width,), value=value) + return xp.nn.functional.pad(x, (pad_width,), value=constant_values) - if is_numpy_namespace(x) or is_jax_namespace(xp) or is_cupy_namespace(xp): - return xp.pad(x, pad_width, mode, constant_values=value) + if is_numpy_namespace(xp) or is_jax_namespace(xp) or is_cupy_namespace(xp): + return xp.pad(x, pad_width, mode, constant_values=constant_values) - return _funcs.pad(x, pad_width, mode, constant_values=value, xp=xp) + return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 13f43b2..8f96cec 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -544,46 +544,14 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: def pad( x: Array, pad_width: int, - mode: str = "constant", *, constant_values: bool | int | float | complex = 0, - xp: ModuleType | None = None, -) -> Array: - """ - Pad the input array. - - Parameters - ---------- - x : array - Input array. - pad_width : int - Pad the input array with this many elements from each side. - mode : str, optional - Only "constant" mode is currently supported, which pads with - the value passed to `constant_values`. - constant_values : python scalar, optional - Use this value to pad the input. Default is zero. - xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. - - Returns - ------- - array - The input array, - padded with ``pad_width`` elements equal to ``constant_values``. - """ - if mode != "constant": - msg = "Only `'constant'` mode is currently supported" - raise NotImplementedError(msg) - - value = constant_values - - if xp is None: - xp = array_namespace(x) - + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `_delegators.py`.""" padded = xp.full( tuple(x + 2 * pad_width for x in x.shape), - fill_value=value, + fill_value=constant_values, dtype=x.dtype, device=_compat.device(x), ) diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py index 9d50090..943a285 100644 --- a/src/array_api_extra/_lib/_utils/_compat.py +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -3,7 +3,7 @@ # `array-api-compat` to override the import location try: - from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] + from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports] array_namespace, # pyright: ignore[reportUnknownVariableType] device, # pyright: ignore[reportUnknownVariableType] is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]