@@ -15,7 +15,6 @@ import numpy
1515from cuda.core.experimental._utils.cuda_utils import handle_return, driver
1616
1717
18- from cuda.core.experimental._dlpack import make_py_capsule
1918from cuda.core.experimental._memory import Buffer
2019
2120# TODO(leofang): support NumPy structured dtypes
@@ -291,44 +290,6 @@ cdef class StridedMemoryView:
291290 + f" readonly={self.readonly},\n "
292291 + f" exporting_obj={get_simple_repr(self.exporting_obj)})" )
293292
294- def __dlpack__ (
295- self ,
296- *,
297- stream: int | None = None ,
298- max_version: tuple[int , int] | None = None ,
299- dl_device: tuple[int , int] | None = None ,
300- copy: bool | None = None ,
301- ) -> PyCapsule:
302- # Note: we ignore the stream argument entirely (as if it is -1).
303- # It is the user's responsibility to maintain stream order.
304- if dl_device is not None:
305- raise BufferError("Sorry , not supported: dl_device other than None")
306- if copy is True:
307- raise BufferError("Sorry , not supported: copy = True " )
308- if max_version is None:
309- versioned = False
310- else:
311- if not isinstance(max_version , tuple ) or len(max_version ) != 2:
312- raise BufferError(f"Expected max_version tuple[int , int], got {max_version}")
313- versioned = max_version >= (1 , 0 )
314- cdef object dtype = self .get_dtype()
315- if dtype is None:
316- raise ValueError(
317- "Cannot export the StridedMemoryView without a dtype. "
318- "You can create a dtyped view calling view(dtype = ...) method."
319- )
320- capsule = make_py_capsule(
321- self .get_buffer(),
322- versioned,
323- self .ptr,
324- self .get_layout(),
325- _numpy2dlpack_dtype[dtype],
326- )
327- return capsule
328-
329- def __dlpack_device__(self ) -> tuple[int , int]:
330- return self.get_buffer().__dlpack_device__()
331-
332293 cdef inline StridedLayout get_layout(self ):
333294 if self ._layout is None :
334295 if self .dl_tensor:
@@ -478,24 +439,25 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
478439 return buf
479440
480441
481- _numpy2dlpack_dtype = {
482- numpy.dtype(" uint8" ): (< uint8_t> kDLUInt, 8 , 1 ),
483- numpy.dtype(" uint16" ): (< uint8_t> kDLUInt, 16 , 1 ),
484- numpy.dtype(" uint32" ): (< uint8_t> kDLUInt, 32 , 1 ),
485- numpy.dtype(" uint64" ): (< uint8_t> kDLUInt, 64 , 1 ),
486- numpy.dtype(" int8" ): (< uint8_t> kDLInt, 8 , 1 ),
487- numpy.dtype(" int16" ): (< uint8_t> kDLInt, 16 , 1 ),
488- numpy.dtype(" int32" ): (< uint8_t> kDLInt, 32 , 1 ),
489- numpy.dtype(" int64" ): (< uint8_t> kDLInt, 64 , 1 ),
490- numpy.dtype(" float16" ): (< uint8_t> kDLFloat, 16 , 1 ),
491- numpy.dtype(" float32" ): (< uint8_t> kDLFloat, 32 , 1 ),
492- numpy.dtype(" float64" ): (< uint8_t> kDLFloat, 64 , 1 ),
493- numpy.dtype(" complex64" ): (< uint8_t> kDLComplex, 64 , 1 ),
494- numpy.dtype(" complex128" ): (< uint8_t> kDLComplex, 128 , 1 ),
495- numpy.dtype(" bool" ): (< uint8_t> kDLBool, 8 , 1 ),
496- }
497- _typestr2dtype = {dtype.str: dtype for dtype in _numpy2dlpack_dtype.keys()}
498- _typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _numpy2dlpack_dtype.keys()}
442+ _builtin_numeric_dtypes = [
443+ numpy.dtype(" uint8" ),
444+ numpy.dtype(" uint16" ),
445+ numpy.dtype(" uint32" ),
446+ numpy.dtype(" uint64" ),
447+ numpy.dtype(" int8" ),
448+ numpy.dtype(" int16" ),
449+ numpy.dtype(" int32" ),
450+ numpy.dtype(" int64" ),
451+ numpy.dtype(" float16" ),
452+ numpy.dtype(" float32" ),
453+ numpy.dtype(" float64" ),
454+ numpy.dtype(" complex64" ),
455+ numpy.dtype(" complex128" ),
456+ numpy.dtype(" bool" ),
457+ ]
458+ # Doing it once to avoid repeated overhead
459+ _typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
460+ _typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
499461
500462
501463cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
0 commit comments