Skip to content

Commit 780d4f9

Browse files
authored
[None][feat] Add MTP>1 support for DS-v3.2 (#9045)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 53491ff commit 780d4f9

File tree

3 files changed

+165
-32
lines changed

3 files changed

+165
-32
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 134 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ class DSAtrtllmAttentionMetadata(TrtllmAttentionMetadata):
290290
indexer_max_chunk_size: int
291291
# Topk for sparse MLA
292292
sparse_mla_topk: int
293+
# max number of draft tokens
294+
max_draft_tokens: int = 0
293295

294296
def __init__(self, *args, **kwargs):
295297
self.num_sms = tensorrt_llm.deep_gemm.get_num_sms()
@@ -432,6 +434,64 @@ def __post_init__(self):
432434
dtype=torch.int32,
433435
capture_graph=capture_graph,
434436
)
437+
self.create_expanded_buffers(capture_graph=capture_graph)
438+
439+
# TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
440+
def create_expanded_buffers(self, capture_graph=False):
441+
self.kv_lens_expanded_cuda = self.get_empty(
442+
self.cuda_graph_buffers,
443+
(self.max_num_sequences * (1 + self.max_draft_tokens), ),
444+
cache_name="kv_lens_expanded_cuda",
445+
dtype=torch.int32,
446+
capture_graph=capture_graph,
447+
)
448+
self.kv_lens_expanded_host = torch.zeros_like(
449+
self.kv_lens_expanded_cuda,
450+
device='cpu',
451+
pin_memory=True,
452+
)
453+
self.block_table_expanded = self.get_empty(
454+
self.cuda_graph_buffers,
455+
[
456+
self.max_num_sequences * (1 + self.max_draft_tokens),
457+
self.kv_cache_manager.max_blocks_per_seq
458+
],
459+
cache_name="block_table_expanded",
460+
dtype=torch.int32,
461+
capture_graph=capture_graph,
462+
)
463+
self.host_block_table_expanded = torch.zeros_like(
464+
self.block_table_expanded,
465+
device='cpu',
466+
pin_memory=True,
467+
)
468+
self.scheduler_metadata_buffer_expanded = self.get_empty(
469+
self.cuda_graph_buffers,
470+
(self.num_sms + 1, 2),
471+
cache_name="scheduler_metadata_buffer_expanded",
472+
dtype=torch.int32,
473+
capture_graph=capture_graph,
474+
)
475+
476+
# This function is only used to create the expanded buffers when the max_draft_tokens is changed.
477+
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
478+
def update_spec_dec_param(
479+
self,
480+
is_spec_decoding_enabled,
481+
is_spec_dec_tree,
482+
is_spec_dec_dynamic_tree,
483+
max_draft_tokens,
484+
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
485+
):
486+
super().update_spec_dec_param(is_spec_decoding_enabled,
487+
is_spec_dec_tree,
488+
is_spec_dec_dynamic_tree,
489+
max_draft_tokens, spec_decoding_tensor)
490+
self.max_draft_tokens = max_draft_tokens
491+
init_shape = self.kv_lens_expanded_host.shape[0]
492+
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
493+
capture_graph = torch.cuda.is_current_stream_capturing()
494+
self.create_expanded_buffers(capture_graph=capture_graph)
435495

436496
def prepare(self):
437497
super().prepare()
@@ -535,6 +595,41 @@ def prepare(self):
535595
else:
536596
self.max_gen_seq_len = 0
537597

598+
# Because the fp8_paged_mqa_logits only supports seq_len == 1 or 2, so it cannot support
599+
# MTP > 1. To handle this, when MTP > 1, we flatten the q tensor and expand the kv_lens and
600+
# block_table for to use the fp8_paged_mqa_logits.
601+
# TODO: remove this when fp8_paged_mqa_logits supports MTP > 1.
602+
if self.max_draft_tokens > 1:
603+
# Expand kv_lens_cuda (only generation)
604+
num_tokens = self.num_generations * (1 + self.max_draft_tokens)
605+
gen_kv_lens = kv_lens[self.num_contexts:self.num_seqs]
606+
gen_kv_lens_expanded = torch.stack([gen_kv_lens] *
607+
(1 + self.max_draft_tokens),
608+
dim=0)
609+
gen_kv_lens_expanded = gen_kv_lens_expanded.transpose(
610+
0, 1).contiguous().flatten()
611+
self.kv_lens_expanded_host[:num_tokens].copy_(gen_kv_lens_expanded)
612+
self.kv_lens_expanded_cuda[:num_tokens].copy_(
613+
self.kv_lens_expanded_host[:num_tokens], non_blocking=True)
614+
615+
# Expand indexer_k_cache_block_offsets (only generation)
616+
if self.kv_cache_manager is not None:
617+
block_ids = self.kv_cache_manager.get_batch_cache_indices(
618+
self.request_ids)
619+
gen_block_ids = block_ids[self.num_contexts:]
620+
if len(gen_block_ids) > 0:
621+
# Find max length and create padded tensor
622+
max_len = max(len(bid) for bid in gen_block_ids)
623+
gen_block_tensor = self.host_indexer_k_cache_block_offsets[
624+
self.num_contexts:self.num_seqs, :max_len]
625+
expanded_blocks = gen_block_tensor.repeat_interleave(
626+
1 + self.max_draft_tokens, dim=0)
627+
self.host_block_table_expanded[:num_tokens, :max_len].copy_(
628+
expanded_blocks, non_blocking=True)
629+
self.block_table_expanded[:num_tokens].copy_(
630+
self.host_block_table_expanded[:num_tokens],
631+
non_blocking=True)
632+
538633
# Prepare metadata for indexer
539634
Indexer.prepare(metadata=self)
540635

@@ -799,12 +894,22 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
799894
if num_generations > 0:
800895
# Prepare schedule metadata for fp8_paged_mqa_logits
801896
# This is a preprocessing step that computes scheduling information for the kernel
802-
gen_seq_lens = metadata.kv_lens_cuda_runtime[
803-
num_contexts:num_contexts + num_generations]
804-
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
805-
gen_seq_lens, tokens_per_block, metadata.num_sms)
806-
metadata.scheduler_metadata_buffer.copy_(scheduler_metadata_buffer,
807-
non_blocking=True)
897+
if metadata.max_draft_tokens <= 1:
898+
gen_seq_lens = metadata.kv_lens_cuda_runtime[
899+
num_contexts:num_contexts + num_generations]
900+
scheduler_metadata_buffer = get_paged_mqa_logits_metadata(
901+
gen_seq_lens, tokens_per_block, metadata.num_sms)
902+
metadata.scheduler_metadata_buffer.copy_(
903+
scheduler_metadata_buffer, non_blocking=True)
904+
else:
905+
# Expand schedule metadata buffer (only generation)
906+
num_tokens = metadata.num_generations * (
907+
1 + metadata.max_draft_tokens)
908+
kv_lens_expanded = metadata.kv_lens_expanded_cuda[:num_tokens]
909+
scheduler_metadata_buffer_expanded = get_paged_mqa_logits_metadata(
910+
kv_lens_expanded, tokens_per_block, metadata.num_sms)
911+
metadata.scheduler_metadata_buffer_expanded.copy_(
912+
scheduler_metadata_buffer_expanded, non_blocking=True)
808913

809914
# Compute slot_mapping for all requests (both context and generation)
810915
# This maps each token to its flat cache position for vectorized KV cache updates
@@ -1053,9 +1158,24 @@ def sparse_attn_indexer(
10531158
# Reshape q for decode phase: [num_gen_tokens, ...] -> [batch_size, next_n, ...]
10541159
q_decode = q_fp8[num_ctx_tokens:num_ctx_tokens + num_gen_tokens,
10551160
...]
1056-
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1057-
batch_size = q_decode.shape[0]
1058-
next_n = q_decode.shape[1]
1161+
batch_size = num_generations
1162+
next_n = num_gen_tokens // num_generations
1163+
# Because fp8_paged_mqa_logits cannot support next_n > 2, we need to flatten the q_decode tensor
1164+
# and expand the corresponding metadata.
1165+
if next_n <= 2:
1166+
q_decode = q_decode.view(num_generations, -1, *q_fp8.shape[1:])
1167+
context_lens = metadata.kv_lens_cuda_runtime[
1168+
num_contexts:num_contexts + num_generations]
1169+
block_table = metadata.indexer_k_cache_block_offsets[
1170+
num_contexts:num_contexts + num_generations]
1171+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer
1172+
else:
1173+
q_decode = q_decode.view(-1, 1, *q_fp8.shape[1:])
1174+
num_tokens = num_generations * (1 + metadata.max_draft_tokens)
1175+
context_lens = metadata.kv_lens_expanded_cuda[:num_tokens]
1176+
block_table = metadata.block_table_expanded[:num_tokens]
1177+
scheduler_metadata_buffer = metadata.scheduler_metadata_buffer_expanded
1178+
10591179
assert num_gen_tokens == batch_size * next_n
10601180
weights_decode = weights[num_ctx_tokens:num_ctx_tokens +
10611181
num_gen_tokens, ...]
@@ -1064,18 +1184,11 @@ def sparse_attn_indexer(
10641184
# [num_blocks, tokens_per_block, 1, head_dim + scale_size]
10651185
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
10661186
self.layer_idx)
1067-
logits_decode = fp8_paged_mqa_logits(
1068-
q_decode,
1069-
k_cache,
1070-
weights_decode,
1071-
metadata.kv_lens_cuda_runtime[
1072-
num_contexts:num_contexts +
1073-
num_generations], # context_lens prepared in prepare()
1074-
metadata.indexer_k_cache_block_offsets[
1075-
num_contexts:num_contexts +
1076-
num_generations], # Only pass generation request block tables
1077-
metadata.scheduler_metadata_buffer,
1078-
max_seq_len)
1187+
logits_decode = fp8_paged_mqa_logits(q_decode, k_cache,
1188+
weights_decode, context_lens,
1189+
block_table,
1190+
scheduler_metadata_buffer,
1191+
max_seq_len)
10791192

10801193
if use_custom_topk:
10811194
# Kernel expects kv_lens (total cache length), not seq_lens (new tokens)

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,7 @@ class TestDeepSeekV32(LlmapiAccuracyTestHarness):
23802380
(8, 1, 8, 0, False, True, True, True, 24, "_DEFAULT"),
23812381
(8, 1, 8, 1, False, True, True, True, 24, "_DEFAULT"),
23822382
(8, 1, 8, 0, True, True, True, True, 24, "_DEFAULT"),
2383-
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
2383+
(8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"),
23842384
],
23852385
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
23862386
def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
@@ -2448,7 +2448,7 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
24482448
(8, 1, 8, 0, False, True, True, True, 24, "CUTLASS"),
24492449
(8, 1, 8, 1, False, True, True, True, 24, "CUTLASS"),
24502450
(8, 1, 8, 0, True, True, True, True, 24, "CUTLASS"),
2451-
(8, 1, 8, 1, False, False, True, True, 1, "TRTLLM"),
2451+
(8, 1, 8, 3, False, False, True, True, 1, "TRTLLM"),
24522452
],
24532453
ids=["baseline", "baseline_mtp1", "baseline_fp8kv", "latency"])
24542454
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids,
381381
cache_manager,
382382
num_ctx_tokens,
383383
num_tokens,
384-
indexer_max_chunk_size=8194):
384+
indexer_max_chunk_size=8194,
385+
max_draft_tokens=0):
385386
"""Helper to create mock metadata for testing."""
386387

387388
class MockKVCacheParams:
@@ -396,6 +397,7 @@ def __init__(self):
396397
self.request_ids = request_ids
397398
self.num_contexts = num_contexts
398399
self.num_generations = num_generations
400+
self.max_draft_tokens = max_draft_tokens
399401
# Keep seq_lens on CPU for split_prefill_chunks and other CPU operations
400402
# CUDA kernels will convert to CUDA as needed
401403
self.seq_lens = seq_lens.cpu() if seq_lens.is_cuda else seq_lens
@@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
826828
cache_manager=cache_manager,
827829
num_ctx_tokens=total_context_tokens,
828830
num_tokens=total_context_tokens,
831+
max_draft_tokens=next_n - 1,
829832
)
830833
Indexer.prepare(metadata_context)
831834

@@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
851854
cache_manager=cache_manager,
852855
num_ctx_tokens=0,
853856
num_tokens=batch_size * num_gen_tokens,
857+
max_draft_tokens=next_n - 1,
854858
)
855859
Indexer.prepare(metadata_gen)
856860

@@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14181422
cache_manager=cache_manager,
14191423
num_ctx_tokens=total_context_tokens,
14201424
num_tokens=total_context_tokens,
1425+
max_draft_tokens=next_n - 1,
14211426
)
14221427
Indexer.prepare(metadata_context)
14231428
indexer._update_k_cache(k_context_fp8, k_context_scale, metadata_context)
@@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14501455
cache_manager=cache_manager,
14511456
num_ctx_tokens=0,
14521457
num_tokens=num_gen_tokens,
1458+
max_draft_tokens=next_n - 1,
14531459
)
14541460
Indexer.prepare(metadata_gen_write)
14551461
indexer._update_k_cache(k_fp8, k_scale, metadata_gen_write)
14561462

14571463
# Test with custom CUDA kernel
1458-
metadata_custom = _create_mock_metadata(request_ids, batch_size, 0,
1459-
batch_size, seq_lens.clone(),
1464+
metadata_custom = _create_mock_metadata(request_ids,
1465+
batch_size,
1466+
0,
1467+
batch_size,
1468+
seq_lens.clone(),
14601469
final_lens.clone(),
1461-
num_cached_tokens, cache_manager, 0,
1462-
num_gen_tokens, max_model_len)
1470+
num_cached_tokens,
1471+
cache_manager,
1472+
0,
1473+
num_gen_tokens,
1474+
max_model_len,
1475+
max_draft_tokens=next_n - 1)
14631476

14641477
Indexer.prepare(metadata_custom)
14651478
indexer._update_k_cache(k_fp8, k_scale, metadata_custom)
@@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14761489
pytest.skip(f"Custom topk not available: {e}")
14771490

14781491
# Test with PyTorch fallback
1479-
metadata_fallback = _create_mock_metadata(request_ids, batch_size, 0,
1480-
batch_size, seq_lens.clone(),
1492+
metadata_fallback = _create_mock_metadata(request_ids,
1493+
batch_size,
1494+
0,
1495+
batch_size,
1496+
seq_lens.clone(),
14811497
final_lens.clone(),
1482-
num_cached_tokens, cache_manager,
1483-
0, num_gen_tokens, max_model_len)
1498+
num_cached_tokens,
1499+
cache_manager,
1500+
0,
1501+
num_gen_tokens,
1502+
max_model_len,
1503+
max_draft_tokens=next_n - 1)
14841504

14851505
Indexer.prepare(metadata_fallback)
14861506
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)

0 commit comments

Comments
 (0)