diff --git a/atom/config.py b/atom/config.py index d9b582d601..446a54ef80 100644 --- a/atom/config.py +++ b/atom/config.py @@ -981,6 +981,7 @@ class Config: model: str trust_remote_code: bool = False max_num_batched_tokens: int = 16384 + long_prefill_token_threshold: int = 0 attn_prefill_chunk_size: int = 16384 scheduler_delay_factor: float = 0.0 max_num_seqs: int = 512 @@ -999,6 +1000,11 @@ class Config: kv_cache_dtype: str = "bf16" enable_prefix_caching: bool = True enable_chunked_prefill: bool = True + # Mix prefill chunks and decode seqs into the same forward pass (Phase 2 + # of chunked prefill). Default off until the attention backends grow + # split-dispatch support — when off, scheduler emits prefill-only or + # decode-only batches as before. + enable_mixed_prefill_decode: bool = False port: int = 8006 torch_profiler_dir: str | None = field( default_factory=lambda: envs.ATOM_TORCH_PROFILER_DIR @@ -1104,6 +1110,19 @@ def __post_init__(self): self.max_model_len, hf_config_max_position_embeddings ) # assert self.max_num_batched_tokens >= self.max_model_len + if self.long_prefill_token_threshold > 0: + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + f"long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than max_model_len ({self.max_model_len})." + ) + if self.long_prefill_token_threshold < self.kv_cache_block_size: + raise ValueError( + f"long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) must be >= " + f"kv_cache_block_size ({self.kv_cache_block_size})." + ) if not is_plugin_mode(): if self.torch_profiler_dir is not None: os.makedirs(self.torch_profiler_dir, exist_ok=True) diff --git a/atom/model_engine/arg_utils.py b/atom/model_engine/arg_utils.py index 0cba4b2e33..97d6b38570 100644 --- a/atom/model_engine/arg_utils.py +++ b/atom/model_engine/arg_utils.py @@ -33,11 +33,13 @@ class EngineArgs: data_parallel_size: int = 1 enforce_eager: bool = False enable_prefix_caching: bool = True + enable_mixed_prefill_decode: bool = False port: int = 8006 kv_cache_dtype: str = "bf16" block_size: int = 16 max_model_len: Optional[int] = None max_num_batched_tokens: int = 16384 + long_prefill_token_threshold: int = 0 attn_prefill_chunk_size: int = 16384 enable_chunked_prefill: bool = True scheduler_delay_factor: float = 0.0 @@ -97,6 +99,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help="Enable prefix caching (default: enabled). " "Use --no-enable_prefix_caching to disable.", ) + parser.add_argument( + "--enable-mixed-prefill-decode", + action="store_true", + help="Pack prefill chunks and decode seqs into the same forward " + "pass. Requires attention backends with split-dispatch support; " + "off by default.", + ) parser.add_argument( "--port", type=int, @@ -192,6 +201,17 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default=16384, help="Maximum number of tokens to batch together in async engine", ) + parser.add_argument( + "--long-prefill-token-threshold", + type=int, + default=0, + help=( + "For chunked prefill, cap a single request's per-step prefill " + "size at this many tokens. 0 disables the cap (request is only " + "bounded by max_num_batched_tokens). Useful to interleave long " + "prefills with decode for lower ITL." + ), + ) parser.add_argument( "--attn-prefill-chunk-size", type=int, diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index db018399e2..f9054bd22b 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -90,7 +90,26 @@ def __init__( num_spec_tokens: int = 0, ): """Asynchronously copy the sampled_token_ids tensor to the host.""" - self.is_deferred_out = True + # Deferred output is disabled when running in P/D disaggregation mode + # (kv_transfer_config is set), enabled otherwise. + # TODO: In P/D disaggregation mode, if have issue, we can disable it + # Mixed prefill+decode: the deferred GPU-gather path has a known + # accuracy bug (idle decode seqs read a placeholder token across a + # prefill chunk — R1 GSM8K 0.87 vs 0.9469). It is NOT a memory bug: + # verified crash-free at conc 2048 / ISL 8192 with deferred forced on. + # Disable deferred under the mixed flag until the accuracy bug is fixed. + self.is_deferred_out = not getattr( + runner.config, "enable_mixed_prefill_decode", False + ) + # Escape hatch: ATOM_FORCE_DEFERRED=1 forces deferred output ON even + # under the mixed flag. Deferred+mixed has a known accuracy bug (not a + # crash — verified crash-free at conc 2048 / ISL 8192) but is faster, so + # this lets us measure mixed+deferred throughput while the accuracy fix + # is pending. Do NOT use for accuracy-sensitive runs. + import os as _os + + if _os.environ.get("ATOM_FORCE_DEFERRED") == "1": + self.is_deferred_out = True self.runner = runner device = runner.device @@ -355,6 +374,7 @@ def prepare_input_ids( total_tokens_prefill = batch.total_tokens_num_prefill total_tokens_decode = batch.total_tokens_num_decode total_reqs_prefill = batch.total_seqs_num_prefill + is_mixed = getattr(batch, "is_mixed", False) """for prefill: all input ids are new""" self.input_ids.np[:total_tokens_prefill] = scheduled_tokens[ :total_tokens_prefill @@ -363,6 +383,75 @@ def prepare_input_ids( self.prev_rejected_num, self.prev_bonus_num = self.recv_mtp_status_async() + if is_mixed: + # Mixed batch layout: [prefill_tokens | decode_tokens]. The prefill + # region is already written above. Fill the decode region (one token + # per decode seq, in batch order — which matches the decode attention + # metadata's row order) starting at `decode_offset`. + # MTP / speculative decode with mixed batches is a separate follow-up + # (the per-seq multi-token layout isn't wired into this branch). + assert not self.use_spec, ( + "Mixed prefill+decode batches do not yet support MTP / speculative " + "decode (follow-up). Disable --enable-mixed-prefill-decode for now." + ) + decode_offset = total_tokens_prefill + sched_decode = scheduled_tokens[ + decode_offset : decode_offset + total_tokens_decode + ] + + # Non-deferred OR first step (no prior batch to gather from): decode + # inputs come straight from scheduled_tokens. This is the path + # already verified at GSM8K parity on R1 / V4-Pro. + if not self.is_deferred_out or self.prev_batch is None: + self.input_ids.np[ + decode_offset : decode_offset + total_tokens_decode + ] = sched_decode + self.input_ids.copy_to_gpu(total_tokens) + return self.input_ids.gpu[:total_tokens] + + # Deferred path: each decode seq's input is the token sampled for it + # last step, kept on-GPU in `prev_token_ids` (ordered by + # prev_batch.req_ids). Map current decode seqs — batch positions + # [n_prefill_seqs:] — to their prev_batch slot. A decode seq is + # ALWAYS in prev_batch in steady state (it decoded last step); a + # genuinely-new decode row (just finished prefill elsewhere) falls + # back to scheduled_tokens. Prefill rows are never gathered. + # + # Destination index = decode_offset + i, where i is the decode seq's + # position within the decode segment — identical to the decode + # attention metadata row order, guaranteeing alignment. + n_prefill_seqs = batch.total_seqs_num_prefill + prev_id_to_idx = {rid: j for j, rid in enumerate(self.prev_batch.req_ids)} + deferred_dst: list[int] = [] + deferred_prev: list[int] = [] + for i, rid in enumerate(batch.req_ids[n_prefill_seqs:]): + prev_idx = prev_id_to_idx.get(rid) + if prev_idx is not None: + deferred_dst.append(decode_offset + i) + deferred_prev.append(prev_idx) + + # Baseline the whole decode region from scheduled_tokens (correct for + # any new decode seq not in prev_batch). Deferred positions are then + # overwritten GPU-side by the gather below. + self.input_ids.np[decode_offset : decode_offset + total_tokens_decode] = ( + sched_decode + ) + self.input_ids.copy_to_gpu(total_tokens) + + if deferred_dst: + self.input_ids_loc.np[: len(deferred_prev)] = deferred_prev + prev_idx_gpu = self.input_ids_loc.copy_to_gpu(len(deferred_prev)) + gathered = torch.gather(self.prev_token_ids, 0, prev_idx_gpu) + dst_gpu = torch.as_tensor( + deferred_dst, dtype=torch.long, device=self.input_ids.gpu.device + ) + self.input_ids.gpu[dst_gpu] = gathered.to(self.input_ids.gpu.dtype) + + # prev_batch / prev_token_ids are advanced by prepare_sampled_ids + # (postprocess) after sampling — NOT here, exactly like the non-mixed + # deferred path. + return self.input_ids.gpu[:total_tokens] + # TODO: remove this when we support mixed prefill and decode in one batch if total_reqs_prefill > 0: return self.input_ids.gpu[:total_tokens_prefill] @@ -1660,6 +1749,10 @@ def _maybe_create_tbo_slices( With the packed-reduce path the eligibility (local + cross-DP AND) is decided in ``_preprocess``; here we just realise the split. """ + if getattr(batch, "is_mixed", False): + # TBO ubatch splitting on a [prefill | decode] layout is not yet + # supported (P2-M5 follow-up). Run mixed batches without TBO. + return None if not tbo_collective_active: return None @@ -1758,6 +1851,7 @@ def _preprocess( def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): is_prefill = batch.total_tokens_num_prefill > 0 + is_mixed = getattr(batch, "is_mixed", False) bs = batch.total_seqs_num num_scheduled_tokens = np.asarray(batch.num_scheduled_tokens) cu_seqlens_q, arange = self._get_cumsum_and_arange(num_scheduled_tokens) @@ -1814,11 +1908,19 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None): graph_bs=graph_bs, dp_uniform_decode=dp_uniform_decode, forward_mode=forward_mode, + is_mixed=is_mixed, + num_prefill_tokens=batch.total_tokens_num_prefill if is_mixed else 0, + num_prefill_seqs=batch.total_seqs_num_prefill if is_mixed else 0, ) actual_num_tokens = batch.total_tokens_num spec_decode_metadata = None + if is_mixed and hasattr(self, "drafter") and not batch.is_dummy_run: + raise NotImplementedError( + "Mixed prefill+decode batches do not yet support MTP / speculative " + "decode (P2-M4 follow-up). Disable --enable-mixed-prefill-decode." + ) if not is_prefill and hasattr(self, "drafter") and not batch.is_dummy_run: scheduled_bs = batch.total_seqs_num_decode spec_decode_metadata = self.drafter.calc_spec_decode_metadata( diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index dab098f1a9..fb02f065ed 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -332,6 +332,16 @@ def __init__( self.total_seqs_num_prefill = total_seqs_num_prefill self.total_seqs_num_decode = total_seqs_num_decode + # True iff this batch packs at least one prefill chunk together with + # at least one decode seq. Consumed by attention backends to dispatch + # the prefill rows and decode rows to different kernels (Phase 2 of + # chunked prefill — see docs/mixed_batch_design.md). + self.is_mixed = total_seqs_num_prefill > 0 and total_seqs_num_decode > 0 + # Per-row prompt length, aligned with `req_ids`. Used by the runner + # to decide which prefill rows are "final chunks" (need logits) vs + # intermediate chunks (skip compute_logits). + self.num_prompt_tokens = [seq.num_prompt_tokens for seq in seqs.values()] + self.connector_meta_output = connector_meta_output self.finished_recving_kv_req_ids: list[int] = [] @@ -415,6 +425,7 @@ class Scheduler: def __init__(self, config: Config): self.max_num_seqs = config.max_num_seqs self.max_num_batched_tokens = config.max_num_batched_tokens + self.long_prefill_token_threshold = config.long_prefill_token_threshold self.max_model_len = config.max_model_len self.bos_token_id = config.bos_token_id self.eos_token_id = config.eos_token_id @@ -453,6 +464,9 @@ def __init__(self, config: Config): CacheStats() if config.enable_prefix_caching else None ) self.enable_chunked_prefill = config.enable_chunked_prefill + self.enable_mixed_prefill_decode = getattr( + config, "enable_mixed_prefill_decode", False + ) # Number of running seqs currently mid-prefill (per-seq state lives in # `Sequence.is_partial_prefill`). Maintained as a counter so Phase 1 # of `schedule()` can skip the running-queue scan entirely on @@ -708,6 +722,27 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if not self.running and not self.waiting: return None + # ---- Decode-first budget reservation (vLLM V1 style) ---- + # vLLM schedules running decodes BEFORE new prefills within one shared + # token budget, so a long prefill chunk can never starve decode out of + # the step and mixed batches form naturally. ATOM keeps its existing + # prefill-first phase bodies (lower risk), but reserves the in-flight + # decodes' token budget up front so the prefill phases below only spend + # `max_num_batched_tokens - decode_token_reserve`. The decode phase then + # consumes the reserved remainder from the full budget. Net effect is + # identical to decode-first for mixed-batch formation. Only active when + # mixed batching is enabled; flag-off => reserve 0 => byte-identical to + # the old prefill-first behavior. + decode_token_reserve = 0 + if self.enable_mixed_prefill_decode: + n_decode_inflight = sum(1 for s in self.running if not s.is_partial_prefill) + n_decode_inflight = min(n_decode_inflight, self.max_num_seqs) + decode_token_reserve = min( + n_decode_inflight * (self.mtp_k + 1), + self.max_num_batched_tokens, + ) + prefill_budget = self.max_num_batched_tokens - decode_token_reserve + # ---- Phase 1: resume partial prefills from running ---- # Gated by `_delayer_allows_prefill` so cross-DP alignment still # holds when one rank is mid-chunked-prefill: a delayer veto skips @@ -722,7 +757,9 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if not seq.is_partial_prefill: continue remaining = seq.num_prompt_tokens - seq.num_cached_tokens - budget_remaining = self.max_num_batched_tokens - num_batched_tokens + if 0 < self.long_prefill_token_threshold < remaining: + remaining = self.long_prefill_token_threshold + budget_remaining = prefill_budget - num_batched_tokens chunk = min(remaining, budget_remaining) if chunk <= 0: break @@ -738,7 +775,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: and (self.delay_factor <= 0 or self._passed_delay(time.time())) and self.waiting and num_seqs_prefill < self.max_num_seqs - and num_batched_tokens < self.max_num_batched_tokens + and num_batched_tokens < prefill_budget ): seq = self.waiting.popleft() @@ -821,7 +858,12 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_new_tokens = ( seq.num_tokens - num_cached_blocks * self.block_manager.block_size ) - budget_remaining = self.max_num_batched_tokens - num_batched_tokens + if ( + self.enable_chunked_prefill + and 0 < self.long_prefill_token_threshold < num_new_tokens + ): + num_new_tokens = self.long_prefill_token_threshold + budget_remaining = prefill_budget - num_batched_tokens if self.enable_chunked_prefill: chunk = min(num_new_tokens, budget_remaining) else: @@ -872,7 +914,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: total_tokens_num_prefill = sum(num_scheduled_tokens) - if num_seqs_prefill > 0: + # Prefill-only fast path: behavior identical to pre-mixed-batch days. + # When the mixed flag is off, we never pack decode rows alongside + # prefill chunks, so emit the prefill batch immediately. + if num_seqs_prefill > 0 and not self.enable_mixed_prefill_decode: num_cached_tokens_list = [ seq.num_cached_tokens for seq in scheduled_seqs.values() ] @@ -884,7 +929,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: f"req_ids: {tuple(scheduled_seqs.keys())}" ) self.prev_prompt = True - # lip: TODO for prefill/decode mixed batch connector_meta_output = None if self.kv_connector is not None: @@ -903,17 +947,38 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: scheduled_seqs, ) - # --- Decode scheduling --- + # --- Decode scheduling (also fall-through for mixed batches) --- + # Three queue states we must handle here when prefills were already + # scheduled this step: + # 1. Partial prefills resumed in Phase 1 are still in `running` + # (Phase 1 didn't pop them) and are in `scheduled_seqs`. + # 2. New prefills admitted in Phase 2 were appended to the back of + # `running` and are in `scheduled_seqs`. + # 3. Partial prefills *not* picked up this step (budget exhausted) + # remain in `running` with `is_partial_prefill=True` and must + # not be decoded — they can only advance via prefill. + # We `popleft` from `running` and route into `decode_scheduled` (real + # decodes this step) or `decode_carryover` (skipped / kept seqs). num_seqs_decode = 0 - num_decode_tokens = 0 tokens_per_decode_seq = self.mtp_k + 1 num_new_tokens = self.mtp_k + 1 remote_kv_blocks: set[int] = set() remote_kv_seq_blocks: dict[int, list[int]] = {} - while self.running and num_seqs_decode < self.max_num_seqs: - if num_decode_tokens + tokens_per_decode_seq > self.max_num_batched_tokens: - break + decode_carryover: list[Sequence] = [] + decode_scheduled: list[Sequence] = [] + while self.running and num_seqs_prefill + num_seqs_decode < self.max_num_seqs: seq = self.running.popleft() + if seq.id in scheduled_seqs: + # Already scheduled as a prefill chunk this step — keep its slot. + decode_carryover.append(seq) + continue + if seq.is_partial_prefill: + # Mid-prefill that didn't make it into Phase 1 this step. + decode_carryover.append(seq) + continue + if num_batched_tokens + tokens_per_decode_seq > self.max_num_batched_tokens: + decode_carryover.append(seq) + break while not self.block_manager.can_append(seq, num_new_tokens): if self.running: self.preempt(self.running.pop()) @@ -924,7 +989,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if seq.spec_token_ids.size > 0: scheduled_spec_decode_tokens[seq.id] = seq.spec_token_ids num_seqs_decode += 1 - num_decode_tokens += num_new_tokens + num_batched_tokens += num_new_tokens # For PD first-decode: if T0 was injected, may_append is # needed for the new position N. Without T0 injection, # blocks were already allocated during prefill. @@ -953,12 +1018,38 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: scheduled_seqs[seq.id] = seq seq.type = SequenceType.DECODE num_scheduled_tokens.append(num_new_tokens) + decode_scheduled.append(seq) seq.is_first_decode = False - total_tokens_num_decode = sum(num_scheduled_tokens) + # Restore running queue order: carryover keeps its previous position, + # decoded seqs move to the front (FCFS), prefills appended in Phase 2 + # remain wherever they were placed. + if decode_carryover: + self.running.extendleft(reversed(decode_carryover)) + if decode_scheduled: + self.running.extendleft(reversed(decode_scheduled)) - if scheduled_seqs: - self.running.extendleft(reversed(scheduled_seqs.values())) + if not scheduled_seqs: + return None + + # Recompute prefill/decode token totals from the per-row list. In + # mixed batches prefill rows come first (Phase 1 + Phase 2 appended), + # decode rows after, so the split is at `num_seqs_prefill`. + total_tokens_num_prefill = sum(num_scheduled_tokens[:num_seqs_prefill]) + total_tokens_num_decode = sum(num_scheduled_tokens[num_seqs_prefill:]) + total_tokens_num = total_tokens_num_prefill + total_tokens_num_decode + + num_cached_tokens_list = [ + seq.num_cached_tokens for seq in scheduled_seqs.values() + ] + + if num_seqs_prefill > 0: + self.prev_prompt = True + logger.info( + f"Scheduled {'mixed' if num_seqs_decode > 0 else 'prefill'} batch: " + f"{num_seqs_prefill} prefill + {num_seqs_decode} decode, " + f"{total_tokens_num_prefill}+{total_tokens_num_decode} tokens" + ) connector_meta_output = None if self.kv_connector is not None: @@ -967,7 +1058,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: decode_batch = ScheduledBatch( seqs=scheduled_seqs, num_scheduled_tokens=num_scheduled_tokens, - total_tokens_num=total_tokens_num_decode, + total_tokens_num=total_tokens_num, + total_tokens_num_prefill=total_tokens_num_prefill, total_tokens_num_decode=total_tokens_num_decode, total_seqs_num=num_seqs_prefill + num_seqs_decode, total_seqs_num_prefill=num_seqs_prefill, @@ -975,6 +1067,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: connector_meta_output=connector_meta_output, num_spec_step=self.mtp_k, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + num_cached_tokens=num_cached_tokens_list, remote_kv_block_ids=sorted(remote_kv_blocks) if remote_kv_blocks else [], remote_kv_seq_blocks=remote_kv_seq_blocks, ) @@ -1382,6 +1475,17 @@ def has_requests(self) -> bool: return self.has_unfinished_requests() def get_next_batch_info(self) -> tuple[bool, int, int]: + # Predicts the next batch shape for cross-DP-rank sync. Returns + # (is_prefill, num_tokens, num_reqs). + # + # Mixed prefill+decode batches (--enable-mixed-prefill-decode) report + # is_prefill=True because they always carry at least one prefill seq + # — that matches the dummy-prefill sync semantics in engine_core + # (all ranks must agree on "prefill phase" so MoE all-to-all stays in + # sync). num_tokens here is a prediction; the actual mixed batch may + # add decode tokens on top, but DP padding uses the post-schedule + # batch.total_tokens_num so the prediction underestimating is fine. + # Check for partial prefills in running (chunked prefill resume) for seq in self.running: if seq.num_cached_tokens < seq.num_prompt_tokens: diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 9ad2916282..0c1accc88f 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -155,6 +155,12 @@ def forward_impl( fwd_ctx: ForwardContext = get_forward_context() + assert not getattr(fwd_ctx.context, "is_mixed", False), ( + "MHA models do not support mixed prefill+decode batches yet " + "(split dispatch is implemented only for dense MLA). " + "Disable --enable-mixed-prefill-decode." + ) + # dummy run will skip attention in cuda graph capture phase if fwd_ctx.context.is_dummy_run: o = torch.empty_like(q) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 8349801135..d9fbeb381e 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -586,6 +586,9 @@ def _forward_prefill_cached_chunked( ) if chunked_out is None: chunked_out = suf_out + # A seq absent from this chunk has lse=-inf; both-(-inf) merges + # are handled inside merge_attn_states (see its both_empty + # guard), so the seed needs no sanitizing here. chunked_lse = suf_lse else: tmp_out = torch.empty_like(new_out) @@ -990,7 +993,105 @@ def forward_impl( kv_cache_data = forward_context.kv_cache_data kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache - if context.is_prefill and not use_prefill_mla: + if context.is_mixed: + # Mixed prefill+decode split dispatch: the first `num_prefill_tokens` + # rows are prefill chunks (MHA path), the rest are decode tokens (MLA + # latent path). Each half runs its own Q/KV/O projections against its + # own nested metadata, then outputs are concatenated (same hidden dim). + assert not self.is_sparse_mla, ( + "Mixed prefill+decode batches do not yet support sparse MLA " + "(V3.2/V4 indexer). Disable --enable-mixed-prefill-decode." + ) + assert not use_prefill_mla, ( + "Mixed prefill+decode batches do not support the prefill-MLA " + "(sparse) path. Disable --enable-mixed-prefill-decode." + ) + n_prefill = context.num_prefill_tokens + prefill_meta = attn_metadata.prefill_attn_metadata + decode_meta = attn_metadata.decode_attn_metadata + + # ---- Prefill half: MHA path ---- + q_p = q[:n_prefill] + k_nope_p = k_nope[:n_prefill] + k_rope_p = k_rope[:n_prefill] + positions_p = positions[:n_prefill] + + prefill_q = self.q_proj(q_p, x_scale=q_scale).view( + -1, self.num_heads, self.qk_head_dim + ) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + self.rotary_emb(positions_p, prefill_q_pe, k_rope_p) + + if kv_cache.numel() > 0: + concat_and_cache_mla( + k_nope_p, + k_rope_p.squeeze(1), + kv_cache, + prefill_meta.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + + if prefill_meta.has_cached: + chunk_meta = getattr(prefill_meta, "mla_chunk_meta", None) + if chunk_meta is not None: + out_prefill = self._forward_prefill_cached_chunked( + prefill_q, + k_nope_p, + k_rope_p, + kv_cache, + prefill_meta, + chunk_meta, + ) + else: + out_prefill = self._forward_prefill_cached_single_pass( + prefill_q, kv_cache, prefill_meta + ) + else: + out_prefill = self._forward_prefill_mha( + prefill_q, k_nope_p, k_rope_p, kv_cache, prefill_meta + ) + + # ---- Decode half: MLA latent path ---- + q_d = q[n_prefill:] + k_nope_d = k_nope[n_prefill:] + k_rope_d = k_rope[n_prefill:] + positions_d = positions[n_prefill:] + + q_nope_d, q_rope_d = self._q_proj_and_k_up_proj(q_d, x_scale=q_scale) + q_out_d = torch.empty( + ( + q_nope_d.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=decode_meta.dtype_q, + device=q_nope_d.device, + ) + if kv_cache.numel() > 0: + fused_qk_rope_concat_and_cache_mla( + q_nope_d, + q_rope_d, + k_nope_d, + k_rope_d, + kv_cache.view( + kv_cache.shape[0], -1, self.kv_lora_rank + self.qk_rope_head_dim + ), + q_out_d, + decode_meta.slot_mapping, + self._k_scale, + self._q_scale, + positions_d, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + is_nope_first=True, + ) + + out_decode = self._forward_decode(q_out_d, kv_cache, decode_meta) + + output = torch.cat([out_prefill, out_decode], dim=0) + elif context.is_prefill and not use_prefill_mla: prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index ee522ff6dc..4e66d61dac 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -813,6 +813,151 @@ def prepare_prefill(self, batch: ScheduledBatch): attn_metadata.dtype_q = self.dtype_q return attn_metadata, positions + def prepare_mixed(self, batch: ScheduledBatch, bs: int): + """Build split-dispatch metadata for a mixed prefill+decode batch. + + Returns one AttentionMetaData whose `prefill_attn_metadata` drives the + first `n_prefill_tokens` rows (MHA path) and whose `decode_attn_metadata` + drives the remaining decode rows (MLA latent path), plus a merged + `positions` tensor laid out [prefill | decode]. + + The prefill half reuses the tested `prepare_prefill`, then *clones* its + tensors off the shared `forward_vars` buffers so the decode half (which + rewrites those same buffers) cannot corrupt them. + """ + assert not self.is_sparse, ( + "Mixed prefill+decode batches do not yet support sparse MLA " + "(V3.2/V4 indexer). Disable --enable-mixed-prefill-decode." + ) + var = self.model_runner.forward_vars + n_p_seqs = batch.total_seqs_num_prefill + n_p_tokens = batch.total_tokens_num_prefill + n_d_seqs = batch.total_seqs_num_decode + n_d_tokens = batch.total_tokens_num_decode + total_tokens = n_p_tokens + n_d_tokens + + # ---- Prefill half: prefill rows are first, so prepare_prefill reads + # the correct [:n_p_seqs] slice. Capture positions before the decode + # half overwrites the shared positions buffer. ---- + prefill_meta, _ = self.prepare_prefill(batch) + prefill_positions_np = var["positions"].np[:n_p_tokens].copy() + + # Detach prefill metadata from shared forward_vars buffers. + for fname in ( + "cu_seqlens_q", + "cu_seqlens_k", + "slot_mapping", + "context_lens", + "block_tables", + "kv_indptr", + "kv_indices", + "kv_last_page_lens", + "num_cached_tokens", + "seq_starts", + ): + t = getattr(prefill_meta, fname, None) + if isinstance(t, torch.Tensor): + setattr(prefill_meta, fname, t.clone()) + + # ---- Decode half: decode rows live at [n_p_seqs:] in batch arrays but + # are packed into forward_vars buffer rows [0:n_d_seqs] so the persistent + # MLA work-buffer builder (which indexes from 0) sees a contiguous batch. + d_ctx = np.asarray(batch.context_lens[n_p_seqs:], dtype=np.int32) + d_block_tables = batch.block_tables[n_p_seqs:] + d_last_block = batch.last_block_num_tokens[n_p_seqs:] + max_q_d = 1 + max_k_d = int(d_ctx.max()) if n_d_seqs > 0 else 0 + + slot_d = [ + bt[-1] * self.model_runner.block_size + lbt - 1 + for bt, lbt in zip(d_block_tables, d_last_block) + ] + positions_d = (d_ctx - 1).astype(np.int32) + + # Mixed batches always run eager, so the decode half is sized to the + # actual decode seq count (no CUDAGraph batch padding). + d_bs = n_d_seqs + + block_tables_np = var["block_tables"].np + for i, bt in enumerate(d_block_tables): + block_tables_np[i] = 0 + block_tables_np[i, : len(bt)] = bt + + var["context_lens"].np[:d_bs] = d_ctx + + num_blocks_per_seq = cdiv(d_ctx, self.block_size) + kv_indptr = np.cumsum(num_blocks_per_seq) + var["kv_indptr"].np[0] = 0 + var["kv_indptr"].np[1 : d_bs + 1] = kv_indptr + var["kv_last_page_lens"].np[:d_bs] = d_last_block if self.block_size != 1 else 1 + # Decode is one query token per seq: cu_seqlens_q = [0, 1, 2, ..., d_bs]. + var["cu_seqlens_q"].np[: d_bs + 1] = np.arange(d_bs + 1, dtype=np.int32) + var["slot_mapping"].np[:d_bs] = slot_d + + d_block_tables_gpu = var["block_tables"].copy_to_gpu(d_bs) + kv_indptr_gpu = var["kv_indptr"].copy_to_gpu(d_bs + 1) + context_lens_gpu = var["context_lens"].copy_to_gpu(d_bs) + cu_q_gpu = var["cu_seqlens_q"].copy_to_gpu(d_bs + 1) + klp_gpu = var["kv_last_page_lens"].copy_to_gpu(d_bs) + slot_d_gpu = var["slot_mapping"].copy_to_gpu(d_bs) + + kv_indices_gpu = var["kv_indices"].gpu + kv_indices_generate_triton( + d_block_tables_gpu, + kv_indices_gpu, + kv_indptr_gpu, + self.block_ratio, + max_k_d, + ) + + ctx_ps = self.set_mla_persistent_worker_buffers(d_bs, max_q_d) + + decode_meta = AttentionMetaData( + cu_seqlens_q=cu_q_gpu, + max_seqlen_q=max_q_d, + max_seqlen_k=max_k_d, + slot_mapping=slot_d_gpu[:n_d_seqs], + context_lens=context_lens_gpu, + block_tables=d_block_tables_gpu, + kv_indptr=kv_indptr_gpu, + kv_indices=kv_indices_gpu, + kv_last_page_lens=klp_gpu, + **ctx_ps, + ) + decode_meta.dtype_q = self.dtype_q + + # ---- Merge positions and slot_mapping ([prefill | decode]) ---- + var["positions"].np[:n_p_tokens] = prefill_positions_np + var["positions"].np[n_p_tokens:total_tokens] = positions_d + positions = var["positions"].copy_to_gpu(total_tokens) + + merged_slot = np.empty(total_tokens, dtype=np.int64) + prefill_slot = prefill_meta.slot_mapping + merged_slot[:n_p_tokens] = ( + prefill_slot.cpu().numpy() + if isinstance(prefill_slot, torch.Tensor) + else prefill_slot + ) + merged_slot[n_p_tokens:total_tokens] = slot_d + # from_numpy keeps the tensor on CPU regardless of the active default-device + # guard (torch.tensor(..., pin_memory=True) raises under a CUDA guard). + merged_slot_gpu = torch.from_numpy(merged_slot).to( + self.device, non_blocking=True + ) + + attn_metadata = AttentionMetaData( + slot_mapping=merged_slot_gpu, + # Surface prefill cu_seqlens_q on the top-level metadata so the + # ParallelLMHead mixed-batch gather (embed_head.py) can find the + # per-prefill-seq last-token indices without reaching into the + # nested prefill metadata. + cu_seqlens_q=prefill_meta.cu_seqlens_q, + prefill_attn_metadata=prefill_meta, + decode_attn_metadata=decode_meta, + ) + attn_metadata.dtype_q = self.dtype_q + return attn_metadata, positions + def _build_mla_chunk_meta( self, batch: ScheduledBatch, bs: int ) -> Optional[MLAChunkContextMetadata]: diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index 36ac0daf32..9af2346658 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -445,12 +445,21 @@ def build_ubatch_prefill_metadata( return split_attn_metadata(attn_metadata, ub_slice, padded_bs) def build(self, batch: ScheduledBatch, bs: int): + if getattr(batch, "is_mixed", False): + return self.prepare_mixed(batch, bs) is_prefill = batch.total_tokens_num_prefill > 0 if is_prefill: return self.prepare_prefill(batch) else: return self.prepare_decode(batch, bs) + def prepare_mixed(self, batch: ScheduledBatch, bs: int): + raise NotImplementedError( + f"{type(self).__name__} does not support mixed prefill+decode " + "batches yet. Only the dense-MLA backend (AiterMLAMetadataBuilder) " + "implements split dispatch. Disable --enable-mixed-prefill-decode." + ) + class AttentionImpl(nn.Module): @abstractmethod diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 2ff9de3dd3..2dc40e37b9 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -232,6 +232,48 @@ class AttentionMetaData_DSV4(AttentionMetaData): kv_indptr_extend: Optional[torch.Tensor] = None """[total_tokens + 1] int32 GPU — packed cumsum of `extend_count`.""" + # ----- Mixed prefill+decode split dispatch (set by `prepare_mixed`) ----- + # When `is_mixed`, `prefill_attn_metadata` / `decode_attn_metadata` (on the + # base class) carry the two per-segment metadata sets; the forward splits + # the flat q/kv tensor at `num_prefill_tokens` and runs each segment through + # its own path. None/False for non-mixed batches. + is_mixed: bool = False + num_prefill_tokens: int = 0 + num_prefill_seqs: int = 0 + num_decode_tokens: int = 0 + num_decode_seqs: int = 0 + + +class _MixedDecodeView: + """Thin read-only view exposing the DECODE rows ``[n_prefill:]`` of a mixed + batch as if they were a standalone decode batch, so the unmodified + `prepare_decode` builds decode metadata for them. + + Only the fields `prepare_decode` (and `prepare_block_tables`) actually read + are sliced; everything else delegates to the wrapped batch. Per-row arrays + (`context_lens`, `block_tables`, `per_req_cache_groups`) are sliced to drop + the leading prefill rows; counts report the decode totals. + """ + + def __init__(self, batch: ScheduledBatch, n_prefill_seqs: int): + self._batch = batch + self._np = n_prefill_seqs + self.context_lens = batch.context_lens[n_prefill_seqs:] + self.block_tables = batch.block_tables[n_prefill_seqs:] + # per_req_cache_groups holds only seqs with a per-req cache, in batch + # order; for V4 every seq has one, so the decode rows are the tail. + self.per_req_cache_groups = batch.per_req_cache_groups[n_prefill_seqs:] + self.total_seqs_num_decode = batch.total_seqs_num_decode + self.total_tokens_num_decode = batch.total_tokens_num_decode + self.total_seqs_num_prefill = 0 + self.total_tokens_num_prefill = 0 + self.is_dummy_run = batch.is_dummy_run + self.num_spec_step = batch.num_spec_step + + def __getattr__(self, name): + # Anything not explicitly sliced above falls through to the real batch. + return getattr(self._batch, name) + class DeepseekV4Backend(AttentionBackend): """Backend selector entry for V4 hybrid attention. @@ -1071,6 +1113,151 @@ def prepare_mtp_decode( # - v4 indexer meta (Indexer — only present when ratio == 4) return {} + def prepare_mixed(self, batch: ScheduledBatch, bs: int): + """Build split-dispatch metadata for a V4 mixed prefill+decode batch. + + Mirrors the dense-MLA `prepare_mixed` (aiter_mla.py) pattern but for + V4's much larger metadata surface. The batch layout is + ``[prefill rows | decode rows]`` (prefill first, M1 scheduler order). + + Strategy: reuse the existing, validated `prepare_prefill` and + `prepare_decode` unmodified. + 1. `prepare_prefill(batch)` reads only rows ``[0:n_p_seqs]`` (prefill + rows are first), so it is correct as-is. Its per-fwd metadata + aliases shared `forward_vars`/`_stage` buffers, so we CLONE those + aliasing tensors before the decode half overwrites them. + 2. `prepare_decode(decode_view)` runs against a thin sub-batch view + that exposes only the decode rows ``[n_p_seqs:]``. + 3. The returned merged metadata carries both as nested + `v4_prefill_meta` / `v4_decode_meta`, plus merged full-tensor + fields (positions, cu_seqlens_q, batch_id_per_token, + state_slot_mapping, block_tables) for the shared ops in + `forward_impl` that run on the whole flat tensor. + + forward_impl raises until P2-P4 land the per-segment split; this + builder is exercised first (P1) so the metadata is validated before + the forward consumes it. + """ + var = self.model_runner.forward_vars + n_p_seqs = batch.total_seqs_num_prefill + n_p_tokens = batch.total_tokens_num_prefill + n_d_seqs = batch.total_seqs_num_decode + n_d_tokens = batch.total_tokens_num_decode + total_tokens = n_p_tokens + n_d_tokens + + # ---- Prefill half: rows [0:n_p_seqs] are first, prepare_prefill is + # correct as-is. Capture positions before the decode half overwrites + # the shared positions buffer. ---- + prefill_meta, prefill_positions = self.prepare_prefill(batch) + prefill_positions_np = var["positions"].np[:n_p_tokens].copy() + + # Detach prefill metadata from shared forward_vars / _stage buffers that + # the decode half will rewrite. Self-owned tensors built by + # `_build_paged_prefill_meta` (kv_indices_prefix_*, kv_indptr_prefix_*, + # kv_indices_extend, skip_prefix_len_csa) are fresh torch.empty/from_numpy + # and need no clone; only the shared per-fwd buffers do. + for fname in ( + "batch_id_per_token", + "n_committed_csa_per_seq", + "state_slot_mapping", + "block_tables", + "cu_seqlens_q", + "cu_seqlens_k", + "context_lens", + ): + t = getattr(prefill_meta, fname, None) + if isinstance(t, torch.Tensor): + setattr(prefill_meta, fname, t.clone()) + # indexer_meta GPU tensors also alias _stage buffers. + if prefill_meta.indexer_meta is not None: + prefill_meta.indexer_meta = { + k: (v.clone() if isinstance(v, torch.Tensor) else v) + for k, v in prefill_meta.indexer_meta.items() + } + # compress_plans: each CompressPlan's compress_plan_gpu / write_plan_gpu + # are VIEWS into the shared var["v4_compress_plan_{ratio}"] / + # var["v4_write_plan_{ratio}"] buffers, which prepare_decode's + # _build_compress_plans overwrites below. Clone those GPU tensors so the + # prefill segment's Compressor reads its own plan, not the decode plan. + if prefill_meta.compress_plans: + from dataclasses import replace as _dc_replace + + prefill_meta.compress_plans = { + ratio: _dc_replace( + plan, + compress_plan_gpu=plan.compress_plan_gpu.clone(), + write_plan_gpu=plan.write_plan_gpu.clone(), + ) + for ratio, plan in prefill_meta.compress_plans.items() + } + + # The prefill staging above issued async H2D copies from SHARED pinned + # forward_vars buffers (`_stage` → `copy_to_gpu(non_blocking=True)`). + # The GPU-side `.clone()`s only protect against GPU buffer REUSE — they + # are enqueued on the stream but not yet executed. The decode half below + # (and the cu_seqlens_q reset just under it) overwrite those SAME pinned + # CPU buffers on the host thread, which does NOT wait for the prefill + # DMA to drain. With a long prefill chunk the DMA window is wide enough + # that the host overwrite races the in-flight copy → the prefill index + # tensors get decode-half values → downstream `tensor[idx]` indexes OOB + # (GPU memory-access fault, only at large ISL). Drain the stream so every + # prefill H2D has finished reading the pinned buffers before decode + # reuses them. Mixed runs eager and is rare, so one sync per mixed batch + # is acceptable. + torch.cuda.current_stream().synchronize() + + # ---- Decode half: present rows [n_p_seqs:] as a standalone batch so + # the unmodified prepare_decode builds decode metadata into shared + # buffer rows [0:n_d_seqs]. Mixed runs eager, so bs == n_d_seqs (no CG + # padding). ---- + decode_view = _MixedDecodeView(batch, n_p_seqs) + # prepare_decode READS var["cu_seqlens_q"] (it never writes it — the + # normal caller, ModelRunner.prepare_inputs, sets it for the whole + # batch). For the mixed batch that buffer holds the FULL [prefill|decode] + # cumulative seqlens, so the decode rows are offset by n_p_tokens. Reset + # it to the decode-local cumulative seqlens (1 token per decode seq, + # no MTP in mixed) so swa_write / paged-decode index the 31-token decode + # kv correctly instead of running off the end (GPU OOB in swa_write). + decode_max_q = batch.num_spec_step + 1 + var["cu_seqlens_q"].np[: n_d_seqs + 1] = np.arange( + 0, (n_d_seqs + 1) * decode_max_q, decode_max_q, dtype=np.int32 + ) + decode_meta, decode_positions = self.prepare_decode(decode_view, n_d_seqs) + + # ---- Merge full-tensor fields for the shared forward_impl ops. ---- + # positions: [prefill | decode] + var["positions"].np[:n_p_tokens] = prefill_positions_np + var["positions"].np[n_p_tokens:total_tokens] = ( + decode_positions.cpu().numpy() + if isinstance(decode_positions, torch.Tensor) + else decode_positions + ) + positions = var["positions"].copy_to_gpu(total_tokens) + + merged = AttentionMetaData_DSV4( + # Surface prefill cu_seqlens_q so the ParallelLMHead mixed-batch + # gather (embed_head.py) finds per-prefill-seq last-token indices + # without reaching into the nested prefill metadata. + cu_seqlens_q=prefill_meta.cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=max(prefill_meta.max_seqlen_q, decode_meta.max_seqlen_q), + max_seqlen_k=max(prefill_meta.max_seqlen_k, decode_meta.max_seqlen_k), + min_seqlen_q=0, + dropout_p=0.0, + has_cached=prefill_meta.has_cached, + total_kv=(prefill_meta.total_kv or 0) + (decode_meta.total_kv or 0), + state=AttnState.PREFILL_PREFIX, # mixed always carries a prefill row + ) + merged.prefill_attn_metadata = prefill_meta + merged.decode_attn_metadata = decode_meta + # Marker the forward reads to take the mixed branch. + merged.is_mixed = True + merged.num_prefill_tokens = n_p_tokens + merged.num_prefill_seqs = n_p_seqs + merged.num_decode_tokens = n_d_tokens + merged.num_decode_seqs = n_d_seqs + return merged, positions + def prepare_decode(self, batch: ScheduledBatch, bs: int): """V4-style decode prep: populates positions, cu_seqlens_q, block_tables, and state_slot_mapping. @@ -1886,9 +2073,18 @@ def _build_paged_prefill_meta( ), index_topk, ).astype(np.int32) - n_hca_per_token_np = n_committed_hca_per_seq_np[batch_id_per_token_np].astype( - np.int32 - ) + # Per-token CAUSAL HCA visibility (mirrors CSA above and the reference + # `get_compress_topk_idxs` prefill mask): token at `pos` sees only the + # `(pos+1)//128` HCA groups committed up to its own position, capped by + # the per-seq committed count. Without `(pos+1)//128`, every token used + # the per-seq `ctx_end//128`, over-reading FUTURE groups and making a + # token's output depend on the forward's total length (chunked breaks). + # MUST stay in sync with the kernel's inline cap in + # `_v4_paged_prefill_indices_kernel` (HCA_RATIO). + n_hca_per_token_np = np.minimum( + (positions_arr + 1) // 128, + n_committed_hca_per_seq_np[batch_id_per_token_np], + ).astype(np.int32) # 4 indptrs on CPU; last element = total (no D2H to size buffers). ext_indptr_np = np.zeros(T + 1, dtype=np.int32) diff --git a/atom/model_ops/attentions/triton_merge_attn_states.py b/atom/model_ops/attentions/triton_merge_attn_states.py index 4aefe3e85a..828496261a 100644 --- a/atom/model_ops/attentions/triton_merge_attn_states.py +++ b/atom/model_ops/attentions/triton_merge_attn_states.py @@ -128,15 +128,25 @@ def merge_attn_states_kernel( s_lse = float("-inf") if s_lse == float("inf") else s_lse max_lse = tl.maximum(p_lse, s_lse) - p_lse = p_lse - max_lse - s_lse = s_lse - max_lse + # Both prefix AND suffix are empty for this token (no KV on either side) -> + # max_lse == -inf. The naive `p_lse - max_lse` would compute -inf-(-inf)=NaN + # and `out_se` would be 0, making the scale 0/0=NaN that poisons the output. + # This happens in ATOM's global-axis chunked prefill: a short seq can fall + # entirely outside a chunk, so its tokens see an empty prefix AND suffix in + # that chunk. Force a safe 0/0-split: subtract a finite max so each side's + # exp is 0 (out = 0*p_out + 0*s_out = 0, correct for empty attention) and + # keep the merged lse at -inf so any downstream merge stays consistent. + both_empty = max_lse == float("-inf") + safe_max = tl.where(both_empty, 0.0, max_lse) + p_lse = p_lse - safe_max + s_lse = s_lse - safe_max # Will reuse precomputed Exp values for scale factor computation. p_se = tl.exp(p_lse) s_se = tl.exp(s_lse) out_se = p_se + s_se if OUTPUT_LSE: - out_lse = tl.log(out_se) + max_lse + out_lse = tl.where(both_empty, float("-inf"), tl.log(out_se) + safe_max) tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) p_out = tl.load( @@ -157,8 +167,11 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = p_se / out_se - s_scale = s_se / out_se + # both_empty -> out_se == 0; guard the denominator so the scale is 0/1=0 + # (not 0/0=NaN). p_out/s_out are 0 for empty attention, so out stays 0. + safe_out_se = tl.where(both_empty, 1.0, out_se) + p_scale = p_se / safe_out_se + s_scale = s_se / safe_out_se out = p_out * p_scale + s_out * s_scale if USE_FP8: diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index 0c2ca9bd60..3da63cbb09 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -175,7 +175,27 @@ def forward(self, x: torch.Tensor): context = forward_context.context attn_metadata = forward_context.attn_metadata # context = get_context() - if context.is_prefill and not context.is_draft: + # Mixed must be checked BEFORE is_prefill: a mixed batch carries + # prefill rows so context.is_prefill is also True, but it needs the + # combined gather (prefill last-tokens + all decode tokens), not the + # prefill-only gather. + if getattr(context, "is_mixed", False) and not context.is_draft: + # Mixed prefill+decode: gather last-token-per-prefill-seq + # for the prefill rows, keep every token for the decode rows + # (decode rows produce one sample per token). + n_p_tokens = context.num_prefill_tokens + n_p_seqs = context.num_prefill_seqs + if n_p_seqs > 0: + prefill_last = attn_metadata.cu_seqlens_q[1 : n_p_seqs + 1] - 1 + decode_range = torch.arange( + n_p_tokens, + x.shape[0], + device=x.device, + dtype=prefill_last.dtype, + ) + sample_indices = torch.cat([prefill_last, decode_range]) + x = x[sample_indices].contiguous() + elif context.is_prefill and not context.is_draft: last_indices = attn_metadata.cu_seqlens_q[1:] - 1 x = x[last_indices].contiguous() logits = tgemm.mm(x, self.weight, self.bias) diff --git a/atom/model_ops/v4_kernels/paged_prefill_indices.py b/atom/model_ops/v4_kernels/paged_prefill_indices.py index c1375da1db..d7edfaf8ce 100644 --- a/atom/model_ops/v4_kernels/paged_prefill_indices.py +++ b/atom/model_ops/v4_kernels/paged_prefill_indices.py @@ -71,6 +71,7 @@ def _v4_paged_prefill_indices_kernel( win: tl.constexpr, cs, # win_with_spec — SWA ring stride (NOT constexpr because varies w/ mtp_k) swa_pages, # state_slot count * cs — boundary into HCA compress section + HCA_RATIO: tl.constexpr, # HCA compress ratio (128) for per-token causal cap BLOCK_N: tl.constexpr, # next_pow2(win) — covers SWA prefix and extend segments ): """One program per token. Writes four per-token segments: @@ -92,7 +93,15 @@ def _v4_paged_prefill_indices_kernel( chunk_start = tl.load(chunk_start_per_seq_ptr + bid) cu_q = tl.load(cu_seqlens_q_per_seq_ptr + bid) state_slot = tl.load(state_slot_per_seq_ptr + bid) - n_hca = tl.load(n_committed_hca_per_seq_ptr + bid) + # Per-token CAUSAL HCA visibility: token at `pos` may see only the + # `(pos+1)//HCA_RATIO` compressed groups committed up to its own position + # (matches the reference `get_compress_topk_idxs` prefill mask, and mirrors + # the CSA `(pos+1)//4` cap). Without this cap every token saw the per-seq + # `n_committed_hca = ctx_end//128`, which over-reads FUTURE groups and makes + # a token's output depend on the forward's total length (chunked != single). + n_hca = tl.minimum( + (pos + 1) // HCA_RATIO, tl.load(n_committed_hca_per_seq_ptr + bid) + ) # Per-token derived quantities (single-pass arithmetic). token_pos_in_chunk = pos - chunk_start @@ -159,6 +168,7 @@ def write_v4_paged_prefill_indices( win: int, cs: int, swa_pages: int, + hca_ratio: int = 128, ) -> None: """One-shot GPU build of the V4 paged-prefill index buffers. @@ -247,6 +257,7 @@ def write_v4_paged_prefill_indices( win=win, cs=cs, swa_pages=swa_pages, + HCA_RATIO=hca_ratio, BLOCK_N=BLOCK_N, ) @@ -272,6 +283,7 @@ def write_v4_paged_prefill_indices_reference( win: int, cs: int, swa_pages: int, + hca_ratio: int = 128, ) -> None: """Pure-Python equivalent of ``write_v4_paged_prefill_indices``. Per-token Python loop — slow but readable; used for unit-test bit-exact @@ -301,7 +313,8 @@ def write_v4_paged_prefill_indices_reference( chunk_start = cs_per_seq_cpu[bid] cu_q = cu_q_cpu[bid] state_slot = state_slot_cpu[bid] - n_hca = n_hca_cpu[bid] + # Per-token causal HCA cap (mirrors kernel + reference get_compress_topk_idxs). + n_hca = min((pos + 1) // hca_ratio, n_hca_cpu[bid]) token_pos_in_chunk = pos - chunk_start swa_low = max(pos - win + 1, 0) diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index fc9e88431b..43c81f98cf 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -1708,6 +1708,53 @@ def forward_impl( return torch.zeros_like(x) if os.environ.get("ATOM_V4_BYPASS_ATTN") == "1": return torch.zeros_like(x) + + # ===== Mixed prefill+decode split dispatch ===== + # A mixed batch packs [prefill rows | decode rows] in one flat tensor. + # Each sub-metadata (prefill_attn_metadata / decode_attn_metadata) is a + # COMPLETE, self-consistent V4 metadata object, so rather than splitting + # the body at every is_decode site we run the whole validated forward + # body once per segment with the forward-context temporarily pointed at + # that segment's sub-meta + is_prefill flag, then concatenate. The + # non-mixed path (below) is byte-for-byte unchanged. swa_write ordering, + # Indexer prefill-vs-decode dispatch, csa_translate_pack, and the sparse + # kernels all fall out correctly because each segment re-enters this same + # method as a pure prefill or pure decode forward. + attn_md_top = cast("AttentionMetaData_DSV4", fc.attn_metadata) + if getattr(attn_md_top, "is_mixed", False): + n_p = attn_md_top.num_prefill_tokens + n_d = x.size(0) - n_p + p_md = attn_md_top.prefill_attn_metadata + d_md = attn_md_top.decode_attn_metadata + ctx = fc.context + saved_is_prefill = ctx.is_prefill + saved_input_ids = ctx.input_ids + # Tag the whole mixed dispatch so the trace can distinguish a real + # pure-prefill / pure-decode step from the two segments of a mixed + # step (which re-enter this method as prefill/decode and would + # otherwise only show those inner tags). + with torch.profiler.record_function(f"mixed[n_p={n_p} n_d={n_d}]"): + # Prefill segment: rows [0:n_p], pure-prefill path. Slice + # input_ids too so the hash-MoE `_hash_topk` (reads + # ctx.input_ids) matches the segment's token count. + fc.attn_metadata = p_md + ctx.is_prefill = True + if saved_input_ids is not None: + ctx.input_ids = saved_input_ids[:n_p] + out_p = self.forward_impl(x[:n_p], positions[:n_p]) + # Decode segment: rows [n_p:], pure-decode path. + fc.attn_metadata = d_md + ctx.is_prefill = False + if saved_input_ids is not None: + ctx.input_ids = saved_input_ids[n_p:] + out_d = self.forward_impl(x[n_p:], positions[n_p:]) + # Restore the top-level mixed context for any caller / later + # layer. + fc.attn_metadata = attn_md_top + ctx.is_prefill = saved_is_prefill + ctx.input_ids = saved_input_ids + return torch.cat([out_p, out_d], dim=0) + num_tokens = x.size(0) cache_size = self.swa_kv.shape[1] ratio = self.compress_ratio diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 209aad6154..13ba34a1a0 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -326,6 +326,12 @@ class Context: # that need the token ids but cannot receive them as a function arg # (the op signature is fixed by the consumer's plugin contract). input_ids: Optional[torch.Tensor] = None + # Mixed prefill+decode (Phase 2). When True, the first num_prefill_tokens + # rows of q/k/v are prefill chunks and the rest are decode tokens; the + # attention backend splits them via prefill_attn_metadata / decode_attn_metadata. + is_mixed: bool = False + num_prefill_tokens: int = 0 + num_prefill_seqs: int = 0 def __init__( self, @@ -338,6 +344,9 @@ def __init__( dp_uniform_decode: bool = True, forward_mode: Optional[ForwardMode] = None, input_ids: Optional[torch.Tensor] = None, + is_mixed: bool = False, + num_prefill_tokens: int = 0, + num_prefill_seqs: int = 0, ): self.positions = positions self.is_prefill = is_prefill @@ -348,6 +357,9 @@ def __init__( self.dp_uniform_decode = dp_uniform_decode self.forward_mode = forward_mode self.input_ids = input_ids + self.is_mixed = is_mixed + self.num_prefill_tokens = num_prefill_tokens + self.num_prefill_seqs = num_prefill_seqs @dataclass @@ -391,6 +403,15 @@ class AttentionMetaData: num_cached_tokens: Optional[torch.Tensor] = None seq_starts: Optional[torch.Tensor] = None + # Mixed prefill+decode (Phase 2) split-dispatch sub-metadata. For a mixed + # batch the builder returns ONE AttentionMetaData whose `slot_mapping` covers + # all tokens (merged prefill-then-decode), and whose attention dispatch is + # driven by these two nested objects: `prefill_attn_metadata` for the first + # `context.num_prefill_tokens` rows and `decode_attn_metadata` for the rest. + # None for non-mixed batches. + prefill_attn_metadata: Optional["AttentionMetaData"] = None + decode_attn_metadata: Optional["AttentionMetaData"] = None + def __init__( self, cu_seqlens_q: Optional[torch.Tensor] = None, @@ -421,7 +442,11 @@ def __init__( total_kv: Optional[int] = None, num_cached_tokens: Optional[torch.Tensor] = None, seq_starts: Optional[torch.Tensor] = None, + prefill_attn_metadata: Optional["AttentionMetaData"] = None, + decode_attn_metadata: Optional["AttentionMetaData"] = None, ): + self.prefill_attn_metadata = prefill_attn_metadata + self.decode_attn_metadata = decode_attn_metadata self.has_cached = has_cached self.total_kv = total_kv self.num_cached_tokens = num_cached_tokens diff --git a/scripts/start_atom_server.sh b/scripts/start_atom_server.sh index 197c040dd1..28fea738a7 100755 --- a/scripts/start_atom_server.sh +++ b/scripts/start_atom_server.sh @@ -58,6 +58,12 @@ done rm -rf ~/.cache/atom/* rm -rf ./gpucore.* +# Disable core dumps: a single ROCm fault dumps a 30-50 GB gpucore PER RANK, +# which on an 8-GPU TP run fills the disk in seconds (and triggers the +# apport "execvp failed" noise). The debug-agent wrapper already does this; +# the normal launcher must too, since faults can happen in production runs. +ulimit -c 0 + # Write config header to log (truncates old content). # Inherited env vars are dumped explicitly so you never have to wonder # whether ATOM_USE_TRITON_MOE / V4_USE_REF_QUANT / etc. were set. diff --git a/tests/conftest.py b/tests/conftest.py index 326335cb9f..e8ded61b5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,6 +118,7 @@ def __init__(self, **overrides): enable_prefix_caching=False, max_num_seqs=4, max_num_batched_tokens=64, + long_prefill_token_threshold=0, max_model_len=64, bos_token_id=1, eos_token_id=2, diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index da43358ecd..5df21aa4cb 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -136,6 +136,82 @@ def test_decode_preemption(self, seq_factory): assert SequenceStatus.WAITING in statuses +# ── mixed prefill+decode (decode-first budget reservation) ────────────────── + + +class TestMixedDecodeFirst: + """Decode-first budget reservation: when --enable-mixed-prefill-decode is + on, a running decode reserves its token budget BEFORE new prefills spend + it, so a long prefill chunk cannot starve decode out of the step and a + mixed batch forms. Mirrors vLLM V1's running-before-waiting ordering.""" + + def _mixed_sched(self, **kw): + cfg = dict( + enable_mixed_prefill_decode=True, + enable_chunked_prefill=True, + max_num_batched_tokens=8, + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_seqs=8, + ) + cfg.update(kw) + return Scheduler(MockConfig(**cfg)) + + def test_mixed_batch_forms_when_decode_inflight(self, seq_factory): + sched = self._mixed_sched() + # One seq reaches decode first. + d = seq_factory([1, 2, 3, 4]) + sched.add(d) + sched.schedule() # prefill d + d.num_cached_tokens = d.num_prompt_tokens + d.append_token(5) + # A long new prompt arrives; its full chunk (8) would, under the old + # prefill-first order, consume the entire budget and push decode out. + p = seq_factory([10, 11, 12, 13, 14, 15, 16, 17]) + sched.add(p) + batch, _ = sched.schedule() + # Decode-first reservation (1 decode * 1 tok) leaves prefill_budget=7, + # so the prefill chunk shrinks and the decode row rides along: MIXED. + assert batch.is_mixed + assert batch.total_seqs_num_decode == 1 + assert batch.total_seqs_num_prefill == 1 + + def test_mixed_batch_layout_prefill_first(self, seq_factory): + sched = self._mixed_sched() + d = seq_factory([1, 2, 3, 4]) + sched.add(d) + sched.schedule() + d.num_cached_tokens = d.num_prompt_tokens + d.append_token(5) + p = seq_factory([10, 11, 12, 13, 14, 15, 16, 17]) + sched.add(p) + batch, _ = sched.schedule() + assert batch.is_mixed + # Runner/attention require [prefill rows | decode rows]: prefill row(s) + # lead, decode row(s) trail. + assert batch.req_ids[0] == p.id + assert batch.req_ids[-1] == d.id + # One decode seq -> exactly one trailing decode token. + assert batch.total_seqs_num_decode == 1 + assert batch.total_tokens_num_decode == 1 + + def test_flag_off_is_prefill_first(self, seq_factory): + # With mixed OFF, scheduling a prefill yields a prefill-only batch even + # when a decode is in flight (byte-identical to legacy behavior). + sched = self._mixed_sched(enable_mixed_prefill_decode=False) + d = seq_factory([1, 2, 3, 4]) + sched.add(d) + sched.schedule() + d.num_cached_tokens = d.num_prompt_tokens + d.append_token(5) + p = seq_factory([10, 11, 12, 13, 14, 15, 16, 17]) + sched.add(p) + batch, _ = sched.schedule() + assert not batch.is_mixed + assert batch.total_seqs_num_prefill == 1 + assert batch.total_seqs_num_decode == 0 + + # ── prefix caching ──────────────────────────────────────────────────────── @@ -430,3 +506,96 @@ def test_normal_decode_window_unchanged(self): ) assert list(batch.scheduled_tokens) == toks[-(mtp_k + 1) :] + + +# ── mixed prefill+decode batch (Phase 2) ─────────────────────────────────── + + +class TestMixedBatch: + """Verify that --enable-mixed-prefill-decode merges prefill chunks and + decode seqs into one ScheduledBatch (Phase 2 of chunked prefill).""" + + def _mixed_scheduler(self, **overrides): + cfg = MockConfig( + enable_mixed_prefill_decode=True, + num_kvcache_blocks=20, + max_num_seqs=4, + max_num_batched_tokens=256, + **overrides, + ) + return Scheduler(cfg) + + def _prefill_only_scheduler(self, **overrides): + cfg = MockConfig( + enable_mixed_prefill_decode=False, + num_kvcache_blocks=20, + max_num_seqs=4, + max_num_batched_tokens=256, + **overrides, + ) + return Scheduler(cfg) + + def _ready_decode_seq(self, sched, seq_factory, token_ids): + seq = seq_factory(token_ids) + sched.add(seq) + sched.schedule() # prefill + seq.num_cached_tokens = seq.num_prompt_tokens + seq.append_token(99) + return seq + + def test_flag_off_keeps_batch_pure_prefill(self, seq_factory): + """Back-compat: without the flag, a step with both pending prefill and + running decode emits prefill-only (decode waits).""" + sched = self._prefill_only_scheduler() + self._ready_decode_seq(sched, seq_factory, [1, 2, 3, 4]) + sched.add(seq_factory([5, 6, 7, 8])) # new prefill + batch, _ = sched.schedule() + assert batch.total_seqs_num_prefill == 1 + assert batch.total_seqs_num_decode == 0 + assert batch.is_mixed is False + + def test_flag_on_produces_mixed_batch(self, seq_factory): + sched = self._mixed_scheduler() + self._ready_decode_seq(sched, seq_factory, [1, 2, 3, 4]) + sched.add(seq_factory([5, 6, 7, 8])) # new prefill + batch, _ = sched.schedule() + assert batch.total_seqs_num_prefill == 1 + assert batch.total_seqs_num_decode == 1 + assert batch.is_mixed is True + # prefill rows come first, then decode rows + assert batch.num_scheduled_tokens[0] == 4 # prefill chunk + assert batch.num_scheduled_tokens[1] == 1 # decode token + + def test_decode_loop_skips_seq_already_scheduled_as_prefill(self, seq_factory): + """A newly-added prefill seq must not also be picked by the decode + loop in the same step (it isn't in running, but guard the contract).""" + sched = self._mixed_scheduler() + decode_seq = self._ready_decode_seq(sched, seq_factory, [1, 2, 3, 4]) + prefill_seq = seq_factory([5, 6, 7, 8]) + sched.add(prefill_seq) + batch, _ = sched.schedule() + # Both seqs scheduled exactly once + assert len(batch.num_scheduled_tokens) == 2 + ids = list(batch.req_ids) if hasattr(batch, "req_ids") else None + if ids is not None: + assert ids.count(decode_seq.id) == 1 + assert ids.count(prefill_seq.id) == 1 + + def test_mixed_batch_decode_only_when_no_prefill(self, seq_factory): + """If no prefill is pending, mixed flag still produces a decode-only + batch (is_mixed False).""" + sched = self._mixed_scheduler() + self._ready_decode_seq(sched, seq_factory, [1, 2, 3, 4]) + batch, _ = sched.schedule() + assert batch.total_seqs_num_prefill == 0 + assert batch.total_seqs_num_decode == 1 + assert batch.is_mixed is False + + def test_mixed_batch_prefill_only_when_no_decode(self, seq_factory): + """If no decode is ready, flag still produces a prefill-only batch.""" + sched = self._mixed_scheduler() + sched.add(seq_factory([5, 6, 7, 8])) + batch, _ = sched.schedule() + assert batch.total_seqs_num_prefill == 1 + assert batch.total_seqs_num_decode == 0 + assert batch.is_mixed is False