From e9baf91bdd2669d657dcb8f9ee10193ca9c252b3 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 25 Nov 2024 11:36:45 -0800 Subject: [PATCH 1/5] feat: add scalar support to element-wise functions --- .../_draft/elementwise_functions.py | 404 +++++++++++------- 1 file changed, 251 insertions(+), 153 deletions(-) diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 156715200..8b8f4b2d8 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -272,15 +272,15 @@ def acosh(x: array, /) -> array: """ -def add(x1: array, x2: array, /) -> array: +def add(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Calculates the sum for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a numeric data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -291,6 +291,8 @@ def add(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For real-valued floating-point operands, @@ -514,7 +516,7 @@ def atan(x: array, /) -> array: """ -def atan2(x1: array, x2: array, /) -> array: +def atan2(x1: Union[array, float], x2: Union[array, float], /) -> array: """ Calculates an implementation-dependent approximation of the inverse tangent of the quotient ``x1/x2``, having domain ``[-infinity, +infinity] x [-infinity, +infinity]`` (where the ``x`` notation denotes the set of ordered pairs of elements ``(x1_i, x2_i)``) and codomain ``[-π, +π]``, for each pair of elements ``(x1_i, x2_i)`` of the input arrays ``x1`` and ``x2``, respectively. Each element-wise result is expressed in radians. @@ -527,9 +529,9 @@ def atan2(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, float] input array corresponding to the y-coordinates. Should have a real-valued floating-point data type. - x2: array + x2: Union[array, float] input array corresponding to the x-coordinates. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -540,6 +542,8 @@ def atan2(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For floating-point operands, @@ -639,39 +643,51 @@ def atanh(x: array, /) -> array: """ -def bitwise_and(x1: array, x2: array, /) -> array: +def bitwise_and(x1: Union[array, int, bool], x2: Union[array, int, bool], /) -> array: """ Computes the bitwise AND of the underlying binary representation of each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, bool] first input array. Should have an integer or boolean data type. - x2: array + x2: Union[array, int, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer or boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ -def bitwise_left_shift(x1: array, x2: array, /) -> array: +def bitwise_left_shift( + x1: Union[array, int, bool], x2: Union[array, int, bool], / +) -> array: """ Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the left by appending ``x2_i`` (i.e., the respective element in the input array ``x2``) zeros to the right of ``x1_i``. Parameters ---------- - x1: array + x1: Union[array, int, bool] first input array. Should have an integer data type. - x2: array + x2: Union[array, int, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer data type. Each element must be greater than or equal to ``0``. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ @@ -691,25 +707,32 @@ def bitwise_invert(x: array, /) -> array: """ -def bitwise_or(x1: array, x2: array, /) -> array: +def bitwise_or(x1: Union[array, int, bool], x2: Union[array, int, bool], /) -> array: """ Computes the bitwise OR of the underlying binary representation of each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, bool] first input array. Should have an integer or boolean data type. - x2: array + x2: Union[array, int, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer or boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ -def bitwise_right_shift(x1: array, x2: array, /) -> array: +def bitwise_right_shift( + x1: Union[array, int, bool], x2: Union[array, int, bool], / +) -> array: """ Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the right according to the respective element ``x2_i`` of the input array ``x2``. @@ -718,33 +741,43 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, bool] first input array. Should have an integer data type. - x2: array + x2: Union[array, int, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer data type. Each element must be greater than or equal to ``0``. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ -def bitwise_xor(x1: array, x2: array, /) -> array: +def bitwise_xor(x1: Union[array, int, bool], x2: Union[array, int, bool], /) -> array: """ Computes the bitwise XOR of the underlying binary representation of each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, bool] first input array. Should have an integer or boolean data type. - x2: array + x2: Union[array, int, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer or boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ @@ -852,15 +885,15 @@ def conj(x: array, /) -> array: """ -def copysign(x1: array, x2: array, /) -> array: +def copysign(x1: Union[array, float], x2: Union[array, float], /) -> array: r""" Composes a floating-point value with the magnitude of ``x1_i`` and the sign of ``x2_i`` for each element of the input array ``x1``. Parameters ---------- - x1: array + x1: Union[array, float] input array containing magnitudes. Should have a real-valued floating-point data type. - x2: array + x2: Union[array, float] input array whose sign bits are applied to the magnitudes of ``x1``. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -871,6 +904,8 @@ def copysign(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For real-valued floating-point operands, let ``|x|`` be the absolute value, and if ``x1_i`` is not ``NaN``, @@ -1003,7 +1038,9 @@ def cosh(x: array, /) -> array: """ -def divide(x1: array, x2: array, /) -> array: +def divide( + x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / +) -> array: r""" Calculates the division of each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1014,9 +1051,9 @@ def divide(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float, complex] dividend input array. Should have a numeric data type. - x2: array + x2: Union[array, int, float, complex] divisor input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -1027,6 +1064,8 @@ def divide(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For real-valued floating-point operands, @@ -1086,15 +1125,19 @@ def divide(x1: array, x2: array, /) -> array: """ -def equal(x1: array, x2: array, /) -> array: +def equal( + x1: Union[array, int, float, complex, bool], + x2: Union[array, int, float, complex, bool], + /, +) -> array: r""" Computes the truth value of ``x1_i == x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, float, complex, bool] first input array. May have any data type. - x2: array + x2: Union[array, int, float, complex, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). May have any data type. Returns @@ -1105,6 +1148,8 @@ def equal(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special Cases** For real-valued floating-point operands, @@ -1279,7 +1324,9 @@ def floor(x: array, /) -> array: """ -def floor_divide(x1: array, x2: array, /) -> array: +def floor_divide( + x1: Union[array, int, float], x2: Union[array, int, float], / +) -> array: r""" Rounds the result of dividing each element ``x1_i`` of the input array ``x1`` by the respective element ``x2_i`` of the input array ``x2`` to the greatest (i.e., closest to `+infinity`) integer-value number that is not greater than the division result. @@ -1288,9 +1335,9 @@ def floor_divide(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float] dividend input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] divisor input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1301,6 +1348,8 @@ def floor_divide(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** .. note:: @@ -1339,7 +1388,7 @@ def floor_divide(x1: array, x2: array, /) -> array: """ -def greater(x1: array, x2: array, /) -> array: +def greater(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Computes the truth value of ``x1_i > x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1348,9 +1397,9 @@ def greater(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1358,13 +1407,18 @@ def greater(x1: array, x2: array, /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. """ -def greater_equal(x1: array, x2: array, /) -> array: +def greater_equal( + x1: Union[array, int, float], x2: Union[array, int, float], / +) -> array: """ Computes the truth value of ``x1_i >= x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1373,9 +1427,9 @@ def greater_equal(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1383,12 +1437,15 @@ def greater_equal(x1: array, x2: array, /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. """ -def hypot(x1: array, x2: array, /) -> array: +def hypot(x1: Union[array, float], x2: Union[array, float], /) -> array: r""" Computes the square root of the sum of squares for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1397,9 +1454,9 @@ def hypot(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, float] first input array. Should have a real-valued floating-point data type. - x2: array + x2: Union[array, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -1410,7 +1467,8 @@ def hypot(x1: array, x2: array, /) -> array: Notes ----- - The purpose of this function is to avoid underflow and overflow during intermediate stages of computation. Accordingly, conforming implementations should not use naive implementations. + - At least one of ``x1`` or ``x2`` must be an array. + - The purpose of this function is to avoid underflow and overflow during intermediate stages of computation. Accordingly, conforming implementations should not use naive implementations. **Special Cases** @@ -1562,7 +1620,7 @@ def isnan(x: array, /) -> array: """ -def less(x1: array, x2: array, /) -> array: +def less(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Computes the truth value of ``x1_i < x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1571,9 +1629,9 @@ def less(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1581,12 +1639,15 @@ def less(x1: array, x2: array, /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. """ -def less_equal(x1: array, x2: array, /) -> array: +def less_equal(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Computes the truth value of ``x1_i <= x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1595,9 +1656,9 @@ def less_equal(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1605,8 +1666,11 @@ def less_equal(x1: array, x2: array, /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. """ @@ -1818,15 +1882,15 @@ def log10(x: array, /) -> array: """ -def logaddexp(x1: array, x2: array, /) -> array: +def logaddexp(x1: Union[array, float], x2: Union[array, float], /) -> array: """ Calculates the logarithm of the sum of exponentiations ``log(exp(x1) + exp(x2))`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, float] first input array. Should have a real-valued floating-point data type. - x2: array + x2: Union[array, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -1837,6 +1901,8 @@ def logaddexp(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For floating-point operands, @@ -1847,7 +1913,7 @@ def logaddexp(x1: array, x2: array, /) -> array: """ -def logical_and(x1: array, x2: array, /) -> array: +def logical_and(x1: Union[array, bool], x2: Union[array, bool], /) -> array: """ Computes the logical AND for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1856,15 +1922,20 @@ def logical_and(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, bool] first input array. Should have a boolean data type. - x2: array + x2: Union[array, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type of `bool`. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ @@ -1887,7 +1958,7 @@ def logical_not(x: array, /) -> array: """ -def logical_or(x1: array, x2: array, /) -> array: +def logical_or(x1: Union[array, bool], x2: Union[array, bool], /) -> array: """ Computes the logical OR for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1896,19 +1967,24 @@ def logical_or(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, bool] first input array. Should have a boolean data type. - x2: array + x2: Union[array, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ -def logical_xor(x1: array, x2: array, /) -> array: +def logical_xor(x1: Union[array, bool], x2: Union[array, bool], /) -> array: """ Computes the logical XOR for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1917,27 +1993,32 @@ def logical_xor(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, bool] first input array. Should have a boolean data type. - x2: array + x2: Union[array, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a boolean data type. Returns ------- out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + + Notes + ----- + + - At least one of ``x1`` or ``x2`` must be an array. """ -def maximum(x1: array, x2: array, /) -> array: +def maximum(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: r""" Computes the maximum value for each element ``x1_i`` of the input array ``x1`` relative to the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1948,9 +2029,9 @@ def maximum(x1: array, x2: array, /) -> array: Notes ----- - The order of signed zeros is unspecified and thus implementation-defined. When choosing between ``-0`` or ``+0`` as a maximum value, specification-compliant libraries may choose to return either value. - - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-defined (see :ref:`complex-number-ordering`). + - At least one of ``x1`` or ``x2`` must be an array. + - The order of signed zeros is unspecified and thus implementation-defined. When choosing between ``-0`` or ``+0`` as a maximum value, specification-compliant libraries may choose to return either value. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-defined (see :ref:`complex-number-ordering`). **Special Cases** @@ -1962,15 +2043,15 @@ def maximum(x1: array, x2: array, /) -> array: """ -def minimum(x1: array, x2: array, /) -> array: +def minimum(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: r""" Computes the minimum value for each element ``x1_i`` of the input array ``x1`` relative to the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, float] first input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -1981,9 +2062,9 @@ def minimum(x1: array, x2: array, /) -> array: Notes ----- - The order of signed zeros is unspecified and thus implementation-defined. When choosing between ``-0`` or ``+0`` as a minimum value, specification-compliant libraries may choose to return either value. - - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-defined (see :ref:`complex-number-ordering`). + - At least one of ``x1`` or ``x2`` must be an array. + - The order of signed zeros is unspecified and thus implementation-defined. When choosing between ``-0`` or ``+0`` as a minimum value, specification-compliant libraries may choose to return either value. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-defined (see :ref:`complex-number-ordering`). **Special Cases** @@ -1995,7 +2076,9 @@ def minimum(x1: array, x2: array, /) -> array: """ -def multiply(x1: array, x2: array, /) -> array: +def multiply( + x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / +) -> array: r""" Calculates the product for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -2004,9 +2087,9 @@ def multiply(x1: array, x2: array, /) -> array: Parameters ---------- - x1: array + x1: Union[array, int, float, complex] first input array. Should have a numeric data type. - x2: array + x2: Union[array, int, float, complex] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -2017,6 +2100,8 @@ def multiply(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For real-valued floating-point operands, @@ -2091,15 +2176,15 @@ def negative(x: array, /) -> array: """ -def nextafter(x1: array, x2: array, /) -> array: +def nextafter(x1: Union[array, float], x2: Union[array, float], /) -> array: """ Returns the next representable floating-point value for each element ``x1_i`` of the input array ``x1`` in the direction of the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, float] first input array. Should have a real-valued floating-point data type. - x2: array + x2: Union[array, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have the same data type as ``x1``. Returns @@ -2110,6 +2195,8 @@ def nextafter(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special cases** For real-valued floating-point operands, @@ -2120,15 +2207,19 @@ def nextafter(x1: array, x2: array, /) -> array: """ -def not_equal(x1: array, x2: array, /) -> array: +def not_equal( + x1: Union[array, int, float, complex, bool], + x2: Union[array, int, float, complex, bool], + /, +) -> array: """ Computes the truth value of ``x1_i != x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: array + x1: Union[array, int, float, complex, bool] first input array. May have any data type. - x2: array + x2: Union[array, int, float, complex, bool] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Returns @@ -2139,6 +2230,8 @@ def not_equal(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + **Special Cases** For real-valued floating-point operands, @@ -2187,73 +2280,75 @@ def positive(x: array, /) -> array: """ -def pow(x1: array, x2: array, /) -> array: +def pow( + x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / +) -> array: r""" - Calculates an implementation-dependent approximation of exponentiation by raising each element ``x1_i`` (the base) of the input array ``x1`` to the power of ``x2_i`` (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. + Calculates an implementation-dependent approximation of exponentiation by raising each element ``x1_i`` (the base) of the input array ``x1`` to the power of ``x2_i`` (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. - .. note:: - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when ``x2_i`` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. + Parameters + ---------- + x1: Union[array, int, float, complex] + first input array whose elements correspond to the exponentiation base. Should have a numeric data type. + x2: Union[array, int, float, complex] + second input array whose elements correspond to the exponentiation exponent. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. - If ``x1`` has an integer data type and ``x2`` has a floating-point data type, behavior is implementation-dependent (type promotion between data type "kinds" (integer versus floating-point) is unspecified). + Returns + ------- + out: array + an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. - .. note:: - By convention, the branch cut of the natural logarithm is the negative real axis :math:`(-\infty, 0)`. + Notes + ----- - The natural logarithm is a continuous function from above the branch cut, taking into account the sign of the imaginary component. As special cases involving complex floating-point operands should be handled according to ``exp(x2*log(x1))``, exponentiation has the same branch cut for ``x1`` as the natural logarithm (see :func:`~array_api.log`). + - At least one of ``x1`` or ``x2`` must be an array. - *Note: branch cuts follow C99 and have provisional status* (see :ref:`branch-cuts`). + - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when ``x2_i`` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. - Parameters - ---------- - x1: array - first input array whose elements correspond to the exponentiation base. Should have a numeric data type. - x2: array - second input array whose elements correspond to the exponentiation exponent. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. + - If ``x1`` has an integer data type and ``x2`` has a floating-point data type, behavior is implementation-dependent (type promotion between data type "kinds" (integer versus floating-point) is unspecified). - Returns - ------- - out: array - an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + - By convention, the branch cut of the natural logarithm is the negative real axis :math:`(-\infty, 0)`. - Notes - ----- + The natural logarithm is a continuous function from above the branch cut, taking into account the sign of the imaginary component. As special cases involving complex floating-point operands should be handled according to ``exp(x2*log(x1))``, exponentiation has the same branch cut for ``x1`` as the natural logarithm (see :func:`~array_api.log`). - **Special cases** + *Note: branch cuts follow C99 and have provisional status* (see :ref:`branch-cuts`). - For real-valued floating-point operands, + **Special cases** - - If ``x1_i`` is not equal to ``1`` and ``x2_i`` is ``NaN``, the result is ``NaN``. - - If ``x2_i`` is ``+0``, the result is ``1``, even if ``x1_i`` is ``NaN``. - - If ``x2_i`` is ``-0``, the result is ``1``, even if ``x1_i`` is ``NaN``. - - If ``x1_i`` is ``NaN`` and ``x2_i`` is not equal to ``0``, the result is ``NaN``. - - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``. - - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+0``. - - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``+infinity``, the result is ``1``. - - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``-infinity``, the result is ``1``. - - If ``x1_i`` is ``1`` and ``x2_i`` is not ``NaN``, the result is ``1``. - - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+0``. - - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+infinity``. - - If ``x1_i`` is ``+infinity`` and ``x2_i`` is greater than ``0``, the result is ``+infinity``. - - If ``x1_i`` is ``+infinity`` and ``x2_i`` is less than ``0``, the result is ``+0``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. - - If ``x1_i`` is ``+0`` and ``x2_i`` is greater than ``0``, the result is ``+0``. - - If ``x1_i`` is ``+0`` and ``x2_i`` is less than ``0``, the result is ``+infinity``. - - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. - - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. - - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. - - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. - - If ``x1_i`` is less than ``0``, ``x1_i`` is a finite number, ``x2_i`` is a finite number, and ``x2_i`` is not an integer value, the result is ``NaN``. + For real-valued floating-point operands, - For complex floating-point operands, special cases should be handled as if the operation is implemented as ``exp(x2*log(x1))``. + - If ``x1_i`` is not equal to ``1`` and ``x2_i`` is ``NaN``, the result is ``NaN``. + - If ``x2_i`` is ``+0``, the result is ``1``, even if ``x1_i`` is ``NaN``. + - If ``x2_i`` is ``-0``, the result is ``1``, even if ``x1_i`` is ``NaN``. + - If ``x1_i`` is ``NaN`` and ``x2_i`` is not equal to ``0``, the result is ``NaN``. + - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``. + - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+0``. + - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``+infinity``, the result is ``1``. + - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``-infinity``, the result is ``1``. + - If ``x1_i`` is ``1`` and ``x2_i`` is not ``NaN``, the result is ``1``. + - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+0``. + - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+infinity``. + - If ``x1_i`` is ``+infinity`` and ``x2_i`` is greater than ``0``, the result is ``+infinity``. + - If ``x1_i`` is ``+infinity`` and ``x2_i`` is less than ``0``, the result is ``+0``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. + - If ``x1_i`` is ``+0`` and ``x2_i`` is greater than ``0``, the result is ``+0``. + - If ``x1_i`` is ``+0`` and ``x2_i`` is less than ``0``, the result is ``+infinity``. + - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. + - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. + - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. + - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. + - If ``x1_i`` is less than ``0``, ``x1_i`` is a finite number, ``x2_i`` is a finite number, and ``x2_i`` is not an integer value, the result is ``NaN``. - .. note:: - Conforming implementations are allowed to treat special cases involving complex floating-point operands more carefully than as described in this specification. + For complex floating-point operands, special cases should be handled as if the operation is implemented as ``exp(x2*log(x1))``. - .. versionchanged:: 2022.12 - Added complex data type support. + .. note:: + Conforming implementations are allowed to treat special cases involving complex floating-point operands more carefully than as described in this specification. + + .. versionchanged:: 2022.12 + Added complex data type support. """ @@ -2301,21 +2396,18 @@ def reciprocal(x: array, /) -> array: """ -def remainder(x1: array, x2: array, /) -> array: +def remainder(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Returns the remainder of division for each element ``x1_i`` of the input array ``x1`` and the respective element ``x2_i`` of the input array ``x2``. .. note:: This function is equivalent to the Python modulus operator ``x1_i % x2_i``. - .. note:: - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. - Parameters ---------- - x1: array + x1: Union[array, int, float] dividend input array. Should have a real-valued data type. - x2: array + x2: Union[array, int, float] divisor input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Returns @@ -2326,6 +2418,9 @@ def remainder(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. + **Special cases** .. note:: @@ -2681,17 +2776,17 @@ def sqrt(x: array, /) -> array: """ -def subtract(x1: array, x2: array, /) -> array: +def subtract( + x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / +) -> array: """ Calculates the difference for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - The result of ``x1_i - x2_i`` must be the same as ``x1_i + (-x2_i)`` and must be governed by the same floating-point rules as addition (see :meth:`add`). - Parameters ---------- - x1: array + x1: Union[array, int, float, complex] first input array. Should have a numeric data type. - x2: array + x2: Union[array, int, float, complex] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -2702,6 +2797,9 @@ def subtract(x1: array, x2: array, /) -> array: Notes ----- + - At least one of ``x1`` or ``x2`` must be an array. + - The result of ``x1_i - x2_i`` must be the same as ``x1_i + (-x2_i)`` and must be governed by the same floating-point rules as addition (see :meth:`add`). + .. versionchanged:: 2022.12 Added complex data type support. """ From 1e1182886e013159cc5fc62bc36d2b1744c0888b Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 25 Nov 2024 11:52:03 -0800 Subject: [PATCH 2/5] fix: update types and move notes --- src/array_api_stubs/_draft/array_object.py | 175 ++++++++---------- .../_draft/elementwise_functions.py | 45 ++--- 2 files changed, 91 insertions(+), 129 deletions(-) diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index 3c6fa8763..01f067895 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -494,7 +494,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]: ONE_API = 14 """ - def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: + def __eq__(self: array, other: Union[int, float, complex, bool, array], /) -> array: r""" Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -502,7 +502,7 @@ def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: ---------- self: array array instance. May have any data type. - other: Union[int, float, bool, array] + other: Union[int, float, complex, bool, array] other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type. Returns @@ -510,12 +510,11 @@ def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + Notes + ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.equal`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. """ def __float__(self: array, /) -> float: @@ -584,9 +583,6 @@ def __ge__(self: array, other: Union[int, float, array], /) -> array: """ Computes the truth value of ``self_i >= other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- self: array @@ -599,12 +595,12 @@ def __ge__(self: array, other: Union[int, float, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater_equal`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater_equal`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ def __getitem__( @@ -645,9 +641,6 @@ def __gt__(self: array, other: Union[int, float, array], /) -> array: """ Computes the truth value of ``self_i > other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- self: array @@ -660,12 +653,12 @@ def __gt__(self: array, other: Union[int, float, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.greater`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ def __index__(self: array, /) -> int: @@ -769,9 +762,6 @@ def __le__(self: array, other: Union[int, float, array], /) -> array: """ Computes the truth value of ``self_i <= other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- self: array @@ -784,12 +774,12 @@ def __le__(self: array, other: Union[int, float, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less_equal`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less_equal`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ def __lshift__(self: array, other: Union[int, array], /) -> array: @@ -808,18 +798,16 @@ def __lshift__(self: array, other: Union[int, array], /) -> array: out: array an array containing the element-wise results. The returned array must have the same data type as ``self``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_left_shift`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_left_shift`. """ def __lt__(self: array, other: Union[int, float, array], /) -> array: """ Computes the truth value of ``self_i < other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- self: array @@ -832,12 +820,12 @@ def __lt__(self: array, other: Union[int, float, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.less`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ def __matmul__(self: array, other: array, /) -> array: @@ -892,9 +880,6 @@ def __mod__(self: array, other: Union[int, float, array], /) -> array: """ Evaluates ``self_i % other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. - Parameters ---------- self: array @@ -907,12 +892,14 @@ def __mod__(self: array, other: Union[int, float, array], /) -> array: out: array an array containing the element-wise results. Each element-wise result must have the same sign as the respective element ``other_i``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. + Notes + ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.remainder`. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.remainder`. + - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. """ - def __mul__(self: array, other: Union[int, float, array], /) -> array: + def __mul__(self: array, other: Union[int, float, complex, array], /) -> array: r""" Calculates the product for each element of an array instance with the respective element of the array ``other``. @@ -923,7 +910,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array: ---------- self: array array instance. Should have a numeric data type. - other: Union[int, float, array] + other: Union[int, float, complex, array] other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -934,14 +921,13 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array: Notes ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.multiply`. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.multiply`. .. versionchanged:: 2022.12 Added complex data type support. """ - def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: + def __ne__(self: array, other: Union[int, float, complex, bool, array], /) -> array: """ Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -949,7 +935,7 @@ def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: ---------- self: array array instance. May have any data type. - other: Union[int, float, bool, array] + other: Union[int, float, complex, bool, array] other array. Must be compatible with ``self`` (see :ref:`broadcasting`). May have any data type. Returns @@ -957,15 +943,11 @@ def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type of ``bool`` (i.e., must be a boolean array). - Notes ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.not_equal`. - - .. note:: - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.not_equal`. + - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. .. versionchanged:: 2022.12 Added complex data type support. @@ -1017,9 +999,10 @@ def __or__(self: array, other: Union[int, bool, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_or`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_or`. """ def __pos__(self: array, /) -> array: @@ -1046,15 +1029,10 @@ def __pos__(self: array, /) -> array: Added complex data type support. """ - def __pow__(self: array, other: Union[int, float, array], /) -> array: + def __pow__(self: array, other: Union[int, float, complex, array], /) -> array: r""" Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``. - .. note:: - If both ``self`` and ``other`` have integer data types, the result of ``__pow__`` when `other_i` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. - - If ``self`` has an integer data type and ``other`` has a floating-point data type, behavior is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. - Parameters ---------- self: array @@ -1070,8 +1048,9 @@ def __pow__(self: array, other: Union[int, float, array], /) -> array: Notes ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.pow`. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.pow`. + - If both ``self`` and ``other`` have integer data types, the result of ``__pow__`` when `other_i` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. + - If ``self`` has an integer data type and ``other`` has a floating-point data type, behavior is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. .. versionchanged:: 2022.12 Added complex data type support. @@ -1093,9 +1072,10 @@ def __rshift__(self: array, other: Union[int, array], /) -> array: out: array an array containing the element-wise results. The returned array must have the same data type as ``self``. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_right_shift`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_right_shift`. """ def __setitem__( @@ -1130,43 +1110,36 @@ def __setitem__( When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined. """ - def __sub__(self: array, other: Union[int, float, array], /) -> array: + def __sub__(self: array, other: Union[int, float, complex, array], /) -> array: """ - Calculates the difference for each element of an array instance with the respective element of the array ``other``. + Calculates the difference for each element of an array instance with the respective element of the array ``other``. - The result of ``self_i - other_i`` must be the same as ``self_i + (-other_i)`` and must be governed by the same floating-point rules as addition (see :meth:`array.__add__`). - - Parameters - ---------- - self: array - array instance (minuend array). Should have a numeric data type. - other: Union[int, float, array] - subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. + Parameters + ---------- + self: array + array instance (minuend array). Should have a numeric data type. + other: Union[int, float, array] + subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. - Returns - ------- - out: array - an array containing the element-wise differences. The returned array must have a data type determined by :ref:`type-promotion`. + Returns + ------- + out: array + an array containing the element-wise differences. The returned array must have a data type determined by :ref:`type-promotion`. - Notes - ----- + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.subtract`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.subtract`. + - The result of ``self_i - other_i`` must be the same as ``self_i + (-other_i)`` and must be governed by the same floating-point rules as addition (see :meth:`array.__add__`). - .. versionchanged:: 2022.12 - Added complex data type support. + .. versionchanged:: 2022.12 + Added complex data type support. """ def __truediv__(self: array, other: Union[int, float, array], /) -> array: r""" Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``. - .. note:: - If one or both of ``self`` and ``other`` have integer data types, the result is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. - - Specification-compliant libraries may choose to raise an error or return an array containing the element-wise results. If an array is returned, the array must have a real-valued floating-point data type. - Parameters ---------- self: array @@ -1182,8 +1155,11 @@ def __truediv__(self: array, other: Union[int, float, array], /) -> array: Notes ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.divide`. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.divide`. + + - If one or both of ``self`` and ``other`` have integer data types, the result is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. + + Specification-compliant libraries may choose to raise an error or return an array containing the element-wise results. If an array is returned, the array must have a real-valued floating-point data type. .. versionchanged:: 2022.12 Added complex data type support. @@ -1205,9 +1181,10 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_xor`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_xor`. """ def to_device( diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 8b8f4b2d8..57f1c9e90 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -666,17 +666,15 @@ def bitwise_and(x1: Union[array, int, bool], x2: Union[array, int, bool], /) -> """ -def bitwise_left_shift( - x1: Union[array, int, bool], x2: Union[array, int, bool], / -) -> array: +def bitwise_left_shift(x1: Union[array, int], x2: Union[array, int], /) -> array: """ Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the left by appending ``x2_i`` (i.e., the respective element in the input array ``x2``) zeros to the right of ``x1_i``. Parameters ---------- - x1: Union[array, int, bool] + x1: Union[array, int] first input array. Should have an integer data type. - x2: Union[array, int, bool] + x2: Union[array, int] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer data type. Each element must be greater than or equal to ``0``. Returns @@ -730,9 +728,7 @@ def bitwise_or(x1: Union[array, int, bool], x2: Union[array, int, bool], /) -> a """ -def bitwise_right_shift( - x1: Union[array, int, bool], x2: Union[array, int, bool], / -) -> array: +def bitwise_right_shift(x1: Union[array, int], x2: Union[array, int], /) -> array: """ Shifts the bits of each element ``x1_i`` of the input array ``x1`` to the right according to the respective element ``x2_i`` of the input array ``x2``. @@ -741,9 +737,9 @@ def bitwise_right_shift( Parameters ---------- - x1: Union[array, int, bool] + x1: Union[array, int] first input array. Should have an integer data type. - x2: Union[array, int, bool] + x2: Union[array, int] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have an integer data type. Each element must be greater than or equal to ``0``. Returns @@ -1044,11 +1040,6 @@ def divide( r""" Calculates the division of each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - .. note:: - If one or both of the input arrays have integer data types, the result is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. - - Specification-compliant libraries may choose to raise an error or return an array containing the element-wise results. If an array is returned, the array must have a real-valued floating-point data type. - Parameters ---------- x1: Union[array, int, float, complex] @@ -1066,6 +1057,10 @@ def divide( - At least one of ``x1`` or ``x2`` must be an array. + - If one or both of the input arrays have integer data types, the result is implementation-dependent, as type promotion between data type "kinds" (e.g., integer versus floating-point) is unspecified. + + Specification-compliant libraries may choose to raise an error or return an array containing the element-wise results. If an array is returned, the array must have a real-valued floating-point data type. + **Special cases** For real-valued floating-point operands, @@ -1330,9 +1325,6 @@ def floor_divide( r""" Rounds the result of dividing each element ``x1_i`` of the input array ``x1`` by the respective element ``x2_i`` of the input array ``x2`` to the greatest (i.e., closest to `+infinity`) integer-value number that is not greater than the division result. - .. note:: - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. - Parameters ---------- x1: Union[array, int, float] @@ -1349,6 +1341,7 @@ def floor_divide( ----- - At least one of ``x1`` or ``x2`` must be an array. + - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. **Special cases** @@ -1392,9 +1385,6 @@ def greater(x1: Union[array, int, float], x2: Union[array, int, float], /) -> ar """ Computes the truth value of ``x1_i > x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- x1: Union[array, int, float] @@ -1412,6 +1402,7 @@ def greater(x1: Union[array, int, float], x2: Union[array, int, float], /) -> ar - At least one of ``x1`` or ``x2`` must be an array. - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ @@ -1422,9 +1413,6 @@ def greater_equal( """ Computes the truth value of ``x1_i >= x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- x1: Union[array, int, float] @@ -1442,6 +1430,7 @@ def greater_equal( - At least one of ``x1`` or ``x2`` must be an array. - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ @@ -1624,9 +1613,6 @@ def less(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array """ Computes the truth value of ``x1_i < x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- x1: Union[array, int, float] @@ -1644,6 +1630,7 @@ def less(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array - At least one of ``x1`` or ``x2`` must be an array. - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ @@ -1651,9 +1638,6 @@ def less_equal(x1: Union[array, int, float], x2: Union[array, int, float], /) -> """ Computes the truth value of ``x1_i <= x2_i`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. - .. note:: - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). - Parameters ---------- x1: Union[array, int, float] @@ -1671,6 +1655,7 @@ def less_equal(x1: Union[array, int, float], x2: Union[array, int, float], /) -> - At least one of ``x1`` or ``x2`` must be an array. - Comparison of arrays without a corresponding promotable data type (see :ref:`type-promotion`) is undefined and thus implementation-dependent. + - For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). """ From a81840052e712e99453723d71dc6512fc6d1beac Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Mon, 25 Nov 2024 11:58:47 -0800 Subject: [PATCH 3/5] fix: update signature to indicate complex scalar support --- src/array_api_stubs/_draft/array_object.py | 10 +++++----- src/array_api_stubs/_draft/elementwise_functions.py | 8 +++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index 01f067895..4f537e010 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -148,7 +148,7 @@ def __abs__(self: array, /) -> array: Added complex data type support. """ - def __add__(self: array, other: Union[int, float, array], /) -> array: + def __add__(self: array, other: Union[int, float, complex, array], /) -> array: """ Calculates the sum for each element of an array instance with the respective element of the array ``other``. @@ -167,8 +167,7 @@ def __add__(self: array, other: Union[int, float, array], /) -> array: Notes ----- - .. note:: - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.add`. + - Element-wise results, including special cases, must equal the results returned by the equivalent element-wise function :func:`~array_api.add`. .. versionchanged:: 2022.12 Added complex data type support. @@ -190,9 +189,10 @@ def __and__(self: array, other: Union[int, bool, array], /) -> array: out: array an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + Notes + ----- - .. note:: - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_and`. + - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.bitwise_and`. """ def __array_namespace__( diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 57f1c9e90..033efbaec 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -272,15 +272,17 @@ def acosh(x: array, /) -> array: """ -def add(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: +def add( + x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / +) -> array: """ Calculates the sum for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: Union[array, int, float] + x1: Union[array, int, float, complex] first input array. Should have a numeric data type. - x2: Union[array, int, float] + x2: Union[array, int, float, complex] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. Returns From fc36a424ee5765eb50b847ba8295fb2cdc88bc9b Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 28 Nov 2024 08:01:38 -0800 Subject: [PATCH 4/5] chore: fix alignment --- src/array_api_stubs/_draft/array_object.py | 30 ++--- .../_draft/elementwise_functions.py | 104 +++++++++--------- 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index 4f537e010..bed22a834 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -1112,28 +1112,28 @@ def __setitem__( def __sub__(self: array, other: Union[int, float, complex, array], /) -> array: """ - Calculates the difference for each element of an array instance with the respective element of the array ``other``. + Calculates the difference for each element of an array instance with the respective element of the array ``other``. - Parameters - ---------- - self: array - array instance (minuend array). Should have a numeric data type. - other: Union[int, float, array] - subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. + Parameters + ---------- + self: array + array instance (minuend array). Should have a numeric data type. + other: Union[int, float, array] + subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. - Returns - ------- - out: array - an array containing the element-wise differences. The returned array must have a data type determined by :ref:`type-promotion`. + Returns + ------- + out: array + an array containing the element-wise differences. The returned array must have a data type determined by :ref:`type-promotion`. - Notes - ----- + Notes + ----- - Element-wise results must equal the results returned by the equivalent element-wise function :func:`~array_api.subtract`. - The result of ``self_i - other_i`` must be the same as ``self_i + (-other_i)`` and must be governed by the same floating-point rules as addition (see :meth:`array.__add__`). - .. versionchanged:: 2022.12 - Added complex data type support. + .. versionchanged:: 2022.12 + Added complex data type support. """ def __truediv__(self: array, other: Union[int, float, array], /) -> array: diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 033efbaec..0cc911f2e 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -2271,28 +2271,28 @@ def pow( x1: Union[array, int, float, complex], x2: Union[array, int, float, complex], / ) -> array: r""" - Calculates an implementation-dependent approximation of exponentiation by raising each element ``x1_i`` (the base) of the input array ``x1`` to the power of ``x2_i`` (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. + Calculates an implementation-dependent approximation of exponentiation by raising each element ``x1_i`` (the base) of the input array ``x1`` to the power of ``x2_i`` (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. - Parameters - ---------- - x1: Union[array, int, float, complex] - first input array whose elements correspond to the exponentiation base. Should have a numeric data type. - x2: Union[array, int, float, complex] - second input array whose elements correspond to the exponentiation exponent. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. + Parameters + ---------- + x1: Union[array, int, float, complex] + first input array whose elements correspond to the exponentiation base. Should have a numeric data type. + x2: Union[array, int, float, complex] + second input array whose elements correspond to the exponentiation exponent. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a numeric data type. - Returns - ------- - out: array - an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. + Returns + ------- + out: array + an array containing the element-wise results. The returned array must have a data type determined by :ref:`type-promotion`. - Notes - ----- + Notes + ----- - - At least one of ``x1`` or ``x2`` must be an array. + - At least one of ``x1`` or ``x2`` must be an array. - - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when ``x2_i`` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. + - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when ``x2_i`` is negative (i.e., less than zero) is unspecified and thus implementation-dependent. - - If ``x1`` has an integer data type and ``x2`` has a floating-point data type, behavior is implementation-dependent (type promotion between data type "kinds" (integer versus floating-point) is unspecified). + - If ``x1`` has an integer data type and ``x2`` has a floating-point data type, behavior is implementation-dependent (type promotion between data type "kinds" (integer versus floating-point) is unspecified). - By convention, the branch cut of the natural logarithm is the negative real axis :math:`(-\infty, 0)`. @@ -2300,42 +2300,42 @@ def pow( *Note: branch cuts follow C99 and have provisional status* (see :ref:`branch-cuts`). - **Special cases** - - For real-valued floating-point operands, - - - If ``x1_i`` is not equal to ``1`` and ``x2_i`` is ``NaN``, the result is ``NaN``. - - If ``x2_i`` is ``+0``, the result is ``1``, even if ``x1_i`` is ``NaN``. - - If ``x2_i`` is ``-0``, the result is ``1``, even if ``x1_i`` is ``NaN``. - - If ``x1_i`` is ``NaN`` and ``x2_i`` is not equal to ``0``, the result is ``NaN``. - - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``. - - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+0``. - - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``+infinity``, the result is ``1``. - - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``-infinity``, the result is ``1``. - - If ``x1_i`` is ``1`` and ``x2_i`` is not ``NaN``, the result is ``1``. - - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+0``. - - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+infinity``. - - If ``x1_i`` is ``+infinity`` and ``x2_i`` is greater than ``0``, the result is ``+infinity``. - - If ``x1_i`` is ``+infinity`` and ``x2_i`` is less than ``0``, the result is ``+0``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. - - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. - - If ``x1_i`` is ``+0`` and ``x2_i`` is greater than ``0``, the result is ``+0``. - - If ``x1_i`` is ``+0`` and ``x2_i`` is less than ``0``, the result is ``+infinity``. - - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. - - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. - - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. - - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. - - If ``x1_i`` is less than ``0``, ``x1_i`` is a finite number, ``x2_i`` is a finite number, and ``x2_i`` is not an integer value, the result is ``NaN``. - - For complex floating-point operands, special cases should be handled as if the operation is implemented as ``exp(x2*log(x1))``. - - .. note:: - Conforming implementations are allowed to treat special cases involving complex floating-point operands more carefully than as described in this specification. - - .. versionchanged:: 2022.12 - Added complex data type support. + **Special cases** + + For real-valued floating-point operands, + + - If ``x1_i`` is not equal to ``1`` and ``x2_i`` is ``NaN``, the result is ``NaN``. + - If ``x2_i`` is ``+0``, the result is ``1``, even if ``x1_i`` is ``NaN``. + - If ``x2_i`` is ``-0``, the result is ``1``, even if ``x1_i`` is ``NaN``. + - If ``x1_i`` is ``NaN`` and ``x2_i`` is not equal to ``0``, the result is ``NaN``. + - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``. + - If ``abs(x1_i)`` is greater than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+0``. + - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``+infinity``, the result is ``1``. + - If ``abs(x1_i)`` is ``1`` and ``x2_i`` is ``-infinity``, the result is ``1``. + - If ``x1_i`` is ``1`` and ``x2_i`` is not ``NaN``, the result is ``1``. + - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``+infinity``, the result is ``+0``. + - If ``abs(x1_i)`` is less than ``1`` and ``x2_i`` is ``-infinity``, the result is ``+infinity``. + - If ``x1_i`` is ``+infinity`` and ``x2_i`` is greater than ``0``, the result is ``+infinity``. + - If ``x1_i`` is ``+infinity`` and ``x2_i`` is less than ``0``, the result is ``+0``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. + - If ``x1_i`` is ``-infinity``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. + - If ``x1_i`` is ``+0`` and ``x2_i`` is greater than ``0``, the result is ``+0``. + - If ``x1_i`` is ``+0`` and ``x2_i`` is less than ``0``, the result is ``+infinity``. + - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is an odd integer value, the result is ``-0``. + - If ``x1_i`` is ``-0``, ``x2_i`` is greater than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+0``. + - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is an odd integer value, the result is ``-infinity``. + - If ``x1_i`` is ``-0``, ``x2_i`` is less than ``0``, and ``x2_i`` is not an odd integer value, the result is ``+infinity``. + - If ``x1_i`` is less than ``0``, ``x1_i`` is a finite number, ``x2_i`` is a finite number, and ``x2_i`` is not an integer value, the result is ``NaN``. + + For complex floating-point operands, special cases should be handled as if the operation is implemented as ``exp(x2*log(x1))``. + + .. note:: + Conforming implementations are allowed to treat special cases involving complex floating-point operands more carefully than as described in this specification. + + .. versionchanged:: 2022.12 + Added complex data type support. """ From da24bf551030c51ef5e3a9a4bd31bec222303450 Mon Sep 17 00:00:00 2001 From: Athan Reines Date: Thu, 28 Nov 2024 08:17:27 -0800 Subject: [PATCH 5/5] fix: ensure explicit `int` support and fix missing `complex` types --- src/array_api_stubs/_draft/array_object.py | 24 ++++++------- .../_draft/elementwise_functions.py | 34 +++++++++---------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/array_api_stubs/_draft/array_object.py b/src/array_api_stubs/_draft/array_object.py index bed22a834..6d55b1eee 100644 --- a/src/array_api_stubs/_draft/array_object.py +++ b/src/array_api_stubs/_draft/array_object.py @@ -1037,7 +1037,7 @@ def __pow__(self: array, other: Union[int, float, complex, array], /) -> array: ---------- self: array array instance whose elements correspond to the exponentiation base. Should have a numeric data type. - other: Union[int, float, array] + other: Union[int, float, complex, array] other array whose elements correspond to the exponentiation exponent. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -1083,7 +1083,7 @@ def __setitem__( key: Union[ int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array ], - value: Union[int, float, bool, array], + value: Union[int, float, complex, bool, array], /, ) -> None: """ @@ -1097,17 +1097,15 @@ def __setitem__( array instance. key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array] index key. - value: Union[int, float, bool, array] + value: Union[int, float, complex, bool, array] value(s) to set. Must be compatible with ``self[key]`` (see :ref:`broadcasting`). + Notes + ----- - .. note:: - - Setting array values must not affect the data type of ``self``. - - When ``value`` is a Python scalar (i.e., ``int``, ``float``, ``bool``), behavior must follow specification guidance on mixing arrays with Python scalars (see :ref:`type-promotion`). - - When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined. + - Setting array values must not affect the data type of ``self``. + - When ``value`` is a Python scalar (i.e., ``int``, ``float``, ``complex``, ``bool``), behavior must follow specification guidance on mixing arrays with Python scalars (see :ref:`type-promotion`). + - When ``value`` is an ``array`` of a different data type than ``self``, how values are cast to the data type of ``self`` is implementation defined. """ def __sub__(self: array, other: Union[int, float, complex, array], /) -> array: @@ -1118,7 +1116,7 @@ def __sub__(self: array, other: Union[int, float, complex, array], /) -> array: ---------- self: array array instance (minuend array). Should have a numeric data type. - other: Union[int, float, array] + other: Union[int, float, complex, array] subtrahend array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. Returns @@ -1136,7 +1134,7 @@ def __sub__(self: array, other: Union[int, float, complex, array], /) -> array: Added complex data type support. """ - def __truediv__(self: array, other: Union[int, float, array], /) -> array: + def __truediv__(self: array, other: Union[int, float, complex, array], /) -> array: r""" Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``. @@ -1144,7 +1142,7 @@ def __truediv__(self: array, other: Union[int, float, array], /) -> array: ---------- self: array array instance. Should have a numeric data type. - other: Union[int, float, array] + other: Union[int, float, complex, array] other array. Must be compatible with ``self`` (see :ref:`broadcasting`). Should have a numeric data type. Returns diff --git a/src/array_api_stubs/_draft/elementwise_functions.py b/src/array_api_stubs/_draft/elementwise_functions.py index 0cc911f2e..fa9390a01 100644 --- a/src/array_api_stubs/_draft/elementwise_functions.py +++ b/src/array_api_stubs/_draft/elementwise_functions.py @@ -518,7 +518,7 @@ def atan(x: array, /) -> array: """ -def atan2(x1: Union[array, float], x2: Union[array, float], /) -> array: +def atan2(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Calculates an implementation-dependent approximation of the inverse tangent of the quotient ``x1/x2``, having domain ``[-infinity, +infinity] x [-infinity, +infinity]`` (where the ``x`` notation denotes the set of ordered pairs of elements ``(x1_i, x2_i)``) and codomain ``[-π, +π]``, for each pair of elements ``(x1_i, x2_i)`` of the input arrays ``x1`` and ``x2``, respectively. Each element-wise result is expressed in radians. @@ -531,9 +531,9 @@ def atan2(x1: Union[array, float], x2: Union[array, float], /) -> array: Parameters ---------- - x1: Union[array, float] + x1: Union[array, int, float] input array corresponding to the y-coordinates. Should have a real-valued floating-point data type. - x2: Union[array, float] + x2: Union[array, int, float] input array corresponding to the x-coordinates. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -824,9 +824,9 @@ def clip( x: array input array. Should have a real-valued data type. min: Optional[Union[int, float, array]] - lower-bound of the range to which to clamp. If ``None``, no lower bound must be applied. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``. + lower-bound of the range to which to clamp. If ``None``, no lower bound must be applied. Must be compatible with ``x`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``. max: Optional[Union[int, float, array]] - upper-bound of the range to which to clamp. If ``None``, no upper bound must be applied. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``. + upper-bound of the range to which to clamp. If ``None``, no upper bound must be applied. Must be compatible with ``x`` (see :ref:`broadcasting`). Should have a real-valued data type. Default: ``None``. Returns ------- @@ -883,15 +883,15 @@ def conj(x: array, /) -> array: """ -def copysign(x1: Union[array, float], x2: Union[array, float], /) -> array: +def copysign(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: r""" Composes a floating-point value with the magnitude of ``x1_i`` and the sign of ``x2_i`` for each element of the input array ``x1``. Parameters ---------- - x1: Union[array, float] + x1: Union[array, int, float] input array containing magnitudes. Should have a real-valued floating-point data type. - x2: Union[array, float] + x2: Union[array, int, float] input array whose sign bits are applied to the magnitudes of ``x1``. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -1436,7 +1436,7 @@ def greater_equal( """ -def hypot(x1: Union[array, float], x2: Union[array, float], /) -> array: +def hypot(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: r""" Computes the square root of the sum of squares for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. @@ -1445,9 +1445,9 @@ def hypot(x1: Union[array, float], x2: Union[array, float], /) -> array: Parameters ---------- - x1: Union[array, float] + x1: Union[array, int, float] first input array. Should have a real-valued floating-point data type. - x2: Union[array, float] + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -1869,15 +1869,15 @@ def log10(x: array, /) -> array: """ -def logaddexp(x1: Union[array, float], x2: Union[array, float], /) -> array: +def logaddexp(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Calculates the logarithm of the sum of exponentiations ``log(exp(x1) + exp(x2))`` for each element ``x1_i`` of the input array ``x1`` with the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: Union[array, float] + x1: Union[array, int, float] first input array. Should have a real-valued floating-point data type. - x2: Union[array, float] + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued floating-point data type. Returns @@ -2163,15 +2163,15 @@ def negative(x: array, /) -> array: """ -def nextafter(x1: Union[array, float], x2: Union[array, float], /) -> array: +def nextafter(x1: Union[array, int, float], x2: Union[array, int, float], /) -> array: """ Returns the next representable floating-point value for each element ``x1_i`` of the input array ``x1`` in the direction of the respective element ``x2_i`` of the input array ``x2``. Parameters ---------- - x1: Union[array, float] + x1: Union[array, int, float] first input array. Should have a real-valued floating-point data type. - x2: Union[array, float] + x2: Union[array, int, float] second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have the same data type as ``x1``. Returns