1010
1111#include < cuda_runtime.h>
1212
13+ #include < executorch/extension/cuda/caller_stream.h>
1314#include < executorch/runtime/platform/log.h>
1415
1516namespace executorch ::backends::cuda {
@@ -19,6 +20,85 @@ using executorch::runtime::Result;
1920using executorch::runtime::etensor::DeviceIndex;
2021using executorch::runtime::etensor::DeviceType;
2122
23+ namespace {
24+
25+ Error copy_impl (
26+ void * dst,
27+ const void * src,
28+ size_t nbytes,
29+ DeviceIndex index,
30+ cudaMemcpyKind kind) {
31+ ET_CHECK_OR_RETURN_ERROR (
32+ kind == cudaMemcpyHostToDevice || kind == cudaMemcpyDeviceToHost,
33+ InvalidArgument,
34+ " CudaAllocator::copy_impl: unsupported cudaMemcpyKind %d" ,
35+ static_cast <int >(kind));
36+ const char * method = kind == cudaMemcpyHostToDevice
37+ ? " CudaAllocator::copy_host_to_device"
38+ : " CudaAllocator::copy_device_to_host" ;
39+ ET_CHECK_OR_RETURN_ERROR (
40+ dst != nullptr , InvalidArgument, " %s: dst is null" , method);
41+ ET_CHECK_OR_RETURN_ERROR (
42+ src != nullptr , InvalidArgument, " %s: src is null" , method);
43+ ET_CHECK_OR_RETURN_ERROR (
44+ index >= -1 ,
45+ InvalidArgument,
46+ " %s: invalid device index %d (must be >= -1)" ,
47+ method,
48+ static_cast <int >(index));
49+ const auto caller_stream = executorch::extension::cuda::getCallerStream ();
50+ if (caller_stream) {
51+ // TODO: validate caller stream device matches index.
52+ // For now assert index is -1 or 0.
53+ ET_CHECK_OR_RETURN_ERROR (
54+ index == -1 || index == 0 ,
55+ InvalidArgument,
56+ " %s: with caller stream, only supports device 0 or -1 (current), got %d" ,
57+ method,
58+ static_cast <int >(index));
59+ }
60+ if (nbytes == 0 ) {
61+ return Error::Ok;
62+ }
63+
64+ int prev_device = 0 ;
65+ cudaError_t prev_device_err = cudaSuccess;
66+
67+ if (index >= 0 ) {
68+ prev_device_err = cudaGetDevice (&prev_device);
69+ if (prev_device_err == cudaSuccess) {
70+ cudaSetDevice (index);
71+ }
72+ }
73+ cudaError_t err = cudaSuccess;
74+ if (caller_stream) {
75+ err = cudaMemcpyAsync (dst, src, nbytes, kind, *caller_stream);
76+ if (err == cudaSuccess && kind == cudaMemcpyDeviceToHost) {
77+ err = cudaStreamSynchronize (*caller_stream);
78+ }
79+ } else {
80+ err = cudaMemcpy (dst, src, nbytes, kind);
81+ }
82+
83+ if (index >= 0 && prev_device_err == cudaSuccess) {
84+ cudaSetDevice (prev_device);
85+ }
86+
87+ if (err != cudaSuccess) {
88+ ET_LOG (
89+ Error,
90+ " cudaMemcpy %s failed: %s (%zu bytes, device %d)" ,
91+ kind == cudaMemcpyHostToDevice ? " H2D" : " D2H" ,
92+ cudaGetErrorString (err),
93+ nbytes,
94+ static_cast <int >(index));
95+ return Error::Internal;
96+ }
97+ return Error::Ok;
98+ }
99+
100+ } // namespace
101+
22102Result<void *>
23103CudaAllocator::allocate (size_t nbytes, DeviceIndex index, size_t alignment) {
24104 // index == -1 means "use the current CUDA device"; any value < -1 is invalid.
@@ -124,72 +204,20 @@ void CudaAllocator::deallocate(void* ptr, DeviceIndex index) {
124204 }
125205}
126206
127- // TODO(gasoonjia): Add support for async copy
128207Error CudaAllocator::copy_host_to_device (
129208 void * dst,
130209 const void * src,
131210 size_t nbytes,
132211 DeviceIndex index) {
133- int prev_device = 0 ;
134- cudaError_t prev_device_err = cudaSuccess;
135-
136- if (index >= 0 ) {
137- prev_device_err = cudaGetDevice (&prev_device);
138- if (prev_device_err == cudaSuccess) {
139- cudaSetDevice (index);
140- }
141- }
142-
143- cudaError_t err = cudaMemcpy (dst, src, nbytes, cudaMemcpyHostToDevice);
144-
145- if (index >= 0 && prev_device_err == cudaSuccess) {
146- cudaSetDevice (prev_device);
147- }
148-
149- if (err != cudaSuccess) {
150- ET_LOG (
151- Error,
152- " cudaMemcpy H2D failed: %s (%zu bytes, device %d)" ,
153- cudaGetErrorString (err),
154- nbytes,
155- static_cast <int >(index));
156- return Error::Internal;
157- }
158- return Error::Ok;
212+ return copy_impl (dst, src, nbytes, index, cudaMemcpyHostToDevice);
159213}
160214
161- // TODO(gasoonjia): Add support for async copy
162215Error CudaAllocator::copy_device_to_host (
163216 void * dst,
164217 const void * src,
165218 size_t nbytes,
166219 DeviceIndex index) {
167- int prev_device = 0 ;
168- cudaError_t prev_device_err = cudaSuccess;
169-
170- if (index >= 0 ) {
171- prev_device_err = cudaGetDevice (&prev_device);
172- if (prev_device_err == cudaSuccess) {
173- cudaSetDevice (index);
174- }
175- }
176-
177- cudaError_t err = cudaMemcpy (dst, src, nbytes, cudaMemcpyDeviceToHost);
178-
179- if (index >= 0 && prev_device_err == cudaSuccess) {
180- cudaSetDevice (prev_device);
181- }
182-
183- if (err != cudaSuccess) {
184- ET_LOG (
185- Error,
186- " cudaMemcpy D2H failed: %s (%zu bytes, device %d)" ,
187- cudaGetErrorString (err),
188- nbytes,
189- static_cast <int >(index));
190- return Error::Internal;
191- }
192- return Error::Ok;
220+ return copy_impl (dst, src, nbytes, index, cudaMemcpyDeviceToHost);
193221}
194222
195223DeviceType CudaAllocator::device_type () const {
0 commit comments