Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ extern JIT_EXPORT void *jit_cuda_lookup(const char *name);
* \brief Add CUDA event synchronization between thread state's and external
* CUDA stream.
*
* An event will be recorded into the thread's states stream and the external stream
* will wait on the event before performing any subsequent work.
* An event will be recorded into the thread's states stream and the external
* stream will wait on the event before performing any subsequent work. The
* special value stream==2 denotes the caller's per-thread default stream.
* There is no need to ever synchronize with the global NULL stream, since
* Dr.Jit implicitly synchronizes with respect to it.
*
* \param stream The CUstream handle of the external stream
*/
Expand Down
3 changes: 2 additions & 1 deletion src/cuda_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ bool jitc_cuda_api_init() {
LOAD(cuStreamDestroy, "v2");
LOAD(cuStreamSynchronize);
LOAD(cuStreamWaitEvent);
LOAD(cuStreamWaitEvent_ptsz);
LOAD(cuPointerGetAttribute);
LOAD(cuArrayCreate, "v2");
LOAD(cuArray3DCreate, "v2");
Expand Down Expand Up @@ -174,7 +175,7 @@ void jitc_cuda_api_shutdown() {
Z(cuModuleGetFunction); Z(cuModuleLoadData); Z(cuModuleLoadDataEx); Z(cuModuleUnload);
Z(cuOccupancyMaxPotentialBlockSize); Z(cuCtxPushCurrent);
Z(cuCtxPopCurrent); Z(cuStreamCreate); Z(cuStreamDestroy);
Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuPointerGetAttribute);
Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuStreamWaitEvent_ptsz); Z(cuPointerGetAttribute);
Z(cuArrayCreate); Z(cuArray3DCreate); Z(cuArray3DGetDescriptor);
Z(cuArrayDestroy); Z(cuTexObjectCreate); Z(cuTexObjectGetResourceDesc);
Z(cuTexObjectDestroy); Z(cuMemcpy2DAsync); Z(cuMemcpy3DAsync);
Expand Down
1 change: 1 addition & 0 deletions src/cuda_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ DR_CUDA_SYM(CUresult (*cuStreamCreate)(CUstream *, unsigned int));
DR_CUDA_SYM(CUresult (*cuStreamDestroy)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamSynchronize)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamWaitEvent)(CUstream, CUevent, unsigned int));
DR_CUDA_SYM(CUresult (*cuStreamWaitEvent_ptsz)(CUstream, CUevent, unsigned int));
DR_CUDA_SYM(CUresult (*cuMemAllocAsync)(CUdeviceptr *, size_t, CUstream));
DR_CUDA_SYM(CUresult (*cuMemFreeAsync)(CUdeviceptr, CUstream));

Expand Down
8 changes: 6 additions & 2 deletions src/cuda_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,12 @@ std::pair<CUmodule, bool> jitc_cuda_compile(const char *buf, bool release_state_
void jitc_cuda_sync_stream(uintptr_t stream) {
ThreadState* ts = thread_state(JitBackend::CUDA);
CUevent sync_event = ts->sync_stream_event;
cuda_check(cuEventRecord(sync_event, (CUstream)ts->stream));
cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
scoped_set_context guard(ts->context);
cuda_check(cuEventRecord(sync_event, (CUstream) ts->stream));
if (stream != 2)
cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
else
cuda_check(cuStreamWaitEvent_ptsz(nullptr, sync_event, CU_EVENT_DEFAULT));
}

void cuda_check_impl(CUresult errval, const char *file, const int line) {
Expand Down