Skip to content

Fpz/mixed mla dispatch#1271

Draft
jiayyu wants to merge 22 commits into
mainfrom
fpz/mixed_mla_dispatch
Draft

Fpz/mixed mla dispatch#1271
jiayyu wants to merge 22 commits into
mainfrom
fpz/mixed_mla_dispatch

Conversation

@jiayyu

@jiayyu jiayyu commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

jiayyu and others added 22 commits June 17, 2026 10:28
…ill-decode (P2-M1)

Adds a config flag (default off) that lets the scheduler pack prefill
chunks and decode seqs into the same ScheduledBatch — Phase 2 of chunked
prefill. With the flag off, behavior is unchanged: prefill-only or
decode-only batches as before.

- atom/config.py: new enable_mixed_prefill_decode field (default False)
- atom/model_engine/arg_utils.py: --enable-mixed-prefill-decode CLI flag
- atom/model_engine/scheduler.py:
  - ScheduledBatch.is_mixed = both prefill and decode rows present
  - ScheduledBatch.num_prompt_tokens for runner-side final-chunk detection
  - schedule() fall-through path: when flag is on, decode loop runs after
    Phase 2 prefill against a shared num_batched_tokens budget and
    shared max_num_seqs slot count
  - decode_carryover + decode_scheduled lists preserve FCFS ordering
    when seqs are popped from running but not picked this step
  - is_partial_prefill stays False whenever any decode row is present
- tests/test_scheduler.py: 6 new TestMixedBatch cases covering flag-off
  back-compat, mixed batch production, is_partial_prefill semantics,
  no double-scheduling, and pure-prefill/pure-decode fallbacks

Backends still raise on is_mixed batches; the flag is meaningful only
after the MHA/MLA Split Dispatch lands in P2-M2/M3.
…-M4)

prepare_input_ids now writes both the prefill and decode token regions
when batch.is_mixed, instead of early-returning the prefill-only slice.

ParallelLMHead.forward gains a mixed-mode branch that gathers
last-token-per-prefill-seq for the prefill rows and keeps every token
for the decode rows, so the sampler sees one logits row per sequence
(layout: [prefill_last × n_p_seqs | decode_tokens]).

Context gains num_prefill_seqs alongside the existing num_prefill_tokens
so the lm_head gather skips the per-step CPU sync.

MTP / deferred-output integration with mixed batches is not yet
implemented — guarded by explicit asserts and a NotImplementedError at
the spec_decode_metadata site so misconfiguration fails loud.

DP get_next_batch_info needs no change: mixed batches always carry
prefill seqs, so the existing eligible_waiting branch already returns
is_prefill=True for cross-rank dummy-prefill sync.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
… (P2-M5)

TBO ubatch splitting assumes a uniform-layout batch (all prefill or all
decode rows); the mixed [prefill | decode] layout would need a custom
ubatch split, which isn't implemented yet. Disable TBO when batch.is_mixed
so we don't silently produce wrong slices.

Document mixed-batch semantics in scheduler.get_next_batch_info: mixed
batches report is_prefill=True (they always carry a prefill seq), so the
existing dummy-prefill DP sync path in engine_core handles them without
additional change.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Implements the attention split dispatch that lets a mixed batch
(prefill chunks + decode seqs packed together by P2-M1) actually run
on dense MLA (DeepSeek-V2/V3). Gated by --enable-mixed-prefill-decode
(default off); prefill-only and decode-only fast paths are unchanged.

- forward_context: AttentionMetaData gains nested prefill_attn_metadata
  / decode_attn_metadata sub-fields for split dispatch.
- aiter_mla: prepare_mixed() builds both halves — reuses prepare_prefill
  and clones its tensors off the shared forward_vars buffers before the
  decode half rewrites them; decode half rebuilt for n_decode_seqs.
  Surfaces prefill cu_seqlens_q on the top-level metadata for the
  lm_head gather.
- attention_mla: forward_impl is_mixed branch splits q/k/v at
  num_prefill_tokens — prefill rows take the MHA path, decode rows the
  MLA latent path, outputs concatenated.
- embed_head: check is_mixed BEFORE is_prefill (mixed batches also have
  is_prefill=True) so the combined gather runs.
- model_runner: auto-disable deferred output when the mixed flag is on
  (the [prefill|decode] input-id layout isn't wired into deferred yet).
- guards: backends base prepare_mixed, V4 builder, and MHA forward
  raise/assert on is_mixed so unsupported paths fail loud.

Verified on DeepSeek-R1-0528 (TP8, fp8 KV): 670 mixed batches scheduled
during GSM8K 3-shot; flexible-extract 0.9469 vs 0.9477 baseline, within
the noise band. 399 relevant unit tests pass with the flag off.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Replaces the NotImplementedError guard with a real prepare_mixed that
builds both per-segment metadata sets by reusing the validated
prepare_prefill / prepare_decode unmodified:
- prefill rows [0:n_p] via prepare_prefill (already correct — prefill is
  first), then clone its shared forward_vars/_stage-aliasing tensors
  (batch_id_per_token, n_committed_csa_per_seq, state_slot_mapping,
  block_tables, cu_seqlens_q/k, context_lens, indexer_meta) before the
  decode half overwrites them.
- decode rows [n_p:] via prepare_decode against a thin _MixedDecodeView
  that slices only the fields prepare_decode reads.
- merged AttentionMetaData_DSV4 carries both as prefill_attn_metadata /
  decode_attn_metadata plus is_mixed + num_{prefill,decode}_{tokens,seqs}
  markers and merged positions for the shared full-tensor forward ops.

forward_impl still raises on is_mixed until P2-P4 land the per-segment
split; this lands the metadata layer first. Gated by
--enable-mixed-prefill-decode (default off).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Adds the is_mixed branch to DeepseekV4Attention.forward_impl for Dense
(ratio==0) layers: splits q_sa/kv/positions at num_prefill_tokens, runs
the decode slice (swa_write BEFORE → sparse_attn_v4_paged_decode) and the
prefill slice (sparse_attn_v4_paged_prefill → swa_write AFTER) against
their own per-segment sub-metadata, then concatenates outputs. The
non-mixed decode swa_write is gated `and not is_mixed`.

Asserts ratio==0 for now; CSA (P3) and HCA (P4) follow. Since every V4
model is hybrid (V4-Pro has only 1 Dense layer of 62), this is not yet
runnable end-to-end on its own — it lands the Dense slice of the split.
Gated by --enable-mixed-prefill-decode (default off).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Supersedes the in-place Dense-only split (828bfd6, reverted): instead of
forking forward_impl at every is_decode site, run the whole validated
forward body once per segment. Each sub-metadata (prefill_attn_metadata /
decode_attn_metadata from prepare_mixed) is a complete self-consistent V4
metadata object, so for a mixed batch we point fc.attn_metadata +
fc.context.is_prefill at each segment in turn, re-enter forward_impl as a
pure prefill / pure decode forward, and concatenate outputs.

This covers Dense (ratio==0), CSA (ratio==4) and HCA (ratio==128) in one
change because every per-ratio path (swa_write ordering, Indexer
prefill-vs-decode dispatch, csa_translate_pack, sparse kernels, Compressor)
falls out of re-running the existing non-mixed code. The non-mixed path is
byte-for-byte unchanged.

Tradeoff: Q/KV projections + qk_norm_rope + output LoRA run twice (once per
segment) rather than shared on the merged tensor — correct (disjoint
tokens), with a fusion opportunity left for later. Gated by
--enable-mixed-prefill-decode (default off).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The V4 builder has no is_sparse attribute (that lives on the AiterMLA
builder); the leftover `assert not self.is_sparse or True` placeholder
raised AttributeError on the first mixed batch. Remove it.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The ParallelLMHead mixed-batch gather reads attn_metadata.cu_seqlens_q to
locate per-prefill-seq last tokens; the merged V4 mixed metadata left it
None, raising TypeError. Point it at the prefill sub-meta's cu_seqlens_q
(same fix as the dense-MLA prepare_mixed).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed builder cloned prefill per-fwd tensors off the shared
forward_vars buffers but missed compress_plans: each CompressPlan's
compress_plan_gpu / write_plan_gpu are views into the shared
var["v4_compress_plan_{ratio}"] / var["v4_write_plan_{ratio}"] buffers,
which the subsequent prepare_decode(_build_compress_plans) overwrites. The
prefill segment's Compressor then read the DECODE plan — whose ragged_id /
batch_id reference the decode segment's smaller kv / state_slot_mapping —
causing a MEMORY_VIOLATION in _update_compressor_states_kernel (located via
rocm-debug-agent). Clone both GPU tensors per CompressPlan.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Removes the blanket auto-disable of deferred output when
--enable-mixed-prefill-decode is on, and implements the deferred decode-input
path for mixed batches in prepare_input_ids:
- decode region starts at decode_offset = total_tokens_prefill (prefill rows
  already written above);
- decode seqs present in prev_batch are gathered GPU-side from prev_token_ids
  by req_id, written at their decode-segment batch position (row-aligned with
  the decode attention metadata);
- decode seqs not in prev_batch fall back to scheduled_tokens.
- prev_batch / prev_token_ids stay advanced only in prepare_sampled_ids.

The change is confined to the is_mixed branch; the proven non-mixed
prefill/decode/deferred paths are byte-for-byte unchanged. MTP + mixed remains
asserted off.

Runs end-to-end without crashes on dense-MLA (R1) and V4-Pro sparse (GPU
fault-free, mixed batches scheduled). KNOWN ISSUE: a decode seq that idles
across prefill-chunk steps (so it is not in the immediately-previous batch)
currently falls back to scheduled_tokens, which under deferred output holds a
placeholder rather than the seq's last real token — causing a partial GSM8K
accuracy drop (R1 full: 0.87 vs 0.9469 non-deferred-mixed). Fix to follow;
for accuracy-sensitive runs keep --enable-mixed-prefill-decode off or run
without this path until the placeholder sourcing is corrected.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…round

The earlier MLA NaN fix sanitized the seed lse at the call site
(nan_to_num on suf_lse). Now that merge_attn_states lives in ATOM
(triton_merge_attn_states.py, sole caller is the MLA chunked path —
plugin/vllm imports vllm's own copy), fix it at the root instead.

When a token's prefix AND suffix are both empty (max_lse == -inf), the
kernel computed -inf-(-inf)=NaN and a 0/0 scale that poisoned the output.
This is reachable in ATOM's global-axis chunked prefill: a short seq can
fall entirely outside a chunk. Guard both_empty: force a finite 0/0-split
so out=0 (correct for empty attention) and keep lse=-inf. The call-site
nan_to_num is now redundant and reverts to chunked_lse = suf_lse.

Verified GSM8K (R1-MXFP4, tp4, fp8, num_concurrent=64, long-prefill 512):
0.9431 — same as the call-site workaround, no regression.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
(cherry picked from commit 591dfef)
HCA prefill used the per-seq committed count (ctx_end//128) for every
token, missing the (pos+1)//128 per-token causal cap that CSA already has
(and that the reference get_compress_topk_idxs applies). Under chunked
prefill ctx_end is the chunk's end, so the same logical token saw a
different number of HCA compressed groups depending on which chunk
computed it -> chunked != single-shot -> ~0.02 GSM8K drop.

Cap HCA per-token visibility to min((pos+1)//128, n_committed_hca) in the
indptr build, the prefill-indices kernel (new HCA_RATIO constexpr), and
the reference impl. Decode is unaffected (decode token is at seq end, the
cap is a no-op).

Verified GSM8K (V4-Pro, num_concurrent=4, fp8): chunked 0.93 -> 0.9507,
single-shot 0.9515 (no regression).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
(cherry picked from commit 0ac25d7)
Pre-existing one-line while condition that black normalizes; keeps the tree
black-clean (CI enforced). No behavior change.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
With V4 chunked prefill enabled (post-rebase), a mixed batch can contain a
partial-chunk prefill row alongside decode rows. Two bugs caused a GPU
memory-access fault (located via kernel-marker bisection — fault was in the
decode segment's swa_write):

1. prepare_mixed: prepare_decode READS var["cu_seqlens_q"] but never writes
   it — the normal caller (ModelRunner.prepare_inputs) sets it for the whole
   batch. For a mixed batch that buffer holds the full [prefill|decode]
   cumulative seqlens, so the decode rows were offset by n_prefill_tokens.
   swa_write then derived src_id past the end of the 31-token decode kv ->
   OOB. Reset var["cu_seqlens_q"] to decode-local cumulative seqlens (1 token
   per decode seq) before calling prepare_decode.

2. forward_impl mixed split: the per-segment re-entry swapped attn_metadata
   and is_prefill but not ctx.input_ids, so the hash-MoE _hash_topk (which
   reads ctx.input_ids) saw the full-batch ids while each segment processed a
   slice. Slice ctx.input_ids per segment and restore after.

Verified: V4-Pro chunked+mixed benchmark (ISL1024 chunk256 conc32) now
completes with no fault; 12 mixed batches with partial-chunk prefill rows
processed through all 62 layers. 36 scheduler tests pass.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed-batch builder's prefill half (prepare_prefill) issues async H2D
copies from SHARED pinned forward_vars buffers via _stage ->
copy_to_gpu(non_blocking=True). The subsequent GPU-side .clone()s only guard
against GPU buffer reuse; they are enqueued, not executed. The decode half
(prepare_decode) and the cu_seqlens_q reset then overwrite those same pinned
CPU buffers on the host thread without waiting for the prefill DMA to drain.

With a short prefill chunk the DMA completes before the overwrite, so the bug
was invisible in the ISL-1024/chunk-256 repro. With a long prefill chunk
(ISL 8192) the in-flight DMA window is wide enough that the host overwrite
races the copy -> prefill index tensors receive decode-half values ->
downstream tensor[idx] indexes out of range (GPU memory-access fault in
index_kernel_impl).

Diagnosis: adding any host sync (an .item() bounds-check probe) made the crash
vanish and an in-bounds assert never fired -> async race, fault kernel is the
victim. Fix: drain the current stream after prefill staging+cloning, before
decode reuses the pinned buffers. Mixed is eager and rare; one sync per mixed
batch is acceptable.

Verified: V4-Pro ISL8192/chunk2048 conc64 -> 19 mixed batches, 0 faults,
128/128 requests successful (was: Memory access fault on first long-prefill
mixed batch).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Cherry-picks the long_prefill_token_threshold logic from
fpz/chunk_long_prefill (only the threshold parts; skips the unrelated
partial-prefill debug-dump helpers in that WIP commit). The flag caps a
single request's per-step prefill chunk in both the Phase-1 resume and
Phase-2 admission paths, so a long prefill no longer saturates
max_num_batched_tokens — leaving budget for decode rows to interleave.

Without this, a conc-2048 / ISL-8192 run packed every step full of
prefill and crashed before any decode (and thus any mixed batch) could
form. With --long-prefill-token-threshold 1792 the scheduler produces
proper "1 prefill (<=1792 tok) + 255 decode" mixed batches and the run
clears the prior crash point (1800 prefill + 13 mixed batches, 0 faults).

Also disables deferred output when --enable-mixed-prefill-decode is on:
the deferred GPU-gather path for mixed is a known-buggy follow-up
(placeholder accuracy bug + suspected scale-dependent corruption), so it
must stay off until fixed separately.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
A single ROCm fault dumps a 30-50 GB gpucore per rank; on an 8-GPU TP run
this fills the disk in seconds and triggers apport "execvp failed" noise.
The debug-agent wrapper already set `ulimit -c 0`; the normal launcher did
not, so a production fault risked filling the disk. Match the debug path.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Verified deferred+mixed is crash-free at conc 2048 / ISL 8192 (deferred
forced on -> 4 mixed batches, 0 faults). The earlier "suspected
scale-dependent corruption" note was a misdiagnosis: the conc-2048 crash
was the missing long_prefill_token_threshold (prefill saturation), not
deferred. Deferred's only confirmed issue is the placeholder accuracy bug.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
… style)

ATOM scheduled prefill-first (resume partial prefills -> admit new prefills ->
give decode the leftover budget), so a 1792-token prefill chunk consumed the
whole max_num_batched_tokens budget and decode could only sneak in during the
brief window when prefill was nearly drained. Result: at ISL 4096 / conc 256
only 2 mixed batches formed out of ~1120 steps (PREFILL x494 -> MIXED x2 ->
DECODE x500).

Mirror vLLM V1's running-before-waiting ordering via a decode-budget
reservation: when --enable-mixed-prefill-decode is on, reserve the in-flight
decodes' token budget (n_decode * (mtp_k+1)) up front, so the prefill phases
only spend `max_num_batched_tokens - reserve`. The decode phase then consumes
the reserved remainder from the full budget. Net effect equals decode-first for
mixed-batch formation, while keeping ATOM's existing phase bodies, PD/MTP/
preempt/carryover logic, and prefill-first batch layout untouched (low risk).
long_prefill_token_threshold reverts to a plain per-chunk cap, no longer a
prefill/decode ratio knob.

Flag-off => reserve 0 => prefill_budget == max_num_batched_tokens =>
byte-identical to the old behavior.

Verified at ISL 4096 / conc 256 (CUDAGraph): mixed batches now dominate the
steady state (PREFILL x3 -> MIXED x134+, e.g. "2 prefill + 114 decode,
1934+114 tokens"), decode count grows as requests flow from prefill into
decode, 0 faults. Unit tests 39/39 (3 new: mixed forms when decode in flight,
prefill-first layout, flag-off stays prefill-only).

Also drop the per-step req_ids tuple from the schedule log (too long at high
concurrency) and the pure-decode log line.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Deferred output is auto-disabled under the mixed flag because of a known
accuracy bug (placeholder token across a prefill chunk). That bug is NOT a
crash — verified crash-free at conc 2048 / ISL 8192 — and deferred is faster,
so ATOM_FORCE_DEFERRED=1 re-enables it to measure mixed+deferred throughput
while the accuracy fix is pending. Not for accuracy-sensitive runs.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed prefill+decode split re-enters forward_impl once as pure prefill and
once as pure decode, so a kineto trace only showed inner prefill[...] /
decode[...] tags and could not distinguish a real pure-prefill/pure-decode step
from the two segments of a mixed step. Wrap the whole dispatch in
record_function("mixed[n_p=.. n_d=..]") so the steady-state mixed batches are
identifiable in the trace. No behavior change.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant