Skip to content

Commit 90c18a8

Browse files
authored
JIT LTO Cagra Search (#1807)
CUDA 13 binary size reduction from 282 MB to 257 MB (-8.86%). Benchmark: <img width="1137" height="776" alt="image" src="https://github.com/user-attachments/assets/81b4f25a-c999-423d-a2c9-3e45566904bf" /> Apply updates from CAGRA related PRs: - [x] #1800 - [ ] #1781 - [x] #1780 - [x] #1771 - [x] #1834 - [x] #1851 - [x] #1831 JIT related PRs: - [x] #1812 Authors: - Divye Gala (https://github.com/divyegala) - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Dante Gama Dessavre (https://github.com/dantegd) - Bradley Dice (https://github.com/bdice) Approvers: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Dante Gama Dessavre (https://github.com/dantegd) URL: #1807
1 parent ca84936 commit 90c18a8

78 files changed

Lines changed: 7557 additions & 4389 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

conda/recipes/libcuvs/recipe.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,13 @@ outputs:
399399
- librmm =${{ minor_version }}
400400
- nccl ${{ nccl_version }}
401401
- cuda-cudart-dev
402+
- cuda-nvrtc-dev
402403
- cuda-profiler-api
403404
- libcublas-dev
404405
- libcurand-dev
405406
- libcusolver-dev
406407
- libcusparse-dev
407408
- libnvjitlink-dev
408-
- cuda-nvrtc-dev
409409
run:
410410
- ${{ pin_subpackage("libcuvs-headers", exact=True) }}
411411
- ${{ pin_subpackage("libcuvs", exact=True) }}

cpp/CMakeLists.txt

Lines changed: 178 additions & 43 deletions
Large diffs are not rendered by default.

cpp/cmake/config.json

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
"kwargs": {
2828
"KERNEL_FILE": 1,
2929
"FATBIN_HEADER_FILE": 1,
30-
"LINK_LIBRARIES": "*"
30+
"LINK_LIBRARIES": "*",
31+
"EXTRA_COMPILE_OPTIONS": "*"
3132
}
3233
},
3334
"process_jit_lto_matrix_entry": {
@@ -41,7 +42,8 @@
4142
"FRAGMENT_TAG_HEADER_FILES": "*",
4243
"OUTPUT_DIRECTORY": 1,
4344
"MATRIX_JSON_ENTRY": 1,
44-
"KERNEL_LINK_LIBRARIES": "*"
45+
"KERNEL_LINK_LIBRARIES": "*",
46+
"KERNEL_EXTRA_COMPILE_OPTIONS": "*"
4547
}
4648
},
4749
"generate_jit_lto_kernels": {
@@ -56,7 +58,8 @@
5658
"FRAGMENT_TAG_FORMAT": 1,
5759
"FRAGMENT_TAG_HEADER_FILES": "*",
5860
"OUTPUT_DIRECTORY": 1,
59-
"KERNEL_LINK_LIBRARIES": "*"
61+
"KERNEL_LINK_LIBRARIES": "*",
62+
"KERNEL_EXTRA_COMPILE_OPTIONS": "*"
6063
}
6164
},
6265
"generate_inst_matrix": {

cpp/cmake/modules/generate_jit_lto_kernels.cmake

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/compute_matrix_product.cmake)
1212
function(add_jit_lto_kernel kernel_target)
1313
set(options)
1414
set(one_value KERNEL_FILE FATBIN_HEADER_FILE)
15-
set(multi_value LINK_LIBRARIES)
15+
set(multi_value LINK_LIBRARIES EXTRA_COMPILE_OPTIONS)
1616

1717
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
1818

@@ -22,6 +22,9 @@ function(add_jit_lto_kernel kernel_target)
2222
# through the LINK_LIBRARIES argument.
2323
target_link_libraries(${kernel_target} PRIVATE ${_JIT_LTO_LINK_LIBRARIES})
2424
target_compile_options(${kernel_target} PRIVATE -Xfatbin=--compress-all --compress-mode=size)
25+
if(_JIT_LTO_EXTRA_COMPILE_OPTIONS)
26+
target_compile_options(${kernel_target} PRIVATE ${_JIT_LTO_EXTRA_COMPILE_OPTIONS})
27+
endif()
2528
set_target_properties(
2629
${kernel_target}
2730
PROPERTIES CUDA_SEPARABLE_COMPILATION ON
@@ -43,7 +46,7 @@ function(process_jit_lto_matrix_entry source_list_var)
4346
set(one_value NAME_FORMAT KERNEL_INPUT_FILE OUTPUT_DIRECTORY FRAGMENT_TAG_FORMAT
4447
MATRIX_JSON_ENTRY
4548
)
46-
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES)
49+
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES KERNEL_EXTRA_COMPILE_OPTIONS)
4750

4851
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
4952

@@ -71,6 +74,7 @@ function(process_jit_lto_matrix_entry source_list_var)
7174
KERNEL_FILE "${kernel_file}"
7275
FATBIN_HEADER_FILE "${fatbin_header_file}"
7376
LINK_LIBRARIES ${_JIT_LTO_KERNEL_LINK_LIBRARIES}
77+
EXTRA_COMPILE_OPTIONS ${_JIT_LTO_KERNEL_EXTRA_COMPILE_OPTIONS}
7478
)
7579
list(APPEND ${source_list_var} "${fatbin_header_file}" "${fatbin_file}")
7680
set(${source_list_var}
@@ -84,7 +88,7 @@ function(generate_jit_lto_kernels source_list_var)
8488
set(one_value NAME_FORMAT MATRIX_JSON_FILE MATRIX_JSON_STRING KERNEL_INPUT_FILE
8589
FRAGMENT_TAG_FORMAT OUTPUT_DIRECTORY
8690
)
87-
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES)
91+
set(multi_value KERNEL_LINK_LIBRARIES FRAGMENT_TAG_HEADER_FILES KERNEL_EXTRA_COMPILE_OPTIONS)
8892

8993
cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN})
9094

@@ -121,6 +125,7 @@ function(generate_jit_lto_kernels source_list_var)
121125
OUTPUT_DIRECTORY "${_JIT_LTO_OUTPUT_DIRECTORY}"
122126
MATRIX_JSON_ENTRY "${matrix_json_entry}"
123127
KERNEL_LINK_LIBRARIES ${_JIT_LTO_KERNEL_LINK_LIBRARIES}
128+
KERNEL_EXTRA_COMPILE_OPTIONS ${_JIT_LTO_KERNEL_EXTRA_COMPILE_OPTIONS}
124129
)
125130
endforeach()
126131

cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,24 @@ struct AlgorithmLauncher {
3737
this->call(stream, grid, block, shared_mem, kernel_args);
3838
}
3939

40+
template <typename FuncT, typename... Args>
41+
void dispatch_cooperative(
42+
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args)
43+
{
44+
static_assert(
45+
std::is_same_v<FuncT, void(std::remove_reference_t<Args>...)>,
46+
"dispatch_cooperative() argument types do not match the kernel function signature FuncT");
47+
48+
void* kernel_args[] = {const_cast<void*>(static_cast<void const*>(&args))...};
49+
this->call_cooperative(stream, grid, block, shared_mem, kernel_args);
50+
}
51+
4052
cudaKernel_t get_kernel() { return this->kernel; }
4153

4254
private:
4355
void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
56+
void call_cooperative(
57+
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args);
4458
cudaKernel_t kernel;
4559
cudaLibrary_t library;
4660
};

cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ struct AlgorithmPlanner {
4444
add_fragment(std::make_unique<StaticFatbinFragmentEntry<FragmentTag>>());
4545
}
4646

47+
protected:
48+
/** Extra link-time option strings passed to nvJitLink. Base build()
49+
* always passes "-lto" and "-arch=sm_XX" first; derived planners may append here in their
50+
* constructor body. */
51+
std::vector<std::string> linktime_extra_options;
52+
4753
private:
4854
std::string get_fragments_key() const;
4955
std::shared_ptr<AlgorithmLauncher> build();
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cstdint>
9+
10+
namespace cuvs::neighbors::cagra::detail {
11+
12+
struct tag_dist_f {};
13+
struct tag_metric_l2 {};
14+
struct tag_metric_inner_product {};
15+
struct tag_metric_cosine {};
16+
struct tag_metric_hamming {};
17+
struct tag_codebook_none {};
18+
struct tag_codebook_half {};
19+
struct tag_metric_l1 {};
20+
struct tag_norm_noop {};
21+
struct tag_norm_cosine {};
22+
23+
/// Multi-kernel planners that do not link `sample_filter` into the JIT link (e.g.
24+
/// `random_pickup`). Real filters use `cuvs::neighbors::detail::tag_filter_*` on
25+
/// `CagraPlannerBase`.
26+
struct tag_cagra_jit_sample_filter_link_absent {};
27+
28+
template <typename DataTag,
29+
typename IndexTag,
30+
typename DistanceTag,
31+
typename QueryTag,
32+
typename CodebookTag,
33+
uint32_t TeamSize,
34+
uint32_t DatasetBlockDim,
35+
uint32_t PqBits,
36+
uint32_t PqLen>
37+
struct fragment_tag_setup_workspace {};
38+
39+
template <typename DataTag,
40+
typename IndexTag,
41+
typename DistanceTag,
42+
typename QueryTag,
43+
typename CodebookTag,
44+
uint32_t TeamSize,
45+
uint32_t DatasetBlockDim,
46+
uint32_t PqBits,
47+
uint32_t PqLen>
48+
struct fragment_tag_compute_distance {};
49+
50+
template <typename QueryTag, typename DistanceTag, typename MetricTag>
51+
struct fragment_tag_dist_op {};
52+
53+
template <typename DataTag,
54+
typename IndexTag,
55+
typename DistanceTag,
56+
typename QueryTag,
57+
uint32_t TeamSize,
58+
uint32_t DatasetBlockDim,
59+
typename NormTag>
60+
struct fragment_tag_apply_normalization_standard {};
61+
62+
template <typename DataTag,
63+
typename SourceIndexTag,
64+
typename IndexTag,
65+
typename DistanceTag,
66+
bool TopkByBitonicSort,
67+
bool BitonicSortAndMergeMultiWarps>
68+
struct fragment_tag_search_single_cta {};
69+
70+
template <typename DataTag,
71+
typename SourceIndexTag,
72+
typename IndexTag,
73+
typename DistanceTag,
74+
bool TopkByBitonicSort,
75+
bool BitonicSortAndMergeMultiWarps>
76+
struct fragment_tag_search_single_cta_p {};
77+
78+
template <typename DataTag, typename SourceIndexTag, typename IndexTag, typename DistanceTag>
79+
struct fragment_tag_search_multi_cta {};
80+
81+
template <typename DataTag, typename IndexTag, typename DistanceTag>
82+
struct fragment_tag_random_pickup {};
83+
84+
template <typename DataTag, typename IndexTag, typename DistanceTag, typename SourceIndexTag>
85+
struct fragment_tag_compute_distance_to_child_nodes {};
86+
87+
template <typename IndexTag, typename DistanceTag, typename SourceIndexTag>
88+
struct fragment_tag_apply_filter_kernel {};
89+
90+
template <typename BitsetTag, typename SourceIndexTag, typename FilterTag>
91+
struct fragment_tag_sample_filter {};
92+
93+
} // namespace cuvs::neighbors::cagra::detail

cpp/include/cuvs/detail/jit_lto/common_fragments.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77

88
namespace cuvs::neighbors::detail {
99

10+
struct tag_f {};
11+
struct tag_h {};
12+
struct tag_i8 {};
13+
struct tag_u8 {};
1014
struct tag_filter_none {};
1115
struct tag_filter_bitset {};
1216

1317
struct tag_bitset_u32 {};
1418

19+
struct tag_index_u32 {};
1520
struct tag_index_i64 {};
1621

1722
template <typename BitsetTag, typename IndexTag, typename FilterTag>

cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_fragments.hpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,13 @@
77

88
namespace cuvs::neighbors::ivf_flat::detail {
99

10-
// Tag types for data types
11-
struct tag_f {};
12-
struct tag_h {};
13-
struct tag_i8 {};
14-
struct tag_u8 {};
15-
1610
// Tag types for accumulator types
1711
struct tag_acc_f {};
1812
struct tag_acc_h {};
1913
struct tag_acc_i32 {};
2014
struct tag_acc_u32 {};
2115

22-
// Tag types for distance metrics with full template info
16+
// Tag types for distance metrics
2317
struct tag_metric_euclidean {};
2418
struct tag_metric_inner_product {};
2519
struct tag_metric_custom_udf {};

cpp/src/detail/jit_lto/AlgorithmLauncher.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept
2424
AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept
2525
{
2626
if (this != &other) {
27-
// Unload current library if it exists
2827
if (library != nullptr) { cudaLibraryUnload(library); }
2928
kernel = other.kernel;
3029
library = other.library;
@@ -47,3 +46,21 @@ void AlgorithmLauncher::call(
4746

4847
RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
4948
}
49+
50+
void AlgorithmLauncher::call_cooperative(
51+
cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args)
52+
{
53+
cudaLaunchAttribute attribute[1];
54+
attribute[0].id = cudaLaunchAttributeCooperative;
55+
attribute[0].val.cooperative = 1;
56+
57+
cudaLaunchConfig_t config{};
58+
config.gridDim = grid;
59+
config.blockDim = block;
60+
config.stream = stream;
61+
config.dynamicSmemBytes = shared_mem;
62+
config.numAttrs = 1;
63+
config.attrs = attribute;
64+
65+
RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args));
66+
}

0 commit comments

Comments
 (0)