diff --git a/libcudacxx/include/cuda/__driver/driver_api.h b/libcudacxx/include/cuda/__driver/driver_api.h index dd7c1778c20..37952b77b2d 100644 --- a/libcudacxx/include/cuda/__driver/driver_api.h +++ b/libcudacxx/include/cuda/__driver/driver_api.h @@ -27,6 +27,11 @@ # include # include # include +# if _CCCL_OS(WINDOWS) +# include +# else +# include +# endif # include @@ -46,21 +51,37 @@ _CCCL_BEGIN_NAMESPACE_CUDA_DRIVER _CCCL_SUPPRESS_DEPRECATED_PUSH //! @brief Gets the cuGetProcAddress function pointer. -[[nodiscard]] _CCCL_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)* +[[nodiscard]] _CCCL_PUBLIC_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)* { - // TODO switch to dlopen of libcuda.so instead of the below - void* __fn; - ::cudaDriverEntryPointQueryResult __result; -# if _CCCL_CTK_AT_LEAST(13, 0) - ::cudaError_t __status = - ::cudaGetDriverEntryPointByVersion("cuGetProcAddress", &__fn, 13000, ::cudaEnableDefault, &__result); -# else - ::cudaError_t __status = ::cudaGetDriverEntryPoint("cuGetProcAddress", &__fn, ::cudaEnableDefault, &__result); -# endif - if (__status != ::cudaSuccess || __result != ::cudaDriverEntryPointSuccess) + const char* __fn_name = "cuGetProcAddress_v2"; +# if _CCCL_OS(WINDOWS) + static auto __driver_library = ::LoadLibraryExA("nvcuda.dll", nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32); + if (__driver_library == nullptr) + { + ::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to load nvcuda.dll"); + } + static void* __fn = ::GetProcAddress(__driver_library, __fn_name); + if (__fn == nullptr) + { + ::cuda::__throw_cuda_error(::cudaErrorInitializationError, "Failed to get cuGetProcAddress from nvcuda.dll"); + } +# else // ^^^ _CCCL_OS(WINDOWS) ^^^ / vvv !_CCCL_OS(WINDOWS) vvv +# if _CCCL_OS(ANDROID) + const char* __driver_library_name = "libcuda.so"; +# else // ^^^ _CCCL_OS(ANDROID) ^^^ / vvv !_CCCL_OS(ANDROID) vvv + const char* __driver_library_name = "libcuda.so.1"; +# endif // ^^^ !_CCCL_OS(ANDROID) ^^^ + static void* __driver_library = ::dlopen(__driver_library_name, RTLD_NOW); + if (__driver_library == nullptr) + { + ::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to load libcuda.so.1"); + } + static void* __fn = ::dlsym(__driver_library, __fn_name); + if (__fn == nullptr) { - ::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to get cuGetProcAddress"); + ::cuda::__throw_cuda_error(::cudaErrorInitializationError, "Failed to get cuGetProcAddress from libcuda.so.1"); } +# endif // ^^^ !_CCCL_OS(WINDOWS) ^^^ return reinterpret_cast(__fn); } @@ -151,7 +172,7 @@ _CCCL_HOST_API inline void __call_driver_fn(Fn __fn, const char* __err_msg, Args //! @return The address of the symbol. //! //! @throws @c cuda::cuda_error if the symbol cannot be obtained or the CUDA driver failed to initialize. -[[nodiscard]] _CCCL_HOST_API inline void* +[[nodiscard]] _CCCL_PUBLIC_HOST_API inline void* __get_driver_entry_point(const char* __name, [[maybe_unused]] int __major = 12, [[maybe_unused]] int __minor = 0) { // Get cuGetProcAddress function and call cuInit(0) only on the first call