You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/array_api_stubs/_draft/statistical_functions.py
+12-3
Original file line number
Diff line number
Diff line change
@@ -177,7 +177,7 @@ def mean(
177
177
Parameters
178
178
----------
179
179
x: array
180
-
input array. Should have a real-valued floating-point data type.
180
+
input array. Should have a floating-point data type.
181
181
axis: Optional[Union[int, Tuple[int, ...]]]
182
182
axis or axes along which arithmetic means must be computed. By default, the mean must be computed over the entire array. If a tuple of integers, arithmetic means must be computed over multiple axes. Default: ``None``.
183
183
keepdims: bool
@@ -189,17 +189,26 @@ def mean(
189
189
if the arithmetic mean was computed over the entire array, a zero-dimensional array containing the arithmetic mean; otherwise, a non-zero-dimensional array containing the arithmetic means. The returned array must have the same data type as ``x``.
190
190
191
191
.. note::
192
-
While this specification recommends that this function only accept input arrays having a real-valued floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type.
192
+
While this specification recommends that this function only accept input arrays having a floating-point data type, specification-compliant array libraries may choose to accept input arrays having an integer data type. While mixed data type promotion is implementation-defined, if the input array ``x`` has an integer data type, the returned array must have the default real-valued floating-point data type.
193
193
194
194
Notes
195
195
-----
196
196
197
197
**Special Cases**
198
198
199
-
Let ``N`` equal the number of elements over which to compute the arithmetic mean.
199
+
Let ``N`` equal the number of elements over which to compute the arithmetic mean. For real-valued operands,
200
200
201
201
- If ``N`` is ``0``, the arithmetic mean is ``NaN``.
202
202
- If ``x_i`` is ``NaN``, the arithmetic mean is ``NaN`` (i.e., ``NaN`` values propagate).
203
+
204
+
For complex floating-point operands, real-valued floating-point special cases should independently apply to the real and imaginary component operations involving real numbers. For example, let ``a = real(x_i)`` and ``b = imag(x_i)``, and
205
+
206
+
- If ``N`` is ``0``, the arithmetic mean is ``NaN + NaN j``.
207
+
- If ``a`` is ``NaN``, the real component of the result is ``NaN``.
208
+
- Similarly, if ``b`` is ``NaN``, the imaginary component of the result is ``NaN``.
209
+
210
+
.. note::
211
+
Array libraries, such as NumPy, PyTorch, and JAX, currently deviate from this specification in their handling of components which are ``NaN`` when computing the arithmetic mean. In general, consumers of array libraries implementing this specification should use :func:`~array_api.isnan` to test whether the result of computing the arithmetic mean over an array have a complex floating-point data type is ``NaN``, rather than relying on ``NaN`` propagation of individual components.
0 commit comments