feat: mixed prefill-decode batching (part 3: compatibility with speculative decoding for MLA models)#205
Conversation
f2815dc to
91482d0
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 91482d0584
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f863184d8a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| sampling_info=sampling_info, | ||
| bs=bs, | ||
| is_spec_decode=self.drafter is not None and num_extends <= 0, | ||
| is_spec_decode=self.drafter is not None and num_extends < bs, |
There was a problem hiding this comment.
Restrict spec-grammar mode to true decode batches
Changing is_spec_decode to self.drafter is not None and num_extends < bs makes MIXED batches enter spec-grammar mode, but _forward_step still only uploads candidate tokens to the grammar runtime when ctx.forward_mode.is_decode() (schedule_fill(...input_ids_buf_slice=...)). In MIXED + speculative + grammar runs, this marks has_candidates=True while candidate buffers are not refreshed, so the capturable grammar path computes per-token masks from stale/incorrect candidate data, producing wrong constrained sampling behavior.
Useful? React with 👍 / 👎.
f863184 to
51f14df
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 51f14df981
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| sampling_info=sampling_info, | ||
| bs=bs, | ||
| is_spec_decode=self.drafter is not None and num_extends <= 0, | ||
| is_spec_decode=self.drafter is not None and num_extends < bs, |
There was a problem hiding this comment.
Keep spec-grammar path off for mixed batches in eager mode
Setting is_spec_decode true for any batch with num_extends < bs makes MIXED speculative batches enter the per-token grammar path even on the eager fallback. In that path, _fill_eager_bitmask assumes a pure decode layout and reads input_ids_buf[:bs*spec_num_tokens].view(bs, spec_num_tokens), but MIXED batches store ragged prefill tokens first and decode candidates after them, so grammar masks are generated from misaligned tokens. This produces incorrect constrained sampling whenever capturable grammar is disabled (or unavailable) and MIXED+spec decode is active.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3bf882de69
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| mla_metadata, num_splits = get_mla_metadata( | ||
| plan_seq_lens.to(torch.int32), | ||
| num_heads_plan, | ||
| 1, | ||
| ) |
There was a problem hiding this comment.
Build FlashMLA decode planner from decode rows only
In MIXED batches, this path computes flashmla_metadata/num_splits from full seq_lens (prefill + decode rows), but decode execution later consumes only the decode slice (block_table[num_extends:] in forward_extend/forward_decode). That misaligns planner metadata with the actual decode rows, so when num_extends > 0 the kernel can run with row planning derived from the wrong requests, leading to incorrect attention behavior or shape/runtime failures. The decode planner needs to be built (or consistently sliced) for decode rows only.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 452d75aedf
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| self.draft_attn_backend.init_forward_metadata( | ||
| bs=padded_bs, | ||
| num_extends=num_extends, | ||
| req_pool_indices=req_pool_indices, | ||
| seq_lens=draft_seq_lens, |
There was a problem hiding this comment.
Reintroduce decode-metadata init for Triton draft backends
When _init_forward_metadata switched from two draft-backend init calls to a single call, EXTEND/MIXED draft batches no longer force a DECODE-shaped metadata refresh before Eagle step 2+. TritonAttnBackend stores only one forward_metadata, so after the single call it can still hold prefill-style qo_indptr/kv_indptr; then _run_multi_step_decode invokes decode kernels against that stale layout, which can misindex KV ranges or fail on shape assumptions for mixed prefill+decode speculative batches.
Useful? React with 👍 / 👎.
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
452d75a to
6a7b0a1
Compare
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Summary
Extend
--enable-mixed-batchfrom V4 sparse attention to Kimi K2.5 (and all MLA backends), with first-class spec-dec support.Mixed prefill-decode batching:
Changes
Scheduler & runtime
enable_mixed_prefill_decodeso MIXED actually emits under long-prefill workloads.--enable-mixed-batch; users opt in per-workload.MLA backends (all gain MIXED support)
tokenspeed_mla,trtllm_mla,flashmla:init_forward_metadatafills both prefill + decode metadata under MIXED, withnum_extendsdiscriminator on decode metadata for kernel-call-time slicing.DeepseekV3AttentionMLA.forwardrunspre_attn_commonce, slices bynum_prefill_tokens, dispatches prefill/decode through their native paths, single sharedo_proj.out=plumbing eliminates the per-call BF16 copy in chunked prefill.Spec-dec + MIXED
is_draft + extend_or_mixed, withseq_lensaliased to the drafter's live buffer for in-place multi-step advance.LogitsProcessorcollapsed from a 4-branch prune into a singlegather_idsgather; row indices computed by the caller (ModelExecutor + Eagle drafter). Cleaner contract, fewer special cases, MIXED-with-verify just works.spec_num_tokensfield onAttentionBackendwith sentinel-aware config defaults.EagleDraftInput.num_extendsthreaded for correct drafter dispatch under EXTEND target.vcdelta, per-rowis_decode_slotgate, hybrid sampler logprob writeback).Interface cleanup
num_tokensarg from allinit_forward_metadata*signatures.spec_num_tokens→q_len_per_reqto distinguish per-call shape from configured constant.set_decode_num_extends(int)setter withoverride_num_extends(int)context manager that restores prior value.CI
--enable-mixed-batch.tokenspeed_mla+kimi-k2.5-eagle3-mla) for MIXED-safety.