Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Sep 30, 2025

Purpose

This PR refactors the MLACommonMetadataBuilder to easily support spec decode kernel optimization in MLA implementations. This is used to enable FlashInfer-MLA support using the trtllm-gen kernels which have explicit support for spec-as-decode.

Test Plan

I ran a suite of evals over nvidia/DeepSeek-R1-FP4 and deepseek-ai/DeepSeek-R1-0528 on 4xB200 and 8xB200 respectively, using Cutlass-MLA and FlashInfer-MLA backends. Running MTP with FP4 on B200 requires the fix in #25987.

lm_eval \
  --model local-completions \
  --tasks gsm8k \
  --model_args base_url=http://0.0.0.0:8049/v1/completions,model=nvidia/DeepSeek-R1-FP4,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5

Known issues

The Cutlass-MLA backend produces incorrect output when using speculative decoding. It is not clear to my why this happens, I have debugged with enforce-eager but did not identify any issues except incorrect model output. I have not verified if this also occurs on H200, but I believe FLASH_ATTN_MLA is also an option on Hopper so it may be sufficient to deprecate Cutlass-MLA when speculative decoding is enabled.

See #26042 for tracking on this correctness issue, which seems to indicate the root cause is MLA chunked prefill.

The fix is in #26063. I will rerun the experiments for a better baseline, but the correctness of this branch for MTP is still valid.

Test Result

4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA MTP=3

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

Finish: 1319/1319 [01:07<00:00, 19.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9500|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9507|±  | 0.006|

4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA MTP=3

VLLM_ATTENTION_BACKEND=CUTLASS_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

FAILED. Also fails with --enforce-eager
Finish: 1319/1319 [05:55<00:00,  3.71it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |    0|±  |     0|
|     |       |strict-match    |     5|exact_match|↑  |    0|±  |     0|

4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA No-Spec

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:30<00:00, 14.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9447|±  |0.0063|
|     |       |strict-match    |     5|exact_match|↑  |0.9454|±  |0.0063|

4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA No-Spec

VLLM_ATTENTION_BACKEND=CUTLASS_MLA VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/DeepSeek-R1-FP4 -tp 4 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:29<00:00, 14.70it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9492|±  |0.0060|
|     |       |strict-match    |     5|exact_match|↑  |0.9477|±  |0.0061|

8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA MTP=3

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

Finish: 1319/1319 [01:24<00:00, 15.67it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9530|±  |0.0058|

8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA MTP=3

VLLM_ATTENTION_BACKEND=CUTLASS_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049 --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

FAIL

8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA No-Spec

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:44<00:00, 12.62it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9545|±  |0.0057|
|     |       |strict-match    |     5|exact_match|↑  |0.9522|±  |0.0059|

8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA No-Spec

VLLM_ATTENTION_BACKEND=CUTLASS_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 -tp 8 --max-model-len 8192 --no-enable-prefix-caching --port 8049

Finish: 1319/1319 [01:42<00:00, 12.86it/s]

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9560|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

Signed-off-by: Benjamin Chislett <[email protected]>
@mergify mergify bot added the v1 label Sep 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLA backend to support speculative decoding with FlashInfer, which is a great improvement. The changes are mostly well-structured. However, I found a critical issue in the fallback logic for handling non-uniform query lengths in FlashInferMLAImpl, which could lead to a runtime error. My review includes a suggestion to fix this.

@benchislett
Copy link
Collaborator Author

Update: the failed baseline is most likely due to an unknown bug in MLA chunked prefill logic. See #26042

# `reorder_batch_threshold > 1`, any decode requests which do not
# have the same query length as the first decode request will
# fall back to the prefill kernel.
supports_nonuniform_decode: ClassVar[bool] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this needed if its always set to false? (I think we should set this for FlashAttnMLA since it does support supports_nonuniform_decode)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we maybe can actually just unify supports_spec_as_decode and supports_nonuniform_decode to supports_only_uniform_spec_decode and when thats False we just leave reorder_batch_threshold untouched and require_uniform = False

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I'm pretty sure there can be a full matrix of options here, and that different combinations are useful. For example:

  • supports_spec_as_decode and supports_nonuniform_decode: FlashAttnMLA, where require_uniform=False is correct (it can handle varlen), and the long reorder_batch_threshold allows it to handle spec requests.
  • supports_spec_as_decode and not supports_nonuniform_decode, where require_uniform=True is required to function correctly, but reorder_batch_threshold can be overridden to = 1 + num_spec_tokens to handle spec decoding.
  • not supports_spec_as_decode and not supports_nonuniform_decode is the default for the backends which require q_len == 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update FlashAttnMLA to reflect the correct defaults, but I don't know how to support each of these 3 cases cleanly with only a single flag. Let me know if you would still prefer a different interface.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants