@@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals
169
169
rope_scaling : Dict [str , Any ],
170
170
rope_ext_factors : rx .Expr ,
171
171
rotary_dim : int ,
172
+ enable_disaggregation : bool ,
172
173
dtype : str ,
173
174
target : Target ,
174
175
name : str = "paged_kv_cache" ,
@@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals
214
215
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
215
216
rotary_dim : int
216
217
The number of dimensions in the embedding that RoPE is applied to.
218
+ enable_disaggregation : bool
219
+ Whether to enable disaggregation in the KV cache.
217
220
"""
218
221
if rope_mode == RopeMode .INLINE :
219
222
assert rotary_dim == head_dim , "FlashInfer RoPE does not support partial rotary dim."
@@ -259,6 +262,7 @@ def __init__( # pylint: disable=too-many-locals
259
262
bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" ),
260
263
bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" ),
261
264
rope_ext_factors ,
265
+ rx .PrimValue (enable_disaggregation ),
262
266
# fmt: on
263
267
# pylint: enable=line-too-long
264
268
]
@@ -293,6 +297,7 @@ def __init__( # pylint: disable=too-many-locals
293
297
rope_scaling : Dict [str , Any ],
294
298
rope_ext_factors : rx .Expr ,
295
299
rotary_dim : int ,
300
+ enable_disaggregation : bool ,
296
301
dtype : str ,
297
302
target : Target ,
298
303
name : str = "paged_kv_cache" ,
@@ -338,6 +343,8 @@ def __init__( # pylint: disable=too-many-locals
338
343
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
339
344
rotary_dim : int
340
345
The number of dimensions in the embedding that RoPE is applied to.
346
+ enable_disaggregation : bool
347
+ Whether to enable disaggregation in the KV cache.
341
348
target : Target
342
349
The target to build the model to.
343
350
"""
@@ -377,6 +384,7 @@ def __init__( # pylint: disable=too-many-locals
377
384
bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" ),
378
385
bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" ),
379
386
rope_ext_factors ,
387
+ rx .PrimValue (enable_disaggregation ),
380
388
# fmt: on
381
389
# pylint: enable=line-too-long
382
390
]
@@ -409,8 +417,9 @@ def tir_kv_cache_transpose_append(
409
417
T .func_attr ({"tir.noalias" : T .bool (True )})
410
418
ntoken = T .SizeVar ("num_tokens_excluding_cache" , "int64" )
411
419
num_pages = T .int64 ()
420
+ pages_elem_offset = T .int64 ()
412
421
position_map_elem_offset = T .int32 ()
413
- pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , 16 , head_dim ), dtype )
422
+ pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , 16 , head_dim ), dtype , elem_offset = pages_elem_offset )
414
423
k_data = T .match_buffer (var_k_data , (ntoken , num_key_value_heads , head_dim ), dtype )
415
424
v_data = T .match_buffer (var_v_data , (ntoken , num_key_value_heads , head_dim ), dtype )
416
425
position_map = T .match_buffer (
@@ -453,8 +462,9 @@ def tir_kv_cache_debug_get_kv(
453
462
seqlen = T .SizeVar ("num_tokens_including_cache" , "int64" )
454
463
page_size = T .SizeVar ("page_size" , "int64" )
455
464
num_pages = T .int64 ()
465
+ pages_elem_offset = T .int64 ()
456
466
position_map_elem_offset = T .int64 ()
457
- pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , page_size , head_dim ), dtype )
467
+ pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , page_size , head_dim ), dtype , elem_offset = pages_elem_offset )
458
468
position_map = T .match_buffer (
459
469
var_position_map , (seqlen ,), "int32" , elem_offset = position_map_elem_offset
460
470
)
@@ -594,6 +604,7 @@ def batch_prefill_paged_kv(
594
604
total_len = T .int32 (is_size_var = True )
595
605
nnz_pages = T .int32 (is_size_var = True )
596
606
max_num_pages = T .int32 (is_size_var = True )
607
+ pages_elem_offset = T .int64 (is_size_var = True )
597
608
q_indptr_elem_offset = T .int32 (is_size_var = True )
598
609
page_indptr_elem_offset = T .int32 (is_size_var = True )
599
610
page_values_elem_offset = T .int32 (is_size_var = True )
@@ -603,7 +614,7 @@ def batch_prefill_paged_kv(
603
614
604
615
q = T .match_buffer (var_q , (total_len , h_q , d ), dtype )
605
616
q_indptr = T .match_buffer (var_q_indptr , (batch_size + 1 ,), "int32" , elem_offset = q_indptr_elem_offset )
606
- pages = T .match_buffer (var_pages , (max_num_pages , 2 , h_kv , 16 , d ), dtype )
617
+ pages = T .match_buffer (var_pages , (max_num_pages , 2 , h_kv , 16 , d ), dtype , elem_offset = pages_elem_offset )
607
618
page_indptr = T .match_buffer (var_page_indptr , (batch_size + 1 ,), "int32" , elem_offset = page_indptr_elem_offset )
608
619
page_values = T .match_buffer (var_page_values , (nnz_pages ,), "int32" , elem_offset = page_values_elem_offset )
609
620
k_rope_pos_offset = T .match_buffer (var_k_rope_pos_offset , (batch_size ,), "int32" , elem_offset = k_rope_pos_offset_elem_offset )
@@ -975,6 +986,7 @@ def batch_decode_paged_kv(
975
986
B = T .int32 (is_size_var = True )
976
987
nnz_pages = T .int32 (is_size_var = True )
977
988
max_num_pages = T .int32 (is_size_var = True )
989
+ pages_elem_offset = T .int64 (is_size_var = True )
978
990
page_indptr_elem_offset = T .int32 (is_size_var = True )
979
991
page_values_elem_offset = T .int32 (is_size_var = True )
980
992
k_rope_pos_offset_elem_offset = T .int32 (is_size_var = True )
@@ -983,7 +995,7 @@ def batch_decode_paged_kv(
983
995
984
996
Q = T .match_buffer (Q_handle , (B , H_qo , D ), qkv_dtype )
985
997
pages = T .match_buffer (
986
- pages_handle , (max_num_pages , 2 , H_kv , 16 , D ), qkv_dtype
998
+ pages_handle , (max_num_pages , 2 , H_kv , 16 , D ), qkv_dtype , elem_offset = pages_elem_offset
987
999
)
988
1000
page_table_indptr = T .match_buffer (page_table_indptr_handle , (B + 1 ,), "int32" , elem_offset = page_indptr_elem_offset )
989
1001
page_table_values = T .match_buffer (page_table_values_handle , (nnz_pages ,), "int32" , elem_offset = page_values_elem_offset )
@@ -1949,7 +1961,13 @@ def copy_single_page(
1949
1961
):
1950
1962
T .func_attr ({"tir.is_scheduled" : 1 })
1951
1963
num_pages = T .int32 ()
1952
- pages = T .match_buffer (var_pages , (num_pages , 2 , num_heads , page_size , head_dim ), dtype )
1964
+ pages_elem_offset = T .int64 ()
1965
+ pages = T .match_buffer (
1966
+ var_pages ,
1967
+ (num_pages , 2 , num_heads , page_size , head_dim ),
1968
+ dtype ,
1969
+ elem_offset = pages_elem_offset ,
1970
+ )
1953
1971
1954
1972
for b in T .thread_binding (
1955
1973
(copy_length * num_heads * head_dim + tx - 1 ) // tx , thread = "blockIdx.x"
@@ -1993,7 +2011,10 @@ def compact_kv_copy(
1993
2011
total_copy_length = T .int32 ()
1994
2012
copy_length_indptr_elem_offset = T .int32 ()
1995
2013
copy_src_dst_pos_elem_offset = T .int32 ()
1996
- pages = T .match_buffer (var_pages , (num_pages , 2 , num_heads , 16 , head_dim ), dtype )
2014
+ pages_elem_offset = T .int64 ()
2015
+ pages = T .match_buffer (
2016
+ var_pages , (num_pages , 2 , num_heads , 16 , head_dim ), dtype , elem_offset = pages_elem_offset
2017
+ )
1997
2018
copy_length_indptr = T .match_buffer (
1998
2019
var_copy_length_indptr ,
1999
2020
(batch_size + 1 ,),
0 commit comments