-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA #25984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MLA backend to support speculative decoding with FlashInfer, which is a great improvement. The changes are mostly well-structured. However, I found a critical issue in the fallback logic for handling non-uniform query lengths in FlashInferMLAImpl
, which could lead to a runtime error. My review includes a suggestion to fix this.
Update: the failed baseline is most likely due to an unknown bug in MLA chunked prefill logic. See #26042 |
# `reorder_batch_threshold > 1`, any decode requests which do not | ||
# have the same query length as the first decode request will | ||
# fall back to the prefill kernel. | ||
supports_nonuniform_decode: ClassVar[bool] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this needed if its always set to false? (I think we should set this for FlashAttnMLA since it does support supports_nonuniform_decode
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we maybe can actually just unify supports_spec_as_decode
and supports_nonuniform_decode
to supports_only_uniform_spec_decode
and when thats False we just leave reorder_batch_threshold
untouched and require_uniform = False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson I'm pretty sure there can be a full matrix of options here, and that different combinations are useful. For example:
supports_spec_as_decode and supports_nonuniform_decode
: FlashAttnMLA, whererequire_uniform=False
is correct (it can handle varlen), and the longreorder_batch_threshold
allows it to handle spec requests.supports_spec_as_decode and not supports_nonuniform_decode
, whererequire_uniform=True
is required to function correctly, butreorder_batch_threshold
can be overridden to= 1 + num_spec_tokens
to handle spec decoding.not supports_spec_as_decode and not supports_nonuniform_decode
is the default for the backends which require q_len == 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update FlashAttnMLA to reflect the correct defaults, but I don't know how to support each of these 3 cases cleanly with only a single flag. Let me know if you would still prefer a different interface.
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Purpose
This PR refactors the
MLACommonMetadataBuilder
to easily support spec decode kernel optimization in MLA implementations. This is used to enable FlashInfer-MLA support using the trtllm-gen kernels which have explicit support for spec-as-decode.Test Plan
I ran a suite of evals over
nvidia/DeepSeek-R1-FP4
anddeepseek-ai/DeepSeek-R1-0528
on 4xB200 and 8xB200 respectively, usingCutlass-MLA
andFlashInfer-MLA
backends. Running MTP with FP4 on B200 requires the fix in #25987.Known issues
TheCutlass-MLA
backend produces incorrect output when using speculative decoding. It is not clear to my why this happens, I have debugged with enforce-eager but did not identify any issues except incorrect model output. I have not verified if this also occurs on H200, but I believeFLASH_ATTN_MLA
is also an option on Hopper so it may be sufficient to deprecateCutlass-MLA
when speculative decoding is enabled.See #26042 for tracking on this correctness issue, which seems to indicate the root cause is MLA chunked prefill.
The fix is in #26063. I will rerun the experiments for a better baseline, but the correctness of this branch for MTP is still valid.
Test Result
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA No-Spec
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA No-Spec