Skip to content

Commit e80208d

Browse files
committed
🏷️: add Device parametrization
Signed-off-by: nstarman <[email protected]>
1 parent c65c5d5 commit e80208d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"Array",
55
"HasArrayNamespace",
66
"HasDType",
7+
"HasDevice",
78
"HasMatrixTranspose",
89
"HasNDim",
910
"HasShape",
@@ -16,6 +17,7 @@
1617
from ._array import (
1718
Array,
1819
HasArrayNamespace,
20+
HasDevice,
1921
HasDType,
2022
HasMatrixTranspose,
2123
HasNDim,

src/array_api_typing/_array.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
1818
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
19+
DeviceT_co = TypeVar("DeviceT_co", covariant=True, default=object)
1920

2021

2122
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -74,11 +75,11 @@ def dtype(self, /) -> DTypeT_co:
7475
...
7576

7677

77-
class HasDevice(Protocol):
78+
class HasDevice(Protocol[DeviceT_co]):
7879
"""Protocol for array classes that have a device attribute."""
7980

8081
@property
81-
def device(self) -> object: # TODO: more specific type
82+
def device(self) -> DeviceT_co:
8283
"""Hardware device the array data resides on."""
8384
...
8485

@@ -191,7 +192,7 @@ def T(self) -> Self: # noqa: N802
191192
class Array(
192193
# ------ Attributes -------
193194
HasDType[DTypeT_co],
194-
HasDevice,
195+
HasDevice[DeviceT_co],
195196
HasMatrixTranspose,
196197
HasNDim,
197198
HasShape,
@@ -200,14 +201,18 @@ class Array(
200201
# ------- Methods ---------
201202
HasArrayNamespace[NamespaceT_co],
202203
# -------------------------
203-
Protocol[DTypeT_co, NamespaceT_co],
204+
Protocol[DTypeT_co, DeviceT_co, NamespaceT_co],
204205
):
205206
"""Array API specification for array object attributes and methods.
206207
207-
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
208-
NamespaceT]`` where:
208+
The type is: ``Array[+DTypeT, +DeviceT = object, +NamespaceT = ModuleType] =
209+
Array[DTypeT, DeviceT, NamespaceT]`` where:
209210
210211
- `DTypeT` is the data type of the array elements.
212+
- `DeviceT` is the type of the device attribute. It defaults to `object` to
213+
enable skipping device specification. Array objects supporting device
214+
management can specify a more specific type if they use types (as opposed
215+
to object instances) to distinguish between different devices.
211216
- `NamespaceT` is the type of the array namespace. It defaults to
212217
`ModuleType`, which is the most common form of array namespace (e.g.,
213218
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a

0 commit comments

Comments
 (0)