Skip to content

Commit ec14168

Browse files
committed
Explicitly use the current device resource in DeviceBuffer
Previously we were relying on the C++ and Python-level device resources to agree. But this need not be the case. To avoid this, first get the current deivce resource and then use it when allocating the wrapped C++ device_buffer when creating DeviceBuffers. - Closes #1506
1 parent bd3f0d8 commit ec14168

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

python/rmm/_lib/device_buffer.pxd

+20-6
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,31 @@ from libcpp.memory cimport unique_ptr
1717

1818
from rmm._cuda.stream cimport Stream
1919
from rmm._lib.cuda_stream_view cimport cuda_stream_view
20-
from rmm._lib.memory_resource cimport DeviceMemoryResource
20+
from rmm._lib.memory_resource cimport (
21+
DeviceMemoryResource,
22+
device_memory_resource,
23+
)
2124

2225

2326
cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
2427
cdef cppclass device_buffer:
2528
device_buffer()
26-
device_buffer(size_t size, cuda_stream_view stream) except +
27-
device_buffer(const void* source_data,
28-
size_t size, cuda_stream_view stream) except +
29-
device_buffer(const device_buffer buf,
30-
cuda_stream_view stream) except +
29+
device_buffer(
30+
size_t size,
31+
cuda_stream_view stream,
32+
device_memory_resource *
33+
) except +
34+
device_buffer(
35+
const void* source_data,
36+
size_t size,
37+
cuda_stream_view stream,
38+
device_memory_resource *
39+
) except +
40+
device_buffer(
41+
const device_buffer buf,
42+
cuda_stream_view stream,
43+
device_memory_resource *
44+
) except +
3145
void reserve(size_t new_capacity, cuda_stream_view stream) except +
3246
void resize(size_t new_size, cuda_stream_view stream) except +
3347
void shrink_to_fit(cuda_stream_view stream) except +

python/rmm/_lib/device_buffer.pyx

+11-9
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ from cuda.ccudart cimport (
3232
cudaStream_t,
3333
)
3434

35-
from rmm._lib.memory_resource cimport get_current_device_resource
35+
from rmm._lib.memory_resource cimport (
36+
device_memory_resource,
37+
get_current_device_resource,
38+
)
3639

3740

3841
# The DeviceMemoryResource attribute could be released prematurely
@@ -75,22 +78,21 @@ cdef class DeviceBuffer:
7578
>>> db = rmm.DeviceBuffer(size=5)
7679
"""
7780
cdef const void* c_ptr
78-
81+
cdef device_memory_resource * mr_ptr
82+
# Save a reference to the MR and stream used for allocation
83+
self.mr = get_current_device_resource()
84+
mr_ptr = self.mr.get_mr()
7985
with nogil:
8086
c_ptr = <const void*>ptr
8187

82-
if size == 0:
83-
self.c_obj.reset(new device_buffer())
84-
elif c_ptr == NULL:
85-
self.c_obj.reset(new device_buffer(size, stream.view()))
88+
if c_ptr == NULL or size == 0:
89+
self.c_obj.reset(new device_buffer(size, stream.view(), mr_ptr))
8690
else:
87-
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
91+
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view(), mr_ptr))
8892

8993
if stream.c_is_default():
9094
stream.c_synchronize()
9195

92-
# Save a reference to the MR and stream used for allocation
93-
self.mr = get_current_device_resource()
9496
self.stream = stream
9597

9698
def __len__(self):

python/rmm/_lib/memory_resource.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ cdef extern from "rmm/mr/device/device_memory_resource.hpp" \
3434

3535
cdef class DeviceMemoryResource:
3636
cdef shared_ptr[device_memory_resource] c_obj
37-
cdef device_memory_resource* get_mr(self)
37+
cdef device_memory_resource* get_mr(self) noexcept nogil
3838

3939
cdef class UpstreamResourceAdaptor(DeviceMemoryResource):
4040
cdef readonly DeviceMemoryResource upstream_mr

python/rmm/_lib/memory_resource.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \
218218

219219
cdef class DeviceMemoryResource:
220220

221-
cdef device_memory_resource* get_mr(self):
221+
cdef device_memory_resource* get_mr(self) noexcept nogil:
222222
"""Get the underlying C++ memory resource object."""
223223
return self.c_obj.get()
224224

0 commit comments

Comments
 (0)