Skip to content

Commit 4de3152

Browse files
authored
[CUDA] Fix cuda 13 build (microsoft#26153)
Fix cuda 13 build errors and warnings. Related: microsoft#25936 I've verified the build in Linux and Windows using the following test settings: ### Build command line You may need change cuda_home and cudnn_home to your installation directories, also update CMAKE_CUDA_ARCHITECTURES according to your GPU. #### Linux Build ``` pip install cmake ninja packaging numpy sh build.sh --config Release --build_dir build/cuda13 --parallel --use_cuda \ --cuda_version 12.8 --cuda_home /nvida/cuda13.0/ \ --cudnn_home /nvida/cudnn9.12_cu13/ \ --build_wheel --skip_tests \ --cmake_generator Ninja \ --enable_cuda_nhwc_ops \ --use_binskim_compliant_compile_flags \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90-real;90-virtual \ --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON \ --cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON ``` #### Windows Build ``` IF "%VCToolsVersion%"=="" call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" build.bat --cmake_generator "Visual Studio 17 2022" --config Release --build_dir build\cuda13 --build_wheel ^ --parallel 4 --nvcc_threads 1 --build_shared_lib ^ --use_cuda --cuda_version "13.0" --cuda_home "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0" ^ --cudnn_home "D:\cudnn\9.13.0.50_cuda13" ^ --skip_tests ^ --use_binskim_compliant_compile_flags ^ --enable_cuda_nhwc_ops ^ --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=native" ^ --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON ^ --cmake_extra_defines FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER ``` The onnxruntime_test_all.exe is passed for RTX 5060 Ti GPU, so the binary can support blackwell GPU (CUDA_ARCHITECTURES=120) properly with CUDA 13.0: ``` [----------] Global test environment tear-down [==========] 1242 tests from 111 test suites ran. (83468 ms total) [ PASSED ] 1242 tests. ```
1 parent 42fcd71 commit 4de3152

File tree

17 files changed

+93
-40
lines changed

17 files changed

+93
-40
lines changed

cmake/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ function(onnxruntime_configure_target target_name)
11921192

11931193
# Keep BinSkim happy
11941194
if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM")
1195-
target_link_options(${target_name} PRIVATE "/CETCOMPAT")
1195+
target_link_options(${target_name} PRIVATE "$<$<LINK_LANGUAGE:CXX,C>:/CETCOMPAT>" "$<$<LINK_LANGUAGE:CUDA>:-Xlinker=/CETCOMPAT>")
11961196
endif()
11971197

11981198
endfunction()
@@ -1421,7 +1421,6 @@ configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_c
14211421
get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
14221422

14231423
if (onnxruntime_USE_CUDA)
1424-
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
14251424
set(CMAKE_CUDA_STANDARD 17)
14261425
if(onnxruntime_CUDA_HOME)
14271426
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
@@ -1441,6 +1440,14 @@ if (onnxruntime_USE_CUDA)
14411440
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all")
14421441
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
14431442
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
1443+
1444+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
1445+
if (UNIX)
1446+
# Suppress deprecation errors (e.g., long4 in CUDA 13)
1447+
add_compile_options(-Wno-deprecated-declarations)
1448+
endif()
1449+
endif()
1450+
14441451
if (NOT WIN32)
14451452
list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
14461453
endif()

cmake/external/cuda_configuration.cmake

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ macro(setup_cuda_compiler)
5858
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION)
5959
message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}")
6060
endif()
61+
62+
# For CUDA 13+, explicitly set the compiler front-end to Clang to handle
63+
# MSVC-specific pragmas correctly in device code.
64+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0 AND NOT DEFINED CMAKE_CUDA_COMPILER_FRONTEND_VARIANT)
65+
message(STATUS "Setting CUDA compiler front-end to Clang by default for CUDA 13+.")
66+
set(CMAKE_CUDA_COMPILER_FRONTEND_VARIANT "CLANG")
67+
endif()
68+
69+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
70+
set(CMAKE_CUDA_RUNTIME_LIBRARY "Hybrid")
71+
else()
72+
set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared")
73+
endif()
6174
endmacro()
6275

6376
macro(setup_cuda_architectures)

cmake/onnxruntime_providers_cuda.cmake

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,17 @@
191191
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=221>")
192192
endif()
193193

194+
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
195+
if (UNIX)
196+
# Suppress -Wattributes warning from protobuf headers with nvcc on Linux
197+
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-attributes>")
198+
endif()
199+
200+
if (MSVC)
201+
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=20199>")
202+
endif()
203+
endif()
204+
194205
if (UNIX)
195206
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-reorder>"
196207
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wno-reorder>")

cmake/onnxruntime_unittests.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1523,7 +1523,7 @@ endif()
15231523
list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo)
15241524
endif()
15251525
if (onnxruntime_USE_CUDA)
1526-
list(APPEND onnxruntime_shared_lib_test_LIBS)
1526+
list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart)
15271527
endif()
15281528

15291529
if (onnxruntime_USE_TENSORRT)
@@ -1751,6 +1751,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
17511751
if (HAS_QSPECTRE)
17521752
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /Qspectre>")
17531753
endif()
1754+
set(custom_op_lib_link ${custom_op_lib_link} CUDA::cudart)
17541755
endif()
17551756

17561757
file(GLOB custom_op_src ${custom_op_src_patterns})

onnxruntime/contrib_ops/cuda/bert/attention_impl.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ namespace cuda {
2222

2323
constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128;
2424

25+
// longlong4 is deprecated in cuda 13.
26+
// LongLong4 is similar to longlong4_32a, except this is also visible in Host compiler (longlong4_32a is only visible to nvcc);
27+
typedef struct __align__(32) {
28+
long long int x, y, z, w;
29+
} LongLong4;
30+
2531
// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that.
2632
struct CumulatedSequenceLengthCache {
2733
onnxruntime::IAllocatorUniquePtr<void> buffer;
@@ -144,14 +150,14 @@ Status PastPresentBufferShare(int batch_size, int num_heads, int qk_head_size, i
144150
template <typename T>
145151
Status LaunchStridedCopy(
146152
cudaStream_t stream,
147-
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
148-
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
153+
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
154+
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
149155
int max_threads_per_block);
150156

151157
template <typename T>
152158
Status LaunchStridedCopy(cudaStream_t stream,
153-
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
154-
T* out, longlong4 out_strides, // coord (b,n,s,h)
159+
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
160+
T* out, LongLong4 out_strides, // coord (b,n,s,h)
155161
int max_threads_per_block);
156162

157163
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ namespace contrib {
1111
namespace cuda {
1212

1313
template <typename T>
14-
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
15-
T* out, longlong4 out_strides, // coord (b,n,s,h)
14+
__global__ void StridedCopy(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h)
15+
T* out, LongLong4 out_strides, // coord (b,n,s,h)
1616
const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) {
1717
const int h = threadIdx.x;
1818
const int n = threadIdx.y;
@@ -30,8 +30,8 @@ __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, //
3030
}
3131

3232
template <typename T>
33-
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
34-
T* out, longlong4 out_strides, // coord (b,n,s,h)
33+
__global__ void StridedCopyLarge(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h)
34+
T* out, LongLong4 out_strides, // coord (b,n,s,h)
3535
const int* in_seqlens_offset, const int* out_seqlens_offset) {
3636
// Use when (H*)*num_heads > 1024
3737
int h = threadIdx.x;
@@ -77,7 +77,7 @@ struct ToByteType<16> {
7777

7878
template <>
7979
struct ToByteType<32> {
80-
using T = ulonglong4;
80+
using T = LongLong4;
8181
};
8282

8383
template <int NumBytes>
@@ -86,8 +86,8 @@ using ToBytes = typename ToByteType<NumBytes>::T;
8686
template <typename T>
8787
Status LaunchStridedCopy(
8888
cudaStream_t stream,
89-
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
90-
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
89+
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
90+
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
9191
int max_threads_per_block) {
9292
int batch_size = in_shape.x;
9393
int num_heads = in_shape.y;
@@ -157,8 +157,8 @@ Status LaunchStridedCopy(
157157

158158
template <typename T>
159159
Status LaunchStridedCopy(cudaStream_t stream,
160-
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
161-
T* out, longlong4 out_strides, // coord (b,n,s,h)
160+
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
161+
T* out, LongLong4 out_strides, // coord (b,n,s,h)
162162
int max_threads_per_block) {
163163
const int* in_seqlens_offset = nullptr;
164164
const int* out_seqlens_offset = nullptr;
@@ -170,14 +170,14 @@ Status LaunchStridedCopy(cudaStream_t stream,
170170

171171
template Status LaunchStridedCopy<float>(
172172
cudaStream_t stream,
173-
const float* in, int4 in_shape, longlong4 in_strides,
174-
float* out, longlong4 out_strides,
173+
const float* in, int4 in_shape, LongLong4 in_strides,
174+
float* out, LongLong4 out_strides,
175175
int max_threads_per_block);
176176

177177
template Status LaunchStridedCopy<half>(
178178
cudaStream_t stream,
179-
const half* in, int4 in_shape, longlong4 in_strides,
180-
half* out, longlong4 out_strides,
179+
const half* in, int4 in_shape, LongLong4 in_strides,
180+
half* out, LongLong4 out_strides,
181181
int max_threads_per_block);
182182

183183
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131

3232
#pragma once
3333

34+
#include "core/providers/cuda/curand_wrapper.h"
35+
3436
#ifdef HAS_PYTORCH
3537
#include <ATen/cuda/CUDAGeneratorImpl.h>
3638
#include <ATen/cuda/CUDAGraphsUtils.cuh>
3739
#endif
3840

39-
#include <curand_kernel.h>
4041
#include <cmath>
4142
#include <cinttypes>
4243
#include <vector>

onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
#include "cute/atom/copy_traits_sm90_tma.hpp"
5050
#include "cute/atom/mma_atom.hpp"
5151
#include "cute/numeric/arithmetic_tuple.hpp"
52-
#include "cute/tensor_predicate.hpp"
5352
#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h"
5453

5554
/////////////////////////////////////////////////////////////////////////////////////////////////

onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
#pragma once
55

66
#include <stdint.h>
7+
#include "core/providers/cuda/curand_wrapper.h"
78
#include <cuda_fp16.h>
8-
#include <curand_kernel.h>
9+
910
#include <cstdio>
1011
#include "contrib_ops/cpu/transformers/generation_shared.h"
1112

onnxruntime/contrib_ops/rocm/bert/attention_impl.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ namespace onnxruntime {
1414
namespace contrib {
1515
namespace rocm {
1616

17+
typedef struct __align__(32) {
18+
long long int x, y, z, w;
19+
} LongLong4;
20+
1721
size_t GetAttentionScratchSize(
1822
size_t element_size,
1923
int batch_size,
@@ -162,14 +166,14 @@ Status ClassifyAttentionMode(AttentionType type,
162166
template <typename T>
163167
Status LaunchStridedCopy(
164168
hipStream_t stream,
165-
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
166-
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
169+
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
170+
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
167171
int max_threads_per_block);
168172

169173
template <typename T>
170174
Status LaunchStridedCopy(hipStream_t stream,
171-
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
172-
T* out, longlong4 out_strides, // coord (b,n,s,h)
175+
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
176+
T* out, LongLong4 out_strides, // coord (b,n,s,h)
173177
int max_threads_per_block);
174178
} // namespace rocm
175179
} // namespace contrib

0 commit comments

Comments
 (0)