Skip to content

Commit 0f3bc89

Browse files
committed
Disable (for now) exporting the SMV via dlpack
Signed-off-by: Kamil Tokarski <[email protected]>
1 parent cf03092 commit 0f3bc89

File tree

1 file changed

+19
-57
lines changed

1 file changed

+19
-57
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import numpy
1515
from cuda.core.experimental._utils.cuda_utils import handle_return, driver
1616

1717

18-
from cuda.core.experimental._dlpack import make_py_capsule
1918
from cuda.core.experimental._memory import Buffer
2019

2120
# TODO(leofang): support NumPy structured dtypes
@@ -291,44 +290,6 @@ cdef class StridedMemoryView:
291290
+ f" readonly={self.readonly},\n"
292291
+ f" exporting_obj={get_simple_repr(self.exporting_obj)})")
293292

294-
def __dlpack__(
295-
self,
296-
*,
297-
stream: int | None = None,
298-
max_version: tuple[int, int] | None = None,
299-
dl_device: tuple[int, int] | None = None,
300-
copy: bool | None = None,
301-
) -> PyCapsule:
302-
# Note: we ignore the stream argument entirely (as if it is -1).
303-
# It is the user's responsibility to maintain stream order.
304-
if dl_device is not None:
305-
raise BufferError("Sorry, not supported: dl_device other than None")
306-
if copy is True:
307-
raise BufferError("Sorry, not supported: copy=True")
308-
if max_version is None:
309-
versioned = False
310-
else:
311-
if not isinstance(max_version, tuple) or len(max_version) != 2:
312-
raise BufferError(f"Expected max_version tuple[int, int], got {max_version}")
313-
versioned = max_version >= (1, 0)
314-
cdef object dtype = self.get_dtype()
315-
if dtype is None:
316-
raise ValueError(
317-
"Cannot export the StridedMemoryView without a dtype. "
318-
"You can create a dtyped view calling view(dtype=...) method."
319-
)
320-
capsule = make_py_capsule(
321-
self.get_buffer(),
322-
versioned,
323-
self.ptr,
324-
self.get_layout(),
325-
_numpy2dlpack_dtype[dtype],
326-
)
327-
return capsule
328-
329-
def __dlpack_device__(self) -> tuple[int, int]:
330-
return self.get_buffer().__dlpack_device__()
331-
332293
cdef inline StridedLayout get_layout(self):
333294
if self._layout is None:
334295
if self.dl_tensor:
@@ -478,24 +439,25 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
478439
return buf
479440

480441

481-
_numpy2dlpack_dtype = {
482-
numpy.dtype("uint8"): (<uint8_t>kDLUInt, 8, 1),
483-
numpy.dtype("uint16"): (<uint8_t>kDLUInt, 16, 1),
484-
numpy.dtype("uint32"): (<uint8_t>kDLUInt, 32, 1),
485-
numpy.dtype("uint64"): (<uint8_t>kDLUInt, 64, 1),
486-
numpy.dtype("int8"): (<uint8_t>kDLInt, 8, 1),
487-
numpy.dtype("int16"): (<uint8_t>kDLInt, 16, 1),
488-
numpy.dtype("int32"): (<uint8_t>kDLInt, 32, 1),
489-
numpy.dtype("int64"): (<uint8_t>kDLInt, 64, 1),
490-
numpy.dtype("float16"): (<uint8_t>kDLFloat, 16, 1),
491-
numpy.dtype("float32"): (<uint8_t>kDLFloat, 32, 1),
492-
numpy.dtype("float64"): (<uint8_t>kDLFloat, 64, 1),
493-
numpy.dtype("complex64"): (<uint8_t>kDLComplex, 64, 1),
494-
numpy.dtype("complex128"): (<uint8_t>kDLComplex, 128, 1),
495-
numpy.dtype("bool"): (<uint8_t>kDLBool, 8, 1),
496-
}
497-
_typestr2dtype = {dtype.str: dtype for dtype in _numpy2dlpack_dtype.keys()}
498-
_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _numpy2dlpack_dtype.keys()}
442+
_builtin_numeric_dtypes = [
443+
numpy.dtype("uint8"),
444+
numpy.dtype("uint16"),
445+
numpy.dtype("uint32"),
446+
numpy.dtype("uint64"),
447+
numpy.dtype("int8"),
448+
numpy.dtype("int16"),
449+
numpy.dtype("int32"),
450+
numpy.dtype("int64"),
451+
numpy.dtype("float16"),
452+
numpy.dtype("float32"),
453+
numpy.dtype("float64"),
454+
numpy.dtype("complex64"),
455+
numpy.dtype("complex128"),
456+
numpy.dtype("bool"),
457+
]
458+
# Doing it once to avoid repeated overhead
459+
_typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
460+
_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
499461

500462

501463
cdef object dtype_dlpack_to_numpy(DLDataType* dtype):

0 commit comments

Comments
 (0)