diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 4a884893..98b8e425 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -292,6 +292,36 @@ def cumulative_sum( ) return res + +def cumulative_prod( + x: ndarray, + /, + xp, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, + **kwargs +) -> ndarray: + wrapped_xp = array_namespace(x) + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + axis = 0 + + res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) + + # np.cumprod does not support include_initial + if include_initial: + initial_shape = list(x.shape) + initial_shape[axis] = 1 + res = xp.concatenate( + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + axis=axis, + ) + return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( @@ -544,7 +574,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray: 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', + 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 8ab5629b..50331fa0 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -49,6 +49,7 @@ std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) +cumulative_prod = get_xp(cp)(_aliases.cumulative_prod) clip = get_xp(cp)(_aliases.clip) permute_dims = get_xp(cp)(_aliases.permute_dims) reshape = get_xp(cp)(_aliases.reshape) diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 4440807d..790621e4 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -101,7 +101,7 @@ def capabilities(self): "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index ab18fd71..4e2d26f9 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -121,6 +121,7 @@ def arange( std = get_xp(da)(_aliases.std) var = get_xp(da)(_aliases.var) cumulative_sum = get_xp(da)(_aliases.cumulative_sum) +cumulative_prod = get_xp(da)(_aliases.cumulative_prod) empty = get_xp(da)(_aliases.empty) empty_like = get_xp(da)(_aliases.empty_like) full = get_xp(da)(_aliases.full) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index d3b12dc9..e15a69f4 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -102,7 +102,7 @@ def capabilities(self): "boolean indexing": False, "data-dependent shapes": False, # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 789eefb3..98eec121 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -49,6 +49,7 @@ std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) cumulative_sum = get_xp(np)(_aliases.cumulative_sum) +cumulative_prod = get_xp(np)(_aliases.cumulative_prod) clip = get_xp(np)(_aliases.clip) permute_dims = get_xp(np)(_aliases.permute_dims) reshape = get_xp(np)(_aliases.reshape) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 62f7ae62..e706d118 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -101,7 +101,7 @@ def capabilities(self): "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self): diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7af3f2af..5b20aabc 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -8,6 +8,7 @@ clip as _aliases_clip, unstack as _aliases_unstack, cumulative_sum as _aliases_cumulative_sum, + cumulative_prod as _aliases_cumulative_prod, ) from .._internal import get_xp @@ -124,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True): x1 = x1.to(dtype) return x1, x2 -def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: + +_py_scalars = (bool, int, float, complex) + + +def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") if len(arrays_and_dtypes) == 1: @@ -136,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) x, y = arrays_and_dtypes + if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): + return torch.result_type(x, y) + xdt = x.dtype if not isinstance(x, torch.dtype) else x ydt = y.dtype if not isinstance(y, torch.dtype) else y @@ -210,6 +218,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep clip = get_xp(torch)(_aliases_clip) unstack = get_xp(torch)(_aliases_unstack) cumulative_sum = get_xp(torch)(_aliases_cumulative_sum) +cumulative_prod = get_xp(torch)(_aliases_cumulative_prod) # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -504,6 +513,31 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) + +# torch uses `dim` instead of `axis` +def diff( + x: array, + /, + *, + axis: int = -1, + n: int = 1, + prepend: Optional[array] = None, + append: Optional[array] = None, +) -> array: + return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) + + +# torch uses `dim` instead of `axis` +def count_nonzero( + x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> array: + return torch.count_nonzero(x, dim=axis, keepdims=keepdims) + + def where(condition: array, x1: array, x2: array, /) -> array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) @@ -734,6 +768,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) + +def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: + return torch.take_along_dim(x, indices, dim=axis) + + def sign(x: array, /) -> array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 @@ -752,11 +791,12 @@ def sign(x: array, /) -> array: __all__ = ['__array_namespace_info__', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', - 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide', + 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', + 'diff', 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', @@ -764,6 +804,6 @@ def sign(x: array, /) -> array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'sign'] + 'take', 'take_along_axis', 'sign'] _all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 264caa9e..34fbcb21 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -86,7 +86,7 @@ def capabilities(self): "boolean indexing": True, "data-dependent shapes": True, # 'max rank' will be part of the 2024.12 standard - # "max rank": 64, + "max dimensions": 64, } def default_device(self):