Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate using "cpu" as a device for to_device() with CuPy arrays #87

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
from typing import NamedTuple
from types import ModuleType
import inspect
import sys

from ._helpers import _check_device, is_numpy_array, array_namespace
from ._helpers import is_numpy_array, array_namespace

def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

# These functions are modified from the NumPy versions.

Expand Down
16 changes: 10 additions & 6 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import math
import inspect
import warnings

def is_numpy_array(x):
# Avoid importing NumPy if it isn't already
Expand Down Expand Up @@ -154,11 +155,6 @@ def your_function(x, y):
# backwards compatibility alias
get_namespace = array_namespace

def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
Expand Down Expand Up @@ -204,7 +200,15 @@ def _cupy_to_device(x, device, /, stream=None):
elif device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
# host and device backends. This is deprecated. See
# https://github.com/data-apis/array-api-compat/issues/86.
# Implementations should prefer using the device keyword to
# from_dlpack() (starting from the 2023.12 version of the standard)
warnings.warn(
"Using `to_device(x, 'cpu')` on CuPy arrays is deprecated. Use the `device` keyword to `from_dlpack()` instead.",
DeprecationWarning,
stacklevel=3,
)
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
Expand Down
2 changes: 0 additions & 2 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from ..._internal import get_xp
from ...common import _aliases, _linalg
from ...common._helpers import _check_device

if TYPE_CHECKING:
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -38,7 +37,6 @@ def dask_arange(
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
Expand Down
Loading