Skip to content

Commit 567eeed

Browse files
jinhongyiiMasterJH5574CharlieFRuanyingchen21
authored
[Runtime][Dist] Implementation of KV cache transfer (#17557)
This PR introduces kv transfer kernel and KV cache integration used in prefill-decode disaggregation. Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Charlie Ruan <[email protected]> Co-authored-by: Yingcheng Wang <[email protected]>
1 parent 4454f8d commit 567eeed

File tree

15 files changed

+1743
-40
lines changed

15 files changed

+1743
-40
lines changed

Diff for: 3rdparty/flashinfer

Diff for: CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM)
478478
if (NOT NVSHMEM_FOUND)
479479
message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
480480
endif()
481-
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
481+
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
482+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
483+
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu)
482484
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
483485
endif()
484486

Diff for: docs/how_to/tutorials/optimize_llm.py

+1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def create_tir_paged_kv_cache(
303303
rotary_dim=self.head_dim,
304304
dtype=self.dtype,
305305
target=target,
306+
enable_disaggregation=False,
306307
)
307308

308309
def get_default_spec(self):

Diff for: python/tvm/relax/frontend/nn/llm/kv_cache.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals
169169
rope_scaling: Dict[str, Any],
170170
rope_ext_factors: rx.Expr,
171171
rotary_dim: int,
172+
enable_disaggregation: bool,
172173
dtype: str,
173174
target: Target,
174175
name: str = "paged_kv_cache",
@@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals
214215
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
215216
rotary_dim : int
216217
The number of dimensions in the embedding that RoPE is applied to.
218+
enable_disaggregation : bool
219+
Whether to enable disaggregation in the KV cache.
217220
"""
218221
if rope_mode == RopeMode.INLINE:
219222
assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim."
@@ -259,6 +262,7 @@ def __init__( # pylint: disable=too-many-locals
259262
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
260263
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
261264
rope_ext_factors,
265+
rx.PrimValue(enable_disaggregation),
262266
# fmt: on
263267
# pylint: enable=line-too-long
264268
]
@@ -293,6 +297,7 @@ def __init__( # pylint: disable=too-many-locals
293297
rope_scaling: Dict[str, Any],
294298
rope_ext_factors: rx.Expr,
295299
rotary_dim: int,
300+
enable_disaggregation: bool,
296301
dtype: str,
297302
target: Target,
298303
name: str = "paged_kv_cache",
@@ -338,6 +343,8 @@ def __init__( # pylint: disable=too-many-locals
338343
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
339344
rotary_dim : int
340345
The number of dimensions in the embedding that RoPE is applied to.
346+
enable_disaggregation : bool
347+
Whether to enable disaggregation in the KV cache.
341348
target : Target
342349
The target to build the model to.
343350
"""
@@ -377,6 +384,7 @@ def __init__( # pylint: disable=too-many-locals
377384
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
378385
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
379386
rope_ext_factors,
387+
rx.PrimValue(enable_disaggregation),
380388
# fmt: on
381389
# pylint: enable=line-too-long
382390
]
@@ -409,8 +417,9 @@ def tir_kv_cache_transpose_append(
409417
T.func_attr({"tir.noalias": T.bool(True)})
410418
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
411419
num_pages = T.int64()
420+
pages_elem_offset = T.int64()
412421
position_map_elem_offset = T.int32()
413-
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype)
422+
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset)
414423
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype)
415424
v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype)
416425
position_map = T.match_buffer(
@@ -453,8 +462,9 @@ def tir_kv_cache_debug_get_kv(
453462
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
454463
page_size = T.SizeVar("page_size", "int64")
455464
num_pages = T.int64()
465+
pages_elem_offset = T.int64()
456466
position_map_elem_offset = T.int64()
457-
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
467+
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype,elem_offset=pages_elem_offset)
458468
position_map = T.match_buffer(
459469
var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset
460470
)
@@ -594,6 +604,7 @@ def batch_prefill_paged_kv(
594604
total_len = T.int32(is_size_var=True)
595605
nnz_pages = T.int32(is_size_var=True)
596606
max_num_pages = T.int32(is_size_var=True)
607+
pages_elem_offset = T.int64(is_size_var=True)
597608
q_indptr_elem_offset = T.int32(is_size_var=True)
598609
page_indptr_elem_offset = T.int32(is_size_var=True)
599610
page_values_elem_offset = T.int32(is_size_var=True)
@@ -603,7 +614,7 @@ def batch_prefill_paged_kv(
603614

604615
q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
605616
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset)
606-
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype)
617+
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset)
607618
page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset)
608619
page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
609620
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
@@ -975,6 +986,7 @@ def batch_decode_paged_kv(
975986
B = T.int32(is_size_var=True)
976987
nnz_pages = T.int32(is_size_var=True)
977988
max_num_pages = T.int32(is_size_var=True)
989+
pages_elem_offset = T.int64(is_size_var=True)
978990
page_indptr_elem_offset = T.int32(is_size_var=True)
979991
page_values_elem_offset = T.int32(is_size_var=True)
980992
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
@@ -983,7 +995,7 @@ def batch_decode_paged_kv(
983995

984996
Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
985997
pages = T.match_buffer(
986-
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
998+
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype, elem_offset=pages_elem_offset
987999
)
9881000
page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset)
9891001
page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
@@ -1949,7 +1961,13 @@ def copy_single_page(
19491961
):
19501962
T.func_attr({"tir.is_scheduled": 1})
19511963
num_pages = T.int32()
1952-
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype)
1964+
pages_elem_offset = T.int64()
1965+
pages = T.match_buffer(
1966+
var_pages,
1967+
(num_pages, 2, num_heads, page_size, head_dim),
1968+
dtype,
1969+
elem_offset=pages_elem_offset,
1970+
)
19531971

19541972
for b in T.thread_binding(
19551973
(copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
@@ -1993,7 +2011,10 @@ def compact_kv_copy(
19932011
total_copy_length = T.int32()
19942012
copy_length_indptr_elem_offset = T.int32()
19952013
copy_src_dst_pos_elem_offset = T.int32()
1996-
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype)
2014+
pages_elem_offset = T.int64()
2015+
pages = T.match_buffer(
2016+
var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset
2017+
)
19972018
copy_length_indptr = T.match_buffer(
19982019
var_copy_length_indptr,
19992020
(batch_size + 1,),

Diff for: src/runtime/contrib/nvshmem/init.cc

+54-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
#include <nvshmem.h>
2020
#include <nvshmemx.h>
21+
#include <picojson.h>
2122
#include <tvm/runtime/disco/disco_worker.h>
2223
#include <tvm/runtime/packed_func.h>
2324
#include <tvm/runtime/registry.h>
@@ -38,9 +39,14 @@ ShapeTuple InitNVSHMEMUID() {
3839
return ShapeTuple(uid_64);
3940
}
4041

41-
void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
42-
DiscoWorker* worker = DiscoWorker::ThreadLocal();
43-
ICHECK(worker != nullptr);
42+
void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
43+
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
44+
int worker_id;
45+
if (worker == nullptr) {
46+
worker_id = worker_id_start;
47+
} else {
48+
worker_id = worker_id_start + worker->worker_id;
49+
}
4450
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
4551
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got "
4652
<< uid_64.size() << ".";
@@ -52,17 +58,61 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
5258
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
5359
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
5460
}
55-
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
61+
// FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
62+
cudaSetDevice(worker_id);
63+
nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
5664
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
5765
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
5866
CUDA_CALL(cudaSetDevice(mype_node));
67+
if (worker != nullptr) {
68+
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
69+
worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
70+
} else {
71+
ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
72+
worker->default_device.device_id == mype_node)
73+
<< "The default device of the worker is inconsistent with the device used for NVSHMEM. "
74+
<< "The default device is " << worker->default_device
75+
<< ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node}
76+
<< ".";
77+
}
78+
}
5979
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
6080
<< ", npes=" << nvshmem_n_pes();
6181
}
6282

83+
void InitNVSHMEMWrapper(String args) {
84+
picojson::value v;
85+
std::string err = picojson::parse(v, args);
86+
if (!err.empty()) {
87+
LOG(FATAL) << "JSON parse error: " << err;
88+
}
89+
90+
if (!v.is<picojson::object>()) {
91+
LOG(FATAL) << "JSON is not an object";
92+
}
93+
94+
picojson::object& obj = v.get<picojson::object>();
95+
96+
picojson::array uid_array = obj["uid"].get<picojson::array>();
97+
std::vector<int64_t> uid_vector;
98+
for (const auto& elem : uid_array) {
99+
uid_vector.push_back(elem.get<int64_t>());
100+
}
101+
102+
ShapeTuple uid_64(uid_vector);
103+
104+
int num_workers = static_cast<int>(obj["npes"].get<int64_t>());
105+
int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());
106+
107+
InitNVSHMEM(uid_64, num_workers, worker_id_start);
108+
}
109+
63110
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);
64111

65112
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
66113

114+
TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
115+
.set_body_typed(InitNVSHMEMWrapper);
116+
67117
} // namespace runtime
68118
} // namespace tvm

0 commit comments

Comments
 (0)