Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a432f2b
scheduler: emit mixed prefill+decode batch behind --enable-mixed-pref…
jiayyu May 27, 2026
e069292
runner: input_ids + lm_head gather for mixed prefill+decode batch (P2…
jiayyu May 27, 2026
717337d
runner: disable TBO for mixed prefill+decode batch + document DP sync…
jiayyu May 27, 2026
89b4c1a
attn: dense-MLA split dispatch for mixed prefill+decode batch (P2-M2/M3)
jiayyu Jun 15, 2026
1072fd3
v4(attn): prepare_mixed builder for mixed prefill+decode (P1)
jiayyu Jun 15, 2026
082130e
v4(attn): forward_impl mixed split dispatch, Dense layers (P2)
jiayyu Jun 15, 2026
7c7e38a
v4(attn): mixed split dispatch via per-segment forward re-entry (P2-P4)
jiayyu Jun 15, 2026
51191df
v4(attn): drop bogus is_sparse placeholder in prepare_mixed
jiayyu Jun 15, 2026
b6c5da9
v4(attn): surface prefill cu_seqlens_q on merged mixed metadata
jiayyu Jun 15, 2026
46b5f98
v4(attn): clone compress_plans in prepare_mixed (fix GPU OOB)
jiayyu Jun 16, 2026
fe98f22
runner: enable deferred output for mixed prefill+decode batches
jiayyu Jun 16, 2026
d79a97f
fix(mla): handle both-empty merge in the kernel, drop call-site worka…
jiayyu Jun 12, 2026
deae538
fix(v4): add per-token causal cap to HCA prefill visibility
jiayyu Jun 10, 2026
10a9390
style: black-format mixed-batch scheduler while-loop
jiayyu Jun 16, 2026
2b004a8
v4(attn): fix chunked-prefill + mixed-batch GPU OOB
jiayyu Jun 17, 2026
411b64b
v4(attn): fix prepare_mixed pinned-buffer H2D race at large ISL
jiayyu Jun 17, 2026
c471a18
sched: add long_prefill_token_threshold + disable deferred out for mixed
jiayyu Jun 18, 2026
a8ea258
scripts: disable core dumps in start_atom_server.sh
jiayyu Jun 18, 2026
957d6ad
runner: correct deferred+mixed comment — accuracy bug, not memory
jiayyu Jun 18, 2026
956e6d2
sched: decode-first budget reservation so mixed batches form (vLLM V1…
jiayyu Jun 18, 2026
54ffed5
runner: add ATOM_FORCE_DEFERRED escape hatch for mixed+deferred
jiayyu Jun 18, 2026
aaee21d
v4(attn): tag mixed dispatch with record_function for trace visibility
jiayyu Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ class Config:
model: str
trust_remote_code: bool = False
max_num_batched_tokens: int = 16384
long_prefill_token_threshold: int = 0
attn_prefill_chunk_size: int = 16384
scheduler_delay_factor: float = 0.0
max_num_seqs: int = 512
Expand All @@ -999,6 +1000,11 @@ class Config:
kv_cache_dtype: str = "bf16"
enable_prefix_caching: bool = True
enable_chunked_prefill: bool = True
# Mix prefill chunks and decode seqs into the same forward pass (Phase 2
# of chunked prefill). Default off until the attention backends grow
# split-dispatch support — when off, scheduler emits prefill-only or
# decode-only batches as before.
enable_mixed_prefill_decode: bool = False
port: int = 8006
torch_profiler_dir: str | None = field(
default_factory=lambda: envs.ATOM_TORCH_PROFILER_DIR
Expand Down Expand Up @@ -1104,6 +1110,19 @@ def __post_init__(self):
self.max_model_len, hf_config_max_position_embeddings
)
# assert self.max_num_batched_tokens >= self.max_model_len
if self.long_prefill_token_threshold > 0:
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
f"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than max_model_len ({self.max_model_len})."
)
if self.long_prefill_token_threshold < self.kv_cache_block_size:
raise ValueError(
f"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) must be >= "
f"kv_cache_block_size ({self.kv_cache_block_size})."
)
if not is_plugin_mode():
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
Expand Down
20 changes: 20 additions & 0 deletions atom/model_engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ class EngineArgs:
data_parallel_size: int = 1
enforce_eager: bool = False
enable_prefix_caching: bool = True
enable_mixed_prefill_decode: bool = False
port: int = 8006
kv_cache_dtype: str = "bf16"
block_size: int = 16
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 16384
long_prefill_token_threshold: int = 0
attn_prefill_chunk_size: int = 16384
enable_chunked_prefill: bool = True
scheduler_delay_factor: float = 0.0
Expand Down Expand Up @@ -97,6 +99,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
help="Enable prefix caching (default: enabled). "
"Use --no-enable_prefix_caching to disable.",
)
parser.add_argument(
"--enable-mixed-prefill-decode",
action="store_true",
help="Pack prefill chunks and decode seqs into the same forward "
"pass. Requires attention backends with split-dispatch support; "
"off by default.",
)
parser.add_argument(
"--port",
type=int,
Expand Down Expand Up @@ -192,6 +201,17 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
default=16384,
help="Maximum number of tokens to batch together in async engine",
)
parser.add_argument(
"--long-prefill-token-threshold",
type=int,
default=0,
help=(
"For chunked prefill, cap a single request's per-step prefill "
"size at this many tokens. 0 disables the cap (request is only "
"bounded by max_num_batched_tokens). Useful to interleave long "
"prefills with decode for lower ITL."
),
)
parser.add_argument(
"--attn-prefill-chunk-size",
type=int,
Expand Down
104 changes: 103 additions & 1 deletion atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,26 @@ def __init__(
num_spec_tokens: int = 0,
):
"""Asynchronously copy the sampled_token_ids tensor to the host."""
self.is_deferred_out = True
# Deferred output is disabled when running in P/D disaggregation mode
# (kv_transfer_config is set), enabled otherwise.
# TODO: In P/D disaggregation mode, if have issue, we can disable it
# Mixed prefill+decode: the deferred GPU-gather path has a known
# accuracy bug (idle decode seqs read a placeholder token across a
# prefill chunk — R1 GSM8K 0.87 vs 0.9469). It is NOT a memory bug:
# verified crash-free at conc 2048 / ISL 8192 with deferred forced on.
# Disable deferred under the mixed flag until the accuracy bug is fixed.
self.is_deferred_out = not getattr(
runner.config, "enable_mixed_prefill_decode", False
)
# Escape hatch: ATOM_FORCE_DEFERRED=1 forces deferred output ON even
# under the mixed flag. Deferred+mixed has a known accuracy bug (not a
# crash — verified crash-free at conc 2048 / ISL 8192) but is faster, so
# this lets us measure mixed+deferred throughput while the accuracy fix
# is pending. Do NOT use for accuracy-sensitive runs.
import os as _os

if _os.environ.get("ATOM_FORCE_DEFERRED") == "1":
self.is_deferred_out = True

self.runner = runner
device = runner.device
Expand Down Expand Up @@ -355,6 +374,7 @@ def prepare_input_ids(
total_tokens_prefill = batch.total_tokens_num_prefill
total_tokens_decode = batch.total_tokens_num_decode
total_reqs_prefill = batch.total_seqs_num_prefill
is_mixed = getattr(batch, "is_mixed", False)
"""for prefill: all input ids are new"""
self.input_ids.np[:total_tokens_prefill] = scheduled_tokens[
:total_tokens_prefill
Expand All @@ -363,6 +383,75 @@ def prepare_input_ids(

self.prev_rejected_num, self.prev_bonus_num = self.recv_mtp_status_async()

if is_mixed:
# Mixed batch layout: [prefill_tokens | decode_tokens]. The prefill
# region is already written above. Fill the decode region (one token
# per decode seq, in batch order — which matches the decode attention
# metadata's row order) starting at `decode_offset`.
# MTP / speculative decode with mixed batches is a separate follow-up
# (the per-seq multi-token layout isn't wired into this branch).
assert not self.use_spec, (
"Mixed prefill+decode batches do not yet support MTP / speculative "
"decode (follow-up). Disable --enable-mixed-prefill-decode for now."
)
decode_offset = total_tokens_prefill
sched_decode = scheduled_tokens[
decode_offset : decode_offset + total_tokens_decode
]

# Non-deferred OR first step (no prior batch to gather from): decode
# inputs come straight from scheduled_tokens. This is the path
# already verified at GSM8K parity on R1 / V4-Pro.
if not self.is_deferred_out or self.prev_batch is None:
self.input_ids.np[
decode_offset : decode_offset + total_tokens_decode
] = sched_decode
self.input_ids.copy_to_gpu(total_tokens)
return self.input_ids.gpu[:total_tokens]

# Deferred path: each decode seq's input is the token sampled for it
# last step, kept on-GPU in `prev_token_ids` (ordered by
# prev_batch.req_ids). Map current decode seqs — batch positions
# [n_prefill_seqs:] — to their prev_batch slot. A decode seq is
# ALWAYS in prev_batch in steady state (it decoded last step); a
# genuinely-new decode row (just finished prefill elsewhere) falls
# back to scheduled_tokens. Prefill rows are never gathered.
#
# Destination index = decode_offset + i, where i is the decode seq's
# position within the decode segment — identical to the decode
# attention metadata row order, guaranteeing alignment.
n_prefill_seqs = batch.total_seqs_num_prefill
prev_id_to_idx = {rid: j for j, rid in enumerate(self.prev_batch.req_ids)}
deferred_dst: list[int] = []
deferred_prev: list[int] = []
for i, rid in enumerate(batch.req_ids[n_prefill_seqs:]):
prev_idx = prev_id_to_idx.get(rid)
if prev_idx is not None:
deferred_dst.append(decode_offset + i)
deferred_prev.append(prev_idx)

# Baseline the whole decode region from scheduled_tokens (correct for
# any new decode seq not in prev_batch). Deferred positions are then
# overwritten GPU-side by the gather below.
self.input_ids.np[decode_offset : decode_offset + total_tokens_decode] = (
sched_decode
)
self.input_ids.copy_to_gpu(total_tokens)

if deferred_dst:
self.input_ids_loc.np[: len(deferred_prev)] = deferred_prev
prev_idx_gpu = self.input_ids_loc.copy_to_gpu(len(deferred_prev))
gathered = torch.gather(self.prev_token_ids, 0, prev_idx_gpu)
dst_gpu = torch.as_tensor(
deferred_dst, dtype=torch.long, device=self.input_ids.gpu.device
)
self.input_ids.gpu[dst_gpu] = gathered.to(self.input_ids.gpu.dtype)

# prev_batch / prev_token_ids are advanced by prepare_sampled_ids
# (postprocess) after sampling — NOT here, exactly like the non-mixed
# deferred path.
return self.input_ids.gpu[:total_tokens]

# TODO: remove this when we support mixed prefill and decode in one batch
if total_reqs_prefill > 0:
return self.input_ids.gpu[:total_tokens_prefill]
Expand Down Expand Up @@ -1660,6 +1749,10 @@ def _maybe_create_tbo_slices(
With the packed-reduce path the eligibility (local + cross-DP AND)
is decided in ``_preprocess``; here we just realise the split.
"""
if getattr(batch, "is_mixed", False):
# TBO ubatch splitting on a [prefill | decode] layout is not yet
# supported (P2-M5 follow-up). Run mixed batches without TBO.
return None
if not tbo_collective_active:
return None

Expand Down Expand Up @@ -1758,6 +1851,7 @@ def _preprocess(

def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None):
is_prefill = batch.total_tokens_num_prefill > 0
is_mixed = getattr(batch, "is_mixed", False)
bs = batch.total_seqs_num
num_scheduled_tokens = np.asarray(batch.num_scheduled_tokens)
cu_seqlens_q, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
Expand Down Expand Up @@ -1814,11 +1908,19 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None):
graph_bs=graph_bs,
dp_uniform_decode=dp_uniform_decode,
forward_mode=forward_mode,
is_mixed=is_mixed,
num_prefill_tokens=batch.total_tokens_num_prefill if is_mixed else 0,
num_prefill_seqs=batch.total_seqs_num_prefill if is_mixed else 0,
)

actual_num_tokens = batch.total_tokens_num

spec_decode_metadata = None
if is_mixed and hasattr(self, "drafter") and not batch.is_dummy_run:
raise NotImplementedError(
"Mixed prefill+decode batches do not yet support MTP / speculative "
"decode (P2-M4 follow-up). Disable --enable-mixed-prefill-decode."
)
if not is_prefill and hasattr(self, "drafter") and not batch.is_dummy_run:
scheduled_bs = batch.total_seqs_num_decode
spec_decode_metadata = self.drafter.calc_spec_decode_metadata(
Expand Down
Loading
Loading