Fpz/mixed mla dispatch#1271
Draft
jiayyu wants to merge 22 commits into
Draft
Conversation
…ill-decode (P2-M1)
Adds a config flag (default off) that lets the scheduler pack prefill
chunks and decode seqs into the same ScheduledBatch — Phase 2 of chunked
prefill. With the flag off, behavior is unchanged: prefill-only or
decode-only batches as before.
- atom/config.py: new enable_mixed_prefill_decode field (default False)
- atom/model_engine/arg_utils.py: --enable-mixed-prefill-decode CLI flag
- atom/model_engine/scheduler.py:
- ScheduledBatch.is_mixed = both prefill and decode rows present
- ScheduledBatch.num_prompt_tokens for runner-side final-chunk detection
- schedule() fall-through path: when flag is on, decode loop runs after
Phase 2 prefill against a shared num_batched_tokens budget and
shared max_num_seqs slot count
- decode_carryover + decode_scheduled lists preserve FCFS ordering
when seqs are popped from running but not picked this step
- is_partial_prefill stays False whenever any decode row is present
- tests/test_scheduler.py: 6 new TestMixedBatch cases covering flag-off
back-compat, mixed batch production, is_partial_prefill semantics,
no double-scheduling, and pure-prefill/pure-decode fallbacks
Backends still raise on is_mixed batches; the flag is meaningful only
after the MHA/MLA Split Dispatch lands in P2-M2/M3.
…-M4) prepare_input_ids now writes both the prefill and decode token regions when batch.is_mixed, instead of early-returning the prefill-only slice. ParallelLMHead.forward gains a mixed-mode branch that gathers last-token-per-prefill-seq for the prefill rows and keeps every token for the decode rows, so the sampler sees one logits row per sequence (layout: [prefill_last × n_p_seqs | decode_tokens]). Context gains num_prefill_seqs alongside the existing num_prefill_tokens so the lm_head gather skips the per-step CPU sync. MTP / deferred-output integration with mixed batches is not yet implemented — guarded by explicit asserts and a NotImplementedError at the spec_decode_metadata site so misconfiguration fails loud. DP get_next_batch_info needs no change: mixed batches always carry prefill seqs, so the existing eligible_waiting branch already returns is_prefill=True for cross-rank dummy-prefill sync. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
… (P2-M5) TBO ubatch splitting assumes a uniform-layout batch (all prefill or all decode rows); the mixed [prefill | decode] layout would need a custom ubatch split, which isn't implemented yet. Disable TBO when batch.is_mixed so we don't silently produce wrong slices. Document mixed-batch semantics in scheduler.get_next_batch_info: mixed batches report is_prefill=True (they always carry a prefill seq), so the existing dummy-prefill DP sync path in engine_core handles them without additional change. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Implements the attention split dispatch that lets a mixed batch (prefill chunks + decode seqs packed together by P2-M1) actually run on dense MLA (DeepSeek-V2/V3). Gated by --enable-mixed-prefill-decode (default off); prefill-only and decode-only fast paths are unchanged. - forward_context: AttentionMetaData gains nested prefill_attn_metadata / decode_attn_metadata sub-fields for split dispatch. - aiter_mla: prepare_mixed() builds both halves — reuses prepare_prefill and clones its tensors off the shared forward_vars buffers before the decode half rewrites them; decode half rebuilt for n_decode_seqs. Surfaces prefill cu_seqlens_q on the top-level metadata for the lm_head gather. - attention_mla: forward_impl is_mixed branch splits q/k/v at num_prefill_tokens — prefill rows take the MHA path, decode rows the MLA latent path, outputs concatenated. - embed_head: check is_mixed BEFORE is_prefill (mixed batches also have is_prefill=True) so the combined gather runs. - model_runner: auto-disable deferred output when the mixed flag is on (the [prefill|decode] input-id layout isn't wired into deferred yet). - guards: backends base prepare_mixed, V4 builder, and MHA forward raise/assert on is_mixed so unsupported paths fail loud. Verified on DeepSeek-R1-0528 (TP8, fp8 KV): 670 mixed batches scheduled during GSM8K 3-shot; flexible-extract 0.9469 vs 0.9477 baseline, within the noise band. 399 relevant unit tests pass with the flag off. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Replaces the NotImplementedError guard with a real prepare_mixed that
builds both per-segment metadata sets by reusing the validated
prepare_prefill / prepare_decode unmodified:
- prefill rows [0:n_p] via prepare_prefill (already correct — prefill is
first), then clone its shared forward_vars/_stage-aliasing tensors
(batch_id_per_token, n_committed_csa_per_seq, state_slot_mapping,
block_tables, cu_seqlens_q/k, context_lens, indexer_meta) before the
decode half overwrites them.
- decode rows [n_p:] via prepare_decode against a thin _MixedDecodeView
that slices only the fields prepare_decode reads.
- merged AttentionMetaData_DSV4 carries both as prefill_attn_metadata /
decode_attn_metadata plus is_mixed + num_{prefill,decode}_{tokens,seqs}
markers and merged positions for the shared full-tensor forward ops.
forward_impl still raises on is_mixed until P2-P4 land the per-segment
split; this lands the metadata layer first. Gated by
--enable-mixed-prefill-decode (default off).
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Adds the is_mixed branch to DeepseekV4Attention.forward_impl for Dense (ratio==0) layers: splits q_sa/kv/positions at num_prefill_tokens, runs the decode slice (swa_write BEFORE → sparse_attn_v4_paged_decode) and the prefill slice (sparse_attn_v4_paged_prefill → swa_write AFTER) against their own per-segment sub-metadata, then concatenates outputs. The non-mixed decode swa_write is gated `and not is_mixed`. Asserts ratio==0 for now; CSA (P3) and HCA (P4) follow. Since every V4 model is hybrid (V4-Pro has only 1 Dense layer of 62), this is not yet runnable end-to-end on its own — it lands the Dense slice of the split. Gated by --enable-mixed-prefill-decode (default off). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Supersedes the in-place Dense-only split (828bfd6, reverted): instead of forking forward_impl at every is_decode site, run the whole validated forward body once per segment. Each sub-metadata (prefill_attn_metadata / decode_attn_metadata from prepare_mixed) is a complete self-consistent V4 metadata object, so for a mixed batch we point fc.attn_metadata + fc.context.is_prefill at each segment in turn, re-enter forward_impl as a pure prefill / pure decode forward, and concatenate outputs. This covers Dense (ratio==0), CSA (ratio==4) and HCA (ratio==128) in one change because every per-ratio path (swa_write ordering, Indexer prefill-vs-decode dispatch, csa_translate_pack, sparse kernels, Compressor) falls out of re-running the existing non-mixed code. The non-mixed path is byte-for-byte unchanged. Tradeoff: Q/KV projections + qk_norm_rope + output LoRA run twice (once per segment) rather than shared on the merged tensor — correct (disjoint tokens), with a fusion opportunity left for later. Gated by --enable-mixed-prefill-decode (default off). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The V4 builder has no is_sparse attribute (that lives on the AiterMLA builder); the leftover `assert not self.is_sparse or True` placeholder raised AttributeError on the first mixed batch. Remove it. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The ParallelLMHead mixed-batch gather reads attn_metadata.cu_seqlens_q to locate per-prefill-seq last tokens; the merged V4 mixed metadata left it None, raising TypeError. Point it at the prefill sub-meta's cu_seqlens_q (same fix as the dense-MLA prepare_mixed). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed builder cloned prefill per-fwd tensors off the shared
forward_vars buffers but missed compress_plans: each CompressPlan's
compress_plan_gpu / write_plan_gpu are views into the shared
var["v4_compress_plan_{ratio}"] / var["v4_write_plan_{ratio}"] buffers,
which the subsequent prepare_decode(_build_compress_plans) overwrites. The
prefill segment's Compressor then read the DECODE plan — whose ragged_id /
batch_id reference the decode segment's smaller kv / state_slot_mapping —
causing a MEMORY_VIOLATION in _update_compressor_states_kernel (located via
rocm-debug-agent). Clone both GPU tensors per CompressPlan.
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Removes the blanket auto-disable of deferred output when --enable-mixed-prefill-decode is on, and implements the deferred decode-input path for mixed batches in prepare_input_ids: - decode region starts at decode_offset = total_tokens_prefill (prefill rows already written above); - decode seqs present in prev_batch are gathered GPU-side from prev_token_ids by req_id, written at their decode-segment batch position (row-aligned with the decode attention metadata); - decode seqs not in prev_batch fall back to scheduled_tokens. - prev_batch / prev_token_ids stay advanced only in prepare_sampled_ids. The change is confined to the is_mixed branch; the proven non-mixed prefill/decode/deferred paths are byte-for-byte unchanged. MTP + mixed remains asserted off. Runs end-to-end without crashes on dense-MLA (R1) and V4-Pro sparse (GPU fault-free, mixed batches scheduled). KNOWN ISSUE: a decode seq that idles across prefill-chunk steps (so it is not in the immediately-previous batch) currently falls back to scheduled_tokens, which under deferred output holds a placeholder rather than the seq's last real token — causing a partial GSM8K accuracy drop (R1 full: 0.87 vs 0.9469 non-deferred-mixed). Fix to follow; for accuracy-sensitive runs keep --enable-mixed-prefill-decode off or run without this path until the placeholder sourcing is corrected. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…round The earlier MLA NaN fix sanitized the seed lse at the call site (nan_to_num on suf_lse). Now that merge_attn_states lives in ATOM (triton_merge_attn_states.py, sole caller is the MLA chunked path — plugin/vllm imports vllm's own copy), fix it at the root instead. When a token's prefix AND suffix are both empty (max_lse == -inf), the kernel computed -inf-(-inf)=NaN and a 0/0 scale that poisoned the output. This is reachable in ATOM's global-axis chunked prefill: a short seq can fall entirely outside a chunk. Guard both_empty: force a finite 0/0-split so out=0 (correct for empty attention) and keep lse=-inf. The call-site nan_to_num is now redundant and reverts to chunked_lse = suf_lse. Verified GSM8K (R1-MXFP4, tp4, fp8, num_concurrent=64, long-prefill 512): 0.9431 — same as the call-site workaround, no regression. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> (cherry picked from commit 591dfef)
HCA prefill used the per-seq committed count (ctx_end//128) for every token, missing the (pos+1)//128 per-token causal cap that CSA already has (and that the reference get_compress_topk_idxs applies). Under chunked prefill ctx_end is the chunk's end, so the same logical token saw a different number of HCA compressed groups depending on which chunk computed it -> chunked != single-shot -> ~0.02 GSM8K drop. Cap HCA per-token visibility to min((pos+1)//128, n_committed_hca) in the indptr build, the prefill-indices kernel (new HCA_RATIO constexpr), and the reference impl. Decode is unaffected (decode token is at seq end, the cap is a no-op). Verified GSM8K (V4-Pro, num_concurrent=4, fp8): chunked 0.93 -> 0.9507, single-shot 0.9515 (no regression). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> (cherry picked from commit 0ac25d7)
Pre-existing one-line while condition that black normalizes; keeps the tree black-clean (CI enforced). No behavior change. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
With V4 chunked prefill enabled (post-rebase), a mixed batch can contain a partial-chunk prefill row alongside decode rows. Two bugs caused a GPU memory-access fault (located via kernel-marker bisection — fault was in the decode segment's swa_write): 1. prepare_mixed: prepare_decode READS var["cu_seqlens_q"] but never writes it — the normal caller (ModelRunner.prepare_inputs) sets it for the whole batch. For a mixed batch that buffer holds the full [prefill|decode] cumulative seqlens, so the decode rows were offset by n_prefill_tokens. swa_write then derived src_id past the end of the 31-token decode kv -> OOB. Reset var["cu_seqlens_q"] to decode-local cumulative seqlens (1 token per decode seq) before calling prepare_decode. 2. forward_impl mixed split: the per-segment re-entry swapped attn_metadata and is_prefill but not ctx.input_ids, so the hash-MoE _hash_topk (which reads ctx.input_ids) saw the full-batch ids while each segment processed a slice. Slice ctx.input_ids per segment and restore after. Verified: V4-Pro chunked+mixed benchmark (ISL1024 chunk256 conc32) now completes with no fault; 12 mixed batches with partial-chunk prefill rows processed through all 62 layers. 36 scheduler tests pass. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed-batch builder's prefill half (prepare_prefill) issues async H2D copies from SHARED pinned forward_vars buffers via _stage -> copy_to_gpu(non_blocking=True). The subsequent GPU-side .clone()s only guard against GPU buffer reuse; they are enqueued, not executed. The decode half (prepare_decode) and the cu_seqlens_q reset then overwrite those same pinned CPU buffers on the host thread without waiting for the prefill DMA to drain. With a short prefill chunk the DMA completes before the overwrite, so the bug was invisible in the ISL-1024/chunk-256 repro. With a long prefill chunk (ISL 8192) the in-flight DMA window is wide enough that the host overwrite races the copy -> prefill index tensors receive decode-half values -> downstream tensor[idx] indexes out of range (GPU memory-access fault in index_kernel_impl). Diagnosis: adding any host sync (an .item() bounds-check probe) made the crash vanish and an in-bounds assert never fired -> async race, fault kernel is the victim. Fix: drain the current stream after prefill staging+cloning, before decode reuses the pinned buffers. Mixed is eager and rare; one sync per mixed batch is acceptable. Verified: V4-Pro ISL8192/chunk2048 conc64 -> 19 mixed batches, 0 faults, 128/128 requests successful (was: Memory access fault on first long-prefill mixed batch). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Cherry-picks the long_prefill_token_threshold logic from fpz/chunk_long_prefill (only the threshold parts; skips the unrelated partial-prefill debug-dump helpers in that WIP commit). The flag caps a single request's per-step prefill chunk in both the Phase-1 resume and Phase-2 admission paths, so a long prefill no longer saturates max_num_batched_tokens — leaving budget for decode rows to interleave. Without this, a conc-2048 / ISL-8192 run packed every step full of prefill and crashed before any decode (and thus any mixed batch) could form. With --long-prefill-token-threshold 1792 the scheduler produces proper "1 prefill (<=1792 tok) + 255 decode" mixed batches and the run clears the prior crash point (1800 prefill + 13 mixed batches, 0 faults). Also disables deferred output when --enable-mixed-prefill-decode is on: the deferred GPU-gather path for mixed is a known-buggy follow-up (placeholder accuracy bug + suspected scale-dependent corruption), so it must stay off until fixed separately. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
A single ROCm fault dumps a 30-50 GB gpucore per rank; on an 8-GPU TP run this fills the disk in seconds and triggers apport "execvp failed" noise. The debug-agent wrapper already set `ulimit -c 0`; the normal launcher did not, so a production fault risked filling the disk. Match the debug path. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Verified deferred+mixed is crash-free at conc 2048 / ISL 8192 (deferred forced on -> 4 mixed batches, 0 faults). The earlier "suspected scale-dependent corruption" note was a misdiagnosis: the conc-2048 crash was the missing long_prefill_token_threshold (prefill saturation), not deferred. Deferred's only confirmed issue is the placeholder accuracy bug. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
… style) ATOM scheduled prefill-first (resume partial prefills -> admit new prefills -> give decode the leftover budget), so a 1792-token prefill chunk consumed the whole max_num_batched_tokens budget and decode could only sneak in during the brief window when prefill was nearly drained. Result: at ISL 4096 / conc 256 only 2 mixed batches formed out of ~1120 steps (PREFILL x494 -> MIXED x2 -> DECODE x500). Mirror vLLM V1's running-before-waiting ordering via a decode-budget reservation: when --enable-mixed-prefill-decode is on, reserve the in-flight decodes' token budget (n_decode * (mtp_k+1)) up front, so the prefill phases only spend `max_num_batched_tokens - reserve`. The decode phase then consumes the reserved remainder from the full budget. Net effect equals decode-first for mixed-batch formation, while keeping ATOM's existing phase bodies, PD/MTP/ preempt/carryover logic, and prefill-first batch layout untouched (low risk). long_prefill_token_threshold reverts to a plain per-chunk cap, no longer a prefill/decode ratio knob. Flag-off => reserve 0 => prefill_budget == max_num_batched_tokens => byte-identical to the old behavior. Verified at ISL 4096 / conc 256 (CUDAGraph): mixed batches now dominate the steady state (PREFILL x3 -> MIXED x134+, e.g. "2 prefill + 114 decode, 1934+114 tokens"), decode count grows as requests flow from prefill into decode, 0 faults. Unit tests 39/39 (3 new: mixed forms when decode in flight, prefill-first layout, flag-off stays prefill-only). Also drop the per-step req_ids tuple from the schedule log (too long at high concurrency) and the pure-decode log line. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Deferred output is auto-disabled under the mixed flag because of a known accuracy bug (placeholder token across a prefill chunk). That bug is NOT a crash — verified crash-free at conc 2048 / ISL 8192 — and deferred is faster, so ATOM_FORCE_DEFERRED=1 re-enables it to measure mixed+deferred throughput while the accuracy fix is pending. Not for accuracy-sensitive runs. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The mixed prefill+decode split re-enters forward_impl once as pure prefill and
once as pure decode, so a kineto trace only showed inner prefill[...] /
decode[...] tags and could not distinguish a real pure-prefill/pure-decode step
from the two segments of a mixed step. Wrap the whole dispatch in
record_function("mixed[n_p=.. n_d=..]") so the steady-state mixed batches are
identifiable in the trace. No behavior change.
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist