Skip to content

Commit 640e6cd

Browse files
authored
feat: add complex dtype support for mean (#850)
1 parent f7d16ff commit 640e6cd

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/array_api_stubs/_draft/statistical_functions.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def mean(
177177
Parameters
178178
----------
179179
x: array
180-
input array. Should have a real-valued floating-point data type.
180+
input array. Should have a floating-point data type.
181181
axis: Optional[Union[int, Tuple[int, ...]]]
182182
axis or axes along which arithmetic means must be computed. By default, the mean must be computed over the entire array. If a tuple of integers, arithmetic means must be computed over multiple axes. Default: ``None``.
183183
keepdims: bool
@@ -189,17 +189,26 @@ def mean(
189189
if the arithmetic mean was computed over the entire array, a zero-dimensional array containing the arithmetic mean; otherwise, a non-zero-dimensional array containing the arithmetic means. The returned array must have the same data type as ``x``.
190190
191191
.. note::
192-
While this specification recommends that this function only accept input arrays having a real-valued floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type.
192+
While this specification recommends that this function only accept input arrays having a floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type.
193193
194194
Notes
195195
-----
196196
197197
**Special Cases**
198198
199-
Let ``N`` equal the number of elements over which to compute the arithmetic mean.
199+
Let ``N`` equal the number of elements over which to compute the arithmetic mean. For real-valued operands,
200200
201201
- If ``N`` is ``0``, the arithmetic mean is ``NaN``.
202202
- If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` (i.e., ``NaN`` values propagate).
203+
204+
For complex floating-point operands, real-valued floating-point special cases should independently apply to the real and imaginary component operations involving real numbers. For example, let ``a = real(x_i)`` and ``b = imag(x_i)``, and
205+
206+
- If ``N`` is ``0``, the arithmetic mean is ``NaN + NaN j``.
207+
- If ``a`` is ``NaN``, the real component of the result is ``NaN``.
208+
- Similarly, if ``b`` is ``NaN``, the imaginary component of the result is ``NaN``.
209+
210+
.. note::
211+
Array libraries, such as NumPy, PyTorch, and JAX, currently deviate from this specification in their handling of components which are ``NaN`` when computing the arithmetic mean. In general, consumers of array libraries implementing this specification should use :func:`~array_api.isnan` to test whether the result of computing the arithmetic mean over an array have a complex floating-point data type is ``NaN``, rather than relying on ``NaN`` propagation of individual components.
203212
"""
204213

205214

0 commit comments

Comments
 (0)