Skip to content

Commit 25e7177

Browse files
authored
Fixes to the array-api stubs (#806)
* Replace info with __array_namespace_info__ in the stubs 'info' is not an actual top-level name in the namespace. * Use consistent wording for complex dtypes in the fft stubs * Fix some copysign special-cases for better machine readability and consistency
1 parent cee4167 commit 25e7177

File tree

7 files changed

+45
-49
lines changed

7 files changed

+45
-49
lines changed

src/array_api_stubs/_2022_12/fft.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fft(
3535
Parameters
3636
----------
3737
x: array
38-
input array. Should have a complex-valued floating-point data type.
38+
input array. Should have a complex floating-point data type.
3939
n: Optional[int]
4040
number of elements over which to compute the transform along the axis (dimension) specified by ``axis``. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``M``.
4141
@@ -84,7 +84,7 @@ def ifft(
8484
Parameters
8585
----------
8686
x: array
87-
input array. Should have a complex-valued floating-point data type.
87+
input array. Should have a complex floating-point data type.
8888
n: Optional[int]
8989
number of elements over which to compute the transform along the axis (dimension) specified by ``axis``. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``M``.
9090
@@ -133,7 +133,7 @@ def fftn(
133133
Parameters
134134
----------
135135
x: array
136-
input array. Should have a complex-valued floating-point data type.
136+
input array. Should have a complex floating-point data type.
137137
s: Optional[Sequence[int]]
138138
number of elements over which to compute the transform along the axes (dimensions) specified by ``axes``. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``.
139139
@@ -188,7 +188,7 @@ def ifftn(
188188
Parameters
189189
----------
190190
x: array
191-
input array. Should have a complex-valued floating-point data type.
191+
input array. Should have a complex floating-point data type.
192192
s: Optional[Sequence[int]]
193193
number of elements over which to compute the transform along the axes (dimensions) specified by ``axes``. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``.
194194
@@ -292,7 +292,7 @@ def irfft(
292292
Parameters
293293
----------
294294
x: array
295-
input array. Should have a complex-valued floating-point data type.
295+
input array. Should have a complex floating-point data type.
296296
n: Optional[int]
297297
number of elements along the transformed axis (dimension) specified by ``axis`` in the **output array**. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``2*(M-1)``.
298298
@@ -398,7 +398,7 @@ def irfftn(
398398
Parameters
399399
----------
400400
x: array
401-
input array. Should have a complex-valued floating-point data type.
401+
input array. Should have a complex floating-point data type.
402402
s: Optional[Sequence[int]]
403403
number of elements along the transformed axes (dimensions) specified by ``axes`` in the **output array**. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``, except for the last transformed axis in which ``s[i]`` equals ``2*(M[i]-1)``. For each ``i``, let ``n`` equal ``s[i]``, except for the last transformed axis in which ``n`` equals ``s[i]//2+1``.
404404
@@ -452,7 +452,7 @@ def hfft(
452452
Parameters
453453
----------
454454
x: array
455-
input array. Should have a complex-valued floating-point data type.
455+
input array. Should have a complex floating-point data type.
456456
n: Optional[int]
457457
number of elements along the transformed axis (dimension) specified by ``axis`` in the **output array**. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``2*(M-1)``.
458458

src/array_api_stubs/_2023_12/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utility_functions import *
1717
from . import linalg
1818
from . import fft
19-
from . import info
19+
from .info import __array_namespace_info__
2020

2121

2222
__array_api_version__: str = "YYYY.MM"

src/array_api_stubs/_2023_12/elementwise_functions.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -874,14 +874,12 @@ def copysign(x1: array, x2: array, /) -> array:
874874
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``-|x1_i|``.
875875
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``|x1_i|``.
876876
877-
If ``x1_i`` is ``NaN``,
878-
879-
- If ``x2_i`` is less than ``0``, the result is ``NaN`` with a sign bit of ``1``.
880-
- If ``x2_i`` is ``-0``, the result is ``NaN`` with a sign bit of ``1``.
881-
- If ``x2_i`` is ``+0``, the result is ``NaN`` with a sign bit of ``0``.
882-
- If ``x2_i`` is greater than ``0``, the result is ``NaN`` with a sign bit of ``0``.
883-
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``NaN`` with a sign bit of ``1``.
884-
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``NaN`` with a sign bit of ``0``.
877+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is less than ``0``, the result is ``NaN`` with a sign bit of ``1``.
878+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``-0``, the result is ``NaN`` with a sign bit of ``1``.
879+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``+0``, the result is ``NaN`` with a sign bit of ``0``.
880+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is greater than ``0``, the result is ``NaN`` with a sign bit of ``0``.
881+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``NaN`` with a sign bit of ``1``.
882+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``NaN`` with a sign bit of ``0``.
885883
886884
.. versionadded:: 2023.12
887885
"""

src/array_api_stubs/_2023_12/fft.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fft(
3535
Parameters
3636
----------
3737
x: array
38-
input array. Should have a complex-valued floating-point data type.
38+
input array. Should have a complex floating-point data type.
3939
n: Optional[int]
4040
number of elements over which to compute the transform along the axis (dimension) specified by ``axis``. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``M``.
4141
@@ -66,7 +66,7 @@ def fft(
6666
.. versionadded:: 2022.12
6767
6868
.. versionchanged:: 2023.12
69-
Required the input array have a complex-valued floating-point data type and required that the output array have the same data type as the input array.
69+
Required the input array have a complex floating-point data type and required that the output array have the same data type as the input array.
7070
"""
7171

7272

@@ -87,7 +87,7 @@ def ifft(
8787
Parameters
8888
----------
8989
x: array
90-
input array. Should have a complex-valued floating-point data type.
90+
input array. Should have a complex floating-point data type.
9191
n: Optional[int]
9292
number of elements over which to compute the transform along the axis (dimension) specified by ``axis``. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``M``.
9393
@@ -118,7 +118,7 @@ def ifft(
118118
.. versionadded:: 2022.12
119119
120120
.. versionchanged:: 2023.12
121-
Required the input array have a complex-valued floating-point data type and required that the output array have the same data type as the input array.
121+
Required the input array have a complex floating-point data type and required that the output array have the same data type as the input array.
122122
"""
123123

124124

@@ -139,7 +139,7 @@ def fftn(
139139
Parameters
140140
----------
141141
x: array
142-
input array. Should have a complex-valued floating-point data type.
142+
input array. Should have a complex floating-point data type.
143143
s: Optional[Sequence[int]]
144144
number of elements over which to compute the transform along the axes (dimensions) specified by ``axes``. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``.
145145
@@ -176,7 +176,7 @@ def fftn(
176176
.. versionadded:: 2022.12
177177
178178
.. versionchanged:: 2023.12
179-
Required the input array have a complex-valued floating-point data type and required that the output array have the same data type as the input array.
179+
Required the input array have a complex floating-point data type and required that the output array have the same data type as the input array.
180180
"""
181181

182182

@@ -197,7 +197,7 @@ def ifftn(
197197
Parameters
198198
----------
199199
x: array
200-
input array. Should have a complex-valued floating-point data type.
200+
input array. Should have a complex floating-point data type.
201201
s: Optional[Sequence[int]]
202202
number of elements over which to compute the transform along the axes (dimensions) specified by ``axes``. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``.
203203
@@ -234,7 +234,7 @@ def ifftn(
234234
.. versionadded:: 2022.12
235235
236236
.. versionchanged:: 2023.12
237-
Required the input array have a complex-valued floating-point data type and required that the output array have the same data type as the input array.
237+
Required the input array have a complex floating-point data type and required that the output array have the same data type as the input array.
238238
"""
239239

240240

@@ -304,7 +304,7 @@ def irfft(
304304
Parameters
305305
----------
306306
x: array
307-
input array. Should have a complex-valued floating-point data type.
307+
input array. Should have a complex floating-point data type.
308308
n: Optional[int]
309309
number of elements along the transformed axis (dimension) specified by ``axis`` in the **output array**. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``2*(M-1)``.
310310
@@ -413,7 +413,7 @@ def irfftn(
413413
Parameters
414414
----------
415415
x: array
416-
input array. Should have a complex-valued floating-point data type.
416+
input array. Should have a complex floating-point data type.
417417
s: Optional[Sequence[int]]
418418
number of elements along the transformed axes (dimensions) specified by ``axes`` in the **output array**. Let ``i`` be the index of the ``n``-th axis specified by ``axes`` (i.e., ``i = axes[n]``) and ``M[i]`` be the size of the input array along axis ``i``. When ``s`` is ``None``, the function must set ``s`` equal to a sequence of integers such that ``s[i]`` equals ``M[i]`` for all ``i``, except for the last transformed axis in which ``s[i]`` equals ``2*(M[i]-1)``. For each ``i``, let ``n`` equal ``s[i]``, except for the last transformed axis in which ``n`` equals ``s[i]//2+1``.
419419
@@ -470,7 +470,7 @@ def hfft(
470470
Parameters
471471
----------
472472
x: array
473-
input array. Should have a complex-valued floating-point data type.
473+
input array. Should have a complex floating-point data type.
474474
n: Optional[int]
475475
number of elements along the transformed axis (dimension) specified by ``axis`` in the **output array**. Let ``M`` be the size of the input array along the axis specified by ``axis``. When ``n`` is ``None``, the function must set ``n`` equal to ``2*(M-1)``.
476476
@@ -501,7 +501,7 @@ def hfft(
501501
.. versionadded:: 2022.12
502502
503503
.. versionchanged:: 2023.12
504-
Required the input array to have a complex-valued floating-point data type and required that the output array have a real-valued data type having the same precision as the input array.
504+
Required the input array to have a complex floating-point data type and required that the output array have a real-valued data type having the same precision as the input array.
505505
"""
506506

507507

src/array_api_stubs/_draft/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utility_functions import *
1717
from . import linalg
1818
from . import fft
19-
from . import info
19+
from .info import __array_namespace_info__
2020

2121

2222
__array_api_version__: str = "YYYY.MM"

src/array_api_stubs/_draft/elementwise_functions.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -875,14 +875,12 @@ def copysign(x1: array, x2: array, /) -> array:
875875
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``-|x1_i|``.
876876
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``|x1_i|``.
877877
878-
If ``x1_i`` is ``NaN``,
879-
880-
- If ``x2_i`` is less than ``0``, the result is ``NaN`` with a sign bit of ``1``.
881-
- If ``x2_i`` is ``-0``, the result is ``NaN`` with a sign bit of ``1``.
882-
- If ``x2_i`` is ``+0``, the result is ``NaN`` with a sign bit of ``0``.
883-
- If ``x2_i`` is greater than ``0``, the result is ``NaN`` with a sign bit of ``0``.
884-
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``NaN`` with a sign bit of ``1``.
885-
- If ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``NaN`` with a sign bit of ``0``.
878+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is less than ``0``, the result is ``NaN`` with a sign bit of ``1``.
879+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``-0``, the result is ``NaN`` with a sign bit of ``1``.
880+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``+0``, the result is ``NaN`` with a sign bit of ``0``.
881+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is greater than ``0``, the result is ``NaN`` with a sign bit of ``0``.
882+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``1``, the result is ``NaN`` with a sign bit of ``1``.
883+
- If ``x1_i`` is ``NaN`` and ``x2_i`` is ``NaN`` and the sign bit of ``x2_i`` is ``0``, the result is ``NaN`` with a sign bit of ``0``.
886884
887885
.. versionadded:: 2023.12
888886
"""

0 commit comments

Comments
 (0)