@@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
290290 indexer_max_chunk_size : int
291291 # Topk for sparse MLA
292292 sparse_mla_topk : int
293+ # max number of draft tokens
294+ max_draft_tokens : int = 0
293295
294296 def __init__ (self , * args , ** kwargs ):
295297 self .num_sms = tensorrt_llm .deep_gemm .get_num_sms ()
@@ -432,6 +434,64 @@ def __post_init__(self):
432434 dtype = torch .int32 ,
433435 capture_graph = capture_graph ,
434436 )
437+ self .create_expanded_buffers (capture_graph = capture_graph )
438+
439+ # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
440+ def create_expanded_buffers (self , capture_graph = False ):
441+ self .kv_lens_expanded_cuda = self .get_empty (
442+ self .cuda_graph_buffers ,
443+ (self .max_num_sequences * (1 + self .max_draft_tokens ), ),
444+ cache_name = "kv_lens_expanded_cuda" ,
445+ dtype = torch .int32 ,
446+ capture_graph = capture_graph ,
447+ )
448+ self .kv_lens_expanded_host = torch .zeros_like (
449+ self .kv_lens_expanded_cuda ,
450+ device = 'cpu' ,
451+ pin_memory = True ,
452+ )
453+ self .block_table_expanded = self .get_empty (
454+ self .cuda_graph_buffers ,
455+ [
456+ self .max_num_sequences * (1 + self .max_draft_tokens ),
457+ self .kv_cache_manager .max_blocks_per_seq
458+ ],
459+ cache_name = "block_table_expanded" ,
460+ dtype = torch .int32 ,
461+ capture_graph = capture_graph ,
462+ )
463+ self .host_block_table_expanded = torch .zeros_like (
464+ self .block_table_expanded ,
465+ device = 'cpu' ,
466+ pin_memory = True ,
467+ )
468+ self .scheduler_metadata_buffer_expanded = self .get_empty (
469+ self .cuda_graph_buffers ,
470+ (self .num_sms + 1 , 2 ),
471+ cache_name = "scheduler_metadata_buffer_expanded" ,
472+ dtype = torch .int32 ,
473+ capture_graph = capture_graph ,
474+ )
475+
476+ # This function is only used to create the expanded buffers when the max_draft_tokens is changed.
477+ # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
478+ def update_spec_dec_param (
479+ self ,
480+ is_spec_decoding_enabled ,
481+ is_spec_dec_tree ,
482+ is_spec_dec_dynamic_tree ,
483+ max_draft_tokens ,
484+ spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
485+ ):
486+ super ().update_spec_dec_param (is_spec_decoding_enabled ,
487+ is_spec_dec_tree ,
488+ is_spec_dec_dynamic_tree ,
489+ max_draft_tokens , spec_decoding_tensor )
490+ self .max_draft_tokens = max_draft_tokens
491+ init_shape = self .kv_lens_expanded_host .shape [0 ]
492+ if self .max_num_sequences * (1 + self .max_draft_tokens ) != init_shape :
493+ capture_graph = torch .cuda .is_current_stream_capturing ()
494+ self .create_expanded_buffers (capture_graph = capture_graph )
435495
436496 def prepare (self ):
437497 super ().prepare ()
@@ -535,6 +595,41 @@ def prepare(self):
535595 else :
536596 self .max_gen_seq_len = 0
537597
598+ # Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
599+ # MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
600+ # block_table for to use the fp8_paged_mqa_logits.
601+ # TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
602+ if self .max_draft_tokens > 1 :
603+ # Expand kv_lens_cuda (only generation)
604+ num_tokens = self .num_generations * (1 + self .max_draft_tokens )
605+ gen_kv_lens = kv_lens [self .num_contexts :self .num_seqs ]
606+ gen_kv_lens_expanded = torch .stack ([gen_kv_lens ] *
607+ (1 + self .max_draft_tokens ),
608+ dim = 0 )
609+ gen_kv_lens_expanded = gen_kv_lens_expanded .transpose (
610+ 0 , 1 ).contiguous ().flatten ()
611+ self .kv_lens_expanded_host [:num_tokens ].copy_ (gen_kv_lens_expanded )
612+ self .kv_lens_expanded_cuda [:num_tokens ].copy_ (
613+ self .kv_lens_expanded_host [:num_tokens ], non_blocking = True )
614+
615+ # Expand indexer_k_cache_block_offsets (only generation)
616+ if self .kv_cache_manager is not None :
617+ block_ids = self .kv_cache_manager .get_batch_cache_indices (
618+ self .request_ids )
619+ gen_block_ids = block_ids [self .num_contexts :]
620+ if len (gen_block_ids ) > 0 :
621+ # Find max length and create padded tensor
622+ max_len = max (len (bid ) for bid in gen_block_ids )
623+ gen_block_tensor = self .host_indexer_k_cache_block_offsets [
624+ self .num_contexts :self .num_seqs , :max_len ]
625+ expanded_blocks = gen_block_tensor .repeat_interleave (
626+ 1 + self .max_draft_tokens , dim = 0 )
627+ self .host_block_table_expanded [:num_tokens , :max_len ].copy_ (
628+ expanded_blocks , non_blocking = True )
629+ self .block_table_expanded [:num_tokens ].copy_ (
630+ self .host_block_table_expanded [:num_tokens ],
631+ non_blocking = True )
632+
538633 # Prepare metadata for indexer
539634 Indexer .prepare (metadata = self )
540635
@@ -799,12 +894,22 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
799894 if num_generations > 0 :
800895 # Prepare schedule metadata for fp8_paged_mqa_logits
801896 # This is a preprocessing step that computes scheduling information for the kernel
802- gen_seq_lens = metadata .kv_lens_cuda_runtime [
803- num_contexts :num_contexts + num_generations ]
804- scheduler_metadata_buffer = get_paged_mqa_logits_metadata (
805- gen_seq_lens , tokens_per_block , metadata .num_sms )
806- metadata .scheduler_metadata_buffer .copy_ (scheduler_metadata_buffer ,
807- non_blocking = True )
897+ if metadata .max_draft_tokens <= 1 :
898+ gen_seq_lens = metadata .kv_lens_cuda_runtime [
899+ num_contexts :num_contexts + num_generations ]
900+ scheduler_metadata_buffer = get_paged_mqa_logits_metadata (
901+ gen_seq_lens , tokens_per_block , metadata .num_sms )
902+ metadata .scheduler_metadata_buffer .copy_ (
903+ scheduler_metadata_buffer , non_blocking = True )
904+ else :
905+ # Expand schedule metadata buffer (only generation)
906+ num_tokens = metadata .num_generations * (
907+ 1 + metadata .max_draft_tokens )
908+ kv_lens_expanded = metadata .kv_lens_expanded_cuda [:num_tokens ]
909+ scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata (
910+ kv_lens_expanded , tokens_per_block , metadata .num_sms )
911+ metadata .scheduler_metadata_buffer_expanded .copy_ (
912+ scheduler_metadata_buffer_expanded , non_blocking = True )
808913
809914 # Compute slot_mapping for all requests (both context and generation)
810915 # This maps each token to its flat cache position for vectorized KV cache updates
@@ -1053,9 +1158,24 @@ def sparse_attn_indexer(
10531158 # Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
10541159 q_decode = q_fp8 [num_ctx_tokens :num_ctx_tokens + num_gen_tokens ,
10551160 ...]
1056- q_decode = q_decode .view (num_generations , - 1 , * q_fp8 .shape [1 :])
1057- batch_size = q_decode .shape [0 ]
1058- next_n = q_decode .shape [1 ]
1161+ batch_size = num_generations
1162+ next_n = num_gen_tokens // num_generations
1163+ # Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1164+ # and expand the corresponding metadata.
1165+ if next_n <= 2 :
1166+ q_decode = q_decode .view (num_generations , - 1 , * q_fp8 .shape [1 :])
1167+ context_lens = metadata .kv_lens_cuda_runtime [
1168+ num_contexts :num_contexts + num_generations ]
1169+ block_table = metadata .indexer_k_cache_block_offsets [
1170+ num_contexts :num_contexts + num_generations ]
1171+ scheduler_metadata_buffer = metadata .scheduler_metadata_buffer
1172+ else :
1173+ q_decode = q_decode .view (- 1 , 1 , * q_fp8 .shape [1 :])
1174+ num_tokens = num_generations * (1 + metadata .max_draft_tokens )
1175+ context_lens = metadata .kv_lens_expanded_cuda [:num_tokens ]
1176+ block_table = metadata .block_table_expanded [:num_tokens ]
1177+ scheduler_metadata_buffer = metadata .scheduler_metadata_buffer_expanded
1178+
10591179 assert num_gen_tokens == batch_size * next_n
10601180 weights_decode = weights [num_ctx_tokens :num_ctx_tokens +
10611181 num_gen_tokens , ...]
@@ -1064,18 +1184,11 @@ def sparse_attn_indexer(
10641184 # [num_blocks, tokens_per_block, 1, head_dim + scale_size]
10651185 k_cache = metadata .kv_cache_manager .get_indexer_k_cache_buffers (
10661186 self .layer_idx )
1067- logits_decode = fp8_paged_mqa_logits (
1068- q_decode ,
1069- k_cache ,
1070- weights_decode ,
1071- metadata .kv_lens_cuda_runtime [
1072- num_contexts :num_contexts +
1073- num_generations ], # context_lens prepared in prepare()
1074- metadata .indexer_k_cache_block_offsets [
1075- num_contexts :num_contexts +
1076- num_generations ], # Only pass generation request block tables
1077- metadata .scheduler_metadata_buffer ,
1078- max_seq_len )
1187+ logits_decode = fp8_paged_mqa_logits (q_decode , k_cache ,
1188+ weights_decode , context_lens ,
1189+ block_table ,
1190+ scheduler_metadata_buffer ,
1191+ max_seq_len )
10791192
10801193 if use_custom_topk :
10811194 # Kernel expects kv_lens (total cache length), not seq_lens (new tokens)
0 commit comments