diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..7a973567 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -4,8 +4,19 @@ from functools import wraps from inspect import signature +from typing import TYPE_CHECKING -def get_xp(xp): +__all__ = ["get_xp"] + +if TYPE_CHECKING: + from collections.abc import Callable + from types import ModuleType + from typing import TypeVar + + _T = TypeVar("_T") + + +def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]": """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]": @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # pyright: ignore[reportReturnType] return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..997a2917 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,36 +5,44 @@ from __future__ import annotations import inspect -from typing import NamedTuple, Optional, Sequence, Tuple, Union - -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace +from typing import TYPE_CHECKING, NamedTuple, Optional, Sequence, cast + +from ._helpers import ( + _check_device, # pyright: ignore[reportPrivateUsage] + array_namespace, + device, + is_cupy_namespace, +) from ._typing import Array, Device, DType, Namespace +if TYPE_CHECKING: + from typing_extensions import TypeIs + # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy) def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) @@ -44,35 +52,35 @@ def empty_like( /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, xp: Namespace, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], fill_value: complex, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) @@ -83,35 +91,35 @@ def full_like( fill_value: complex, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) @@ -121,20 +129,20 @@ def ones_like( /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) @@ -144,9 +152,9 @@ def zeros_like( /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -251,10 +259,10 @@ def std( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) @@ -263,10 +271,10 @@ def var( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) @@ -278,10 +286,10 @@ def cumulative_sum( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) @@ -309,10 +317,10 @@ def cumulative_prod( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) @@ -338,14 +346,14 @@ def cumulative_prod( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[Array] = None, + out: Array | None = None, ) -> Array: - def _isscalar(a): + def _isscalar(a: object) -> TypeIs[int | float | None]: return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape @@ -384,6 +392,7 @@ def _isscalar(a): dev = device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright out[()] = x if min is not None: @@ -402,18 +411,18 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape( x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], xp: Namespace, *, copy: Optional[bool] = None, - **kwargs, + **kwargs: object, ) -> Array: if copy is True: x = x.copy() @@ -433,7 +442,7 @@ def argsort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' @@ -464,7 +473,7 @@ def sort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' @@ -477,31 +486,31 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: +def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: +def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: +def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. @@ -516,8 +525,8 @@ def tensordot( /, xp: Namespace, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) @@ -541,10 +550,10 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: def isdtype( dtype: DType, - kind: Union[DType, str, Tuple[Union[DType, str], ...]], + kind: DType | str | tuple[DType | str, ...], xp: Namespace, *, - _tuple: bool = True, # Disallow nested tuples + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -557,7 +566,10 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): if kind == 'bool': return dtype == xp.bool_ @@ -583,14 +595,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index bd2a4e1a..6fe834dc 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union, Optional, Literal +from typing import Literal -from ._typing import Device, Array, DType, Namespace +from ._typing import Array, Device, DType, Namespace + +_Norm = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. @@ -13,9 +15,9 @@ def fft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -27,9 +29,9 @@ def ifft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -41,9 +43,9 @@ def fftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -55,9 +57,9 @@ def ifftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -69,9 +71,9 @@ def rfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: @@ -83,9 +85,9 @@ def irfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: @@ -97,9 +99,9 @@ def rfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: @@ -111,9 +113,9 @@ def irfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: @@ -125,9 +127,9 @@ def hfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -139,9 +141,9 @@ def ihfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -154,8 +156,8 @@ def fftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -170,8 +172,8 @@ def rfftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -181,12 +183,12 @@ def rfftfreq( return res def fftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.fftshift(x, axes=axes) def ifftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.ifftshift(x, axes=axes) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6d95069d..8e81ecdd 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,44 @@ """ from __future__ import annotations -import sys -import math import inspect +import math +import sys import warnings -from typing import Optional, Union, Any +from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace + +if TYPE_CHECKING: + from collections.abc import Collection + + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse # pyright: ignore[reportMissingTypeStubs] + import torch + from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar -from ._typing import Array, Device, Namespace + _SizeT = TypeVar("_SizeT", bound=int | None) + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] + _CupyArray: TypeAlias = Any # cupy has no py.typed -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + | _CupyArray + ) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. @@ -25,13 +53,17 @@ def _is_jax_zero_gradient_array(x: object) -> bool: if 'numpy' not in sys.modules or 'jax' not in sys.modules: return False - import numpy as np import jax + import numpy as np - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + jax_float0 = cast("np.dtype[np.void]", jax.float0) + return ( + isinstance(x, np.ndarray) + and cast("npt.NDArray[np.void]", x).dtype == jax_float0 + ) -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -60,7 +92,7 @@ def is_numpy_array(x: object) -> bool: # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip def is_cupy_array(x: object) -> bool: @@ -88,13 +120,13 @@ def is_cupy_array(x: object) -> bool: if 'cupy' not in sys.modules: return False - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -122,7 +154,7 @@ def is_torch_array(x: object) -> bool: return isinstance(x, torch.Tensor) -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -150,7 +182,7 @@ def is_ndonnx_array(x: object) -> bool: return isinstance(x, ndx.Array) -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -177,8 +209,7 @@ def is_dask_array(x: object) -> bool: return isinstance(x, dask.array.Array) - -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -207,7 +238,7 @@ def is_jax_array(x: object) -> bool: return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -231,13 +262,13 @@ def is_pydata_sparse_array(x) -> bool: if 'sparse' not in sys.modules: return False - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] """ Return True if `x` is an array API compatible array object. @@ -429,7 +460,7 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: return xp.__name__ == 'array_api_strict' -def _check_api_version(api_version: str) -> None: +def _check_api_version(api_version: str | None) -> None: if api_version in ['2021.12', '2022.12', '2023.12']: warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") elif api_version is not None and api_version not in ['2021.12', '2022.12', @@ -438,9 +469,9 @@ def _check_api_version(api_version: str) -> None: def array_namespace( - *xs: Union[Array, bool, int, float, complex, None], - api_version: Optional[str] = None, - use_compat: Optional[bool] = None, + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, ) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -510,11 +541,12 @@ def your_function(x, y): _use_compat = use_compat in [None, True] - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: if is_numpy_array(x): - from .. import numpy as numpy_namespace import numpy as np + + from .. import numpy as numpy_namespace if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) @@ -530,7 +562,8 @@ def your_function(x, y): from .. import cupy as cupy_namespace namespaces.add(cupy_namespace) else: - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + namespaces.add(cp) elif is_torch_array(x): if _use_compat: @@ -561,20 +594,21 @@ def your_function(x, y): if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp + import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) elif hasattr(x, '__array_namespace__'): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") + x = cast("SupportsArrayNamespace[Any]", x) namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue @@ -595,7 +629,7 @@ def your_function(x, y): # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): +def _check_device(xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] if xp == sys.modules.get('numpy'): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") @@ -604,7 +638,7 @@ def _check_device(xp, device): # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" _DASK_DEVICE = _dask_device() @@ -615,7 +649,7 @@ def __repr__(self): # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -651,7 +685,7 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): + if is_numpy_array(x._meta): # pyright: ignore # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -675,22 +709,28 @@ def device(x: Array, /) -> Device: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # pyright: ignore + # Prevent shadowing, used below _device = device # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime +def _cupy_to_device( + x: _CupyArray, + device: Device, + /, + stream: int | Any | None = None, +) -> _CupyArray: + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + from cupy.cuda import Device as _Device # pyright: ignore + from cupy.cuda import stream as stream_module # pyright: ignore + from cupy_backends.cuda.api import runtime # pyright: ignore if device == x.device: return x @@ -703,33 +743,38 @@ def _cupy_to_device(x, device, /, stream=None): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None + prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_stream = None if stream is not None: - prev_stream = stream_module.get_current_stream() + prev_stream: Any = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): + stream = cp.cuda.ExternalStream(stream) # pyright: ignore + elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] pass else: raise ValueError('the input stream is not recognized') - stream.use() + stream.use() # pyright: ignore[reportUnknownMemberType] try: - runtime.setDevice(device.id) + runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] arr = x.copy() finally: - runtime.setDevice(prev_device) + runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] if stream is not None: prev_stream.use() return arr -def _torch_to_device(x, device, /, stream=None): +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -749,7 +794,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -788,7 +833,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) + return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") @@ -799,7 +844,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -808,10 +854,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[None]]) -> None: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -826,7 +878,7 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out @@ -889,7 +941,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index c77ee3b8..a7c7ef4b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,23 +1,33 @@ from __future__ import annotations import math -from typing import Literal, NamedTuple, Optional, Tuple, Union +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp -from ._typing import Array, Namespace +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): @@ -39,46 +49,66 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd( - x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, ) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: Array, - /, - xp: Namespace, - *, - rtol: Optional[Union[float, Array]] = None, - **kwargs) -> Array: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +118,12 @@ def matrix_rank(x: Array, return xp.count_nonzero(S > tol, axis=-1) def pinv( - x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, ) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -104,13 +139,13 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', + ord: float | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) def vector_norm( @@ -118,9 +153,9 @@ def vector_norm( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Optional[Union[int, float]] = 2, + ord: float = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -133,7 +168,10 @@ def vector_norm( elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -149,7 +187,13 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + _axis = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -159,11 +203,17 @@ def vector_norm( # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace( - x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, ) -> Array: return xp.asarray( xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 4c3b356b..a38d083b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,11 +1,14 @@ from __future__ import annotations + from types import ModuleType as Namespace -from typing import Any, TypeVar, Protocol +from typing import Any, Protocol, TypeVar __all__ = [ "Array", + "SupportsArrayNamespace", "DType", "Device", + "HasShape", "Namespace", "NestedSequence", "SupportsBufferProtocol", @@ -18,6 +21,15 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + SupportsBufferProtocol = Any Array = Any Device = Any diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 02c55d28..72900680 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final + +from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from numpy import abs as abs +from numpy import max as max +from numpy import min as min +from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,18 +19,18 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') - -__import__(__package__ + '.fft') +__import__(__package__ + ".linalg") # pyright: ignore -from .linalg import matrix_transpose, vecdot # noqa: F401 +__import__(__package__ + ".fft") # pyright: ignore -from ..common._helpers import * # noqa: F403 +from ..common._helpers import * # noqa: F403 +from .linalg import matrix_transpose as matrix_transpose +from .linalg import vecdot as vecdot try: # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 + from numpy import _CopyMode as _CopyMode # pyright: ignore[reportPrivateUsage] except ImportError: pass -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6536d9a8..0832abee 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,6 +1,10 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import Optional, Union +from builtins import bool as py_bool +from typing import TYPE_CHECKING, cast + +import numpy as np from .._internal import get_xp from ..common import _aliases @@ -8,7 +12,12 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -import numpy as np +if TYPE_CHECKING: + from typing import Any, Literal + + from typing_extensions import Buffer, TypeAlias, TypeIs + + _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -63,9 +72,9 @@ sign = get_xp(np)(_aliases.sign) -def _supports_buffer_protocol(obj): +def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] try: - memoryview(obj) + memoryview(obj) # pyright: ignore[reportArgumentType] except TypeError: return False return True @@ -76,15 +85,13 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: _Copy | None = None, + **kwargs: Any, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -108,7 +115,7 @@ def asarray( if copy is False: raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") - return np.array(obj, copy=copy, dtype=dtype, **kwargs) + return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore def astype( @@ -116,16 +123,20 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore if axis is None and not keepdims: return np.asarray(result) return result @@ -148,10 +159,25 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] +__all__ = [ + "__array_namespace_info__", + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", +] +__all__ += _aliases.__all__ _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..6b8b75e8 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,28 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -131,7 +135,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. @@ -183,7 +191,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +273,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +325,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 6a18a3b2..c49a5dbf 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -3,29 +3,24 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["np"] -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import numpy as np -from numpy import ndarray as Array Device = Literal["cpu"] if TYPE_CHECKING: + from typing_extensions import TypeAlias + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] - DType = np.dtype[ - np.intp - | np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64 + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] | np.float32 | np.float64 | np.complex64 | np.complex128 - | np.bool ] + Array: TypeAlias = np.ndarray[Any, DType] else: DType = np.dtype + Array = np.ndarray diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..5423bd01 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,9 @@ -from numpy.fft import * # noqa: F403 +import numpy as np from numpy.fft import __all__ as fft_all +from numpy.fft import fft2, ifft2, irfft2, rfft2 -from ..common import _fft from .._internal import get_xp - -import numpy as np +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,7 +20,8 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ +__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ += _fft.__all__ del get_xp del np