diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h index c923ef7d..5c3cbd37 100644 --- a/include/drjit-core/jit.h +++ b/include/drjit-core/jit.h @@ -199,8 +199,11 @@ extern JIT_EXPORT void *jit_cuda_lookup(const char *name); * \brief Add CUDA event synchronization between thread state's and external * CUDA stream. * - * An event will be recorded into the thread's states stream and the external stream - * will wait on the event before performing any subsequent work. + * An event will be recorded into the thread's states stream and the external + * stream will wait on the event before performing any subsequent work. The + * special value stream==2 denotes the caller's per-thread default stream. + * There is no need to ever synchronize with the global NULL stream, since + * Dr.Jit implicitly synchronizes with respect to it. * * \param stream The CUstream handle of the external stream */ diff --git a/src/cuda_api.cpp b/src/cuda_api.cpp index 293f4de4..b3d0f70e 100644 --- a/src/cuda_api.cpp +++ b/src/cuda_api.cpp @@ -123,6 +123,7 @@ bool jitc_cuda_api_init() { LOAD(cuStreamDestroy, "v2"); LOAD(cuStreamSynchronize); LOAD(cuStreamWaitEvent); + LOAD(cuStreamWaitEvent_ptsz); LOAD(cuPointerGetAttribute); LOAD(cuArrayCreate, "v2"); LOAD(cuArray3DCreate, "v2"); @@ -174,7 +175,7 @@ void jitc_cuda_api_shutdown() { Z(cuModuleGetFunction); Z(cuModuleLoadData); Z(cuModuleLoadDataEx); Z(cuModuleUnload); Z(cuOccupancyMaxPotentialBlockSize); Z(cuCtxPushCurrent); Z(cuCtxPopCurrent); Z(cuStreamCreate); Z(cuStreamDestroy); - Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuPointerGetAttribute); + Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuStreamWaitEvent_ptsz); Z(cuPointerGetAttribute); Z(cuArrayCreate); Z(cuArray3DCreate); Z(cuArray3DGetDescriptor); Z(cuArrayDestroy); Z(cuTexObjectCreate); Z(cuTexObjectGetResourceDesc); Z(cuTexObjectDestroy); Z(cuMemcpy2DAsync); Z(cuMemcpy3DAsync); diff --git a/src/cuda_api.h b/src/cuda_api.h index a2a04c09..142d9e59 100644 --- a/src/cuda_api.h +++ b/src/cuda_api.h @@ -260,6 +260,7 @@ DR_CUDA_SYM(CUresult (*cuStreamCreate)(CUstream *, unsigned int)); DR_CUDA_SYM(CUresult (*cuStreamDestroy)(CUstream)); DR_CUDA_SYM(CUresult (*cuStreamSynchronize)(CUstream)); DR_CUDA_SYM(CUresult (*cuStreamWaitEvent)(CUstream, CUevent, unsigned int)); +DR_CUDA_SYM(CUresult (*cuStreamWaitEvent_ptsz)(CUstream, CUevent, unsigned int)); DR_CUDA_SYM(CUresult (*cuMemAllocAsync)(CUdeviceptr *, size_t, CUstream)); DR_CUDA_SYM(CUresult (*cuMemFreeAsync)(CUdeviceptr, CUstream)); diff --git a/src/cuda_core.cpp b/src/cuda_core.cpp index 1e54389a..cbcbdda9 100644 --- a/src/cuda_core.cpp +++ b/src/cuda_core.cpp @@ -109,8 +109,12 @@ std::pair jitc_cuda_compile(const char *buf, bool release_state_ void jitc_cuda_sync_stream(uintptr_t stream) { ThreadState* ts = thread_state(JitBackend::CUDA); CUevent sync_event = ts->sync_stream_event; - cuda_check(cuEventRecord(sync_event, (CUstream)ts->stream)); - cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT)); + scoped_set_context guard(ts->context); + cuda_check(cuEventRecord(sync_event, (CUstream) ts->stream)); + if (stream != 2) + cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT)); + else + cuda_check(cuStreamWaitEvent_ptsz(nullptr, sync_event, CU_EVENT_DEFAULT)); } void cuda_check_impl(CUresult errval, const char *file, const int line) {