Skip to content
4 changes: 4 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,10 @@ def allocate_kv_cache(self, num_kvcache_blocks):
for i, kv_cache_tensor in enumerate(kv_cache_tensors)
}
transfer_tensors = self.attn_metadata_builder.get_kv_transfer_tensors()
if hasattr(self, "eagle3_draft_builder") and transfer_tensors is not None:
draft_regions = self.eagle3_draft_builder.get_kv_transfer_tensors()
if draft_regions:
transfer_tensors.block_regions.extend(draft_regions)
set_kv_cache_data(kv_cache_data, config, transfer_tensors)

# Cross-validate: compare estimated vs actual KV cache allocation.
Expand Down
13 changes: 13 additions & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,19 @@ def postprocess(
# no-op).
if stop_at_idx is not None and stop_at_idx < num_new_token - 1:
num_tokens -= (num_new_token - 1) - stop_at_idx
# The same truncation MUST apply to the EMITTED tokens, not just
# the internal seq length. The client-visible text is built from
# RequestOutput.output_tokens (an accumulation of `new_tokens`) by
# generate_async / the streaming callback — NOT from
# completion_token_ids (which the `seq.num_tokens` write above
# governs). Without trimming `new_tokens` here, the post-stop
# tokens the rejection sampler emits past EOS (it does not inspect
# EOS) leak into the response: strict-match still finds the answer,
# but flexible-extract's last-number picks up the leaked trailing
# digit. `injected_t0` (if present) prepends one slot not counted
# in stop_at_idx / num_new_token, so offset the cut by it.
keep = stop_at_idx + 1 + (1 if injected_t0 is not None else 0)
new_tokens = new_tokens[:keep]

# Prepare stream output
if stream_output_queue is not None and new_tokens:
Expand Down
125 changes: 121 additions & 4 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import aiter
import numpy as np
import torch
import triton
import triton.language as tl
from aiter.dist.parallel_state import get_tp_group
from atom.model_engine.scheduler import ScheduledBatch
from atom.utils import CpuGpuBuffer, envs
Expand All @@ -15,10 +17,7 @@
kv_indices_generate_triton,
)
from atom.model_ops.attention_mha import PagedAttentionImpl, use_pa_decode_bf16_asm
from atom.utils.forward_context import (
AttentionMetaData,
Context,
)
from atom.utils.forward_context import AttentionMetaData, Context, get_forward_context
from atom.utils.tbo import TokenSplitPrefillState

from .backends import AttentionBackend, CommonAttentionBuilder
Expand All @@ -43,6 +42,54 @@ def _is_indexed_sparse_attention(module) -> bool:
return bool(getattr(impl, "is_indexed_sparse_attention", False))


@triton.jit
def _mtp_prepare_decode_metadata_kernel(
context_lens_ptr,
block_tables_ptr,
slot_mapping_ptr,
positions_in_ptr,
positions_out_ptr,
last_token_indices_ptr,
bs,
skip_update: tl.constexpr,
update_context_lens: tl.constexpr,
update_positions: tl.constexpr,
select_positions: tl.constexpr,
block_size: tl.constexpr,
block_table_stride: tl.constexpr,
position_stride: tl.constexpr,
BLOCK: tl.constexpr,
):
if not skip_update:
seq = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
mask = seq < bs

ctx = tl.load(context_lens_ptr + seq, mask=mask, other=1).to(tl.int64)
if update_context_lens:
ctx += 1
tl.store(context_lens_ptr + seq, ctx, mask=mask)

if update_positions:
pos_idx = seq
if select_positions:
pos_idx = tl.load(last_token_indices_ptr + seq, mask=mask, other=0)
pos = tl.load(positions_in_ptr + pos_idx, mask=mask, other=0)
tl.store(positions_out_ptr + seq * position_stride, pos + 1, mask=mask)

last_pos = tl.maximum(ctx - 1, 0)
block_col = last_pos // block_size
within_block = last_pos - block_col * block_size

phys_block = tl.load(
block_tables_ptr + seq * block_table_stride + block_col,
mask=mask,
other=0,
).to(tl.int64)
tl.store(
slot_mapping_ptr + seq, phys_block * block_size + within_block, mask=mask
)


class AiterBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
Expand All @@ -59,6 +106,9 @@ def get_impl_cls():

class AiterAttentionMetadataBuilder(CommonAttentionBuilder):
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
# EagleProposer fuses the per-draft-step position bump into
# prepare_mtp_decode's kernel when this is set (block-paged MHA draft).
fuse_mtp_decode_position_update = True

def __init__(
self,
Expand Down Expand Up @@ -712,6 +762,73 @@ def _add_region(tensor):
num_blocks=runner.num_physical_kvcache_blocks,
)

def prepare_mtp_decode(
self,
bs: int,
max_seqlen_q: int,
max_seqlen_k: int,
positions: torch.Tensor,
only_update: bool = False,
num_reject_tokens: torch.Tensor = None,
*,
update_context_lens: bool = False,
positions_out: torch.Tensor | None = None,
last_token_indices: torch.Tensor | None = None,
):
"""Per-draft-step metadata for a block-paged MHA Eagle3 draft.

Called by EagleProposer.propose at mid-step iters. The draft's decode
kernels (``paged_attention_{asm,triton}``) read ``block_tables`` +
``context_lens``. Eagle can pre-bump ``context_lens`` before this call,
or ask this fused kernel to update it in place. The block_size==1024
persistent path is the only one consuming ``kv_indptr``/``kv_indices``;
MiniMax-M3 runs at ``--block-size 128`` so the kernel never reads them -
no rebuild.

The one value we must (re)compute is the write slot for the new draft
token in the draft's own block-paged KV cache:

slot = block_tables[seq, (ctx-1)//B] * B + (ctx-1) % B, B = block_size

Returned under ``slot_mapping`` so EagleProposer skips its token-granular
(MLA physical block_size==1) flat-kv slot derivation, which would yield
a bare block id for ``B > 1``.

``only_update`` / ``num_reject_tokens`` are MLA/V4-specific knobs and are
unused here: there are no persistent worker buffers to roll over for
``block_size != 1024``.
"""
var = self.model_runner.forward_vars
slot_mapping = var["slot_mapping"].gpu[:bs]
block_tables = var["block_tables"].gpu
context_lens = var["context_lens"].gpu
update_positions = positions_out is not None
select_positions = update_positions and last_token_indices is not None
if positions_out is None:
positions_out = positions
if last_token_indices is None:
last_token_indices = slot_mapping
# Dummy runs skip the draft attention, so keep this launch as a no-op:
# their synthetic context_lens can point past block_tables.
_mtp_prepare_decode_metadata_kernel[(max(1, triton.cdiv(bs, 128)),)](
context_lens,
block_tables,
slot_mapping,
positions,
positions_out,
last_token_indices,
bs,
bs == 0 or get_forward_context().context.is_dummy_run,
update_context_lens,
update_positions,
select_positions,
self.model_runner.block_size,
block_tables.stride(0),
positions_out.stride(0) if update_positions else 1,
BLOCK=128,
)
return {"slot_mapping": slot_mapping}

def prepare_prefill(self, batch: ScheduledBatch):
attn_metadata, positions = CommonAttentionBuilder.prepare_prefill(self, batch)
if self._has_sparse_attention and not attn_metadata.has_cached:
Expand Down
59 changes: 59 additions & 0 deletions atom/model_ops/embed_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiter.dist.parallel_state import get_tp_group
from aiter.jit.utils.torch_guard import torch_compile_guard

from atom.model_ops.lm_head_argmax import lm_head_argmax_pack
from atom.model_ops.utils import atom_parameter
from atom.plugin import is_plugin_mode
from atom.utils import envs
Expand Down Expand Up @@ -151,6 +152,41 @@ def forward(self, x: torch.Tensor):
# return y


class ReplicatedEmbedding(nn.Module):
"""Full vocab embedding replicated on every TP rank (no sharding).

Each rank holds the complete ``[num_embeddings, embedding_dim]`` table and
does a purely local lookup, so the forward needs **no all-reduce** — unlike
``VocabParallelEmbedding``, which shards the vocab and must all-reduce the
masked partial lookups to reconstruct the full vector.

Trades ``(tp-1)/tp`` of the embedding's memory per rank for one fewer
collective per embed. Use ONLY where the embedding is independent of any
sharded ``lm_head`` (e.g. the EAGLE3 draft, whose embed/lm_head are separate
tensors). Do NOT use for an embedding shared/tied with a TP-sharded lm_head
or with the target model's sharded embedding.
"""

def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.num_embeddings = num_embeddings
self.weight = atom_parameter(
torch.empty(num_embeddings, embedding_dim),
)
self.weight.weight_loader = self.weight_loader

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# Full (un-sharded) copy: every rank gets the complete table.
assert param.data.size() == loaded_weight.size(), (
f"ReplicatedEmbedding expects the full weight "
f"{tuple(param.data.size())}, got {tuple(loaded_weight.size())}"
)
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor):
return F.embedding(x, self.weight)


class ParallelLMHead(VocabParallelEmbedding):

def __init__(
Expand Down Expand Up @@ -190,3 +226,26 @@ def forward(self, x: torch.Tensor):
# dist.gather(logits, all_logits, 0)
# logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits

def compute_argmax_token(self, x: torch.Tensor) -> torch.Tensor:
"""Greedy argmax token over the (TP-sharded) vocab — returns ``[N]`` token
ids WITHOUT all-gathering the full ``[N, vocab]`` logits.

For greedy speculative drafting only the argmax is needed, so each rank
reduces its own vocab shard to ``(max_val, global_idx)`` and we all-gather
just those ``[N, 2]`` (tp small) instead of the O(vocab) logits. Token
selection is identical to a full-logits ``argmax``: the values compared
are the same bf16 logits (fp32-packed exactly), and tie-breaking matches
the lowest global index — ``torch.max`` picks the lowest local index, and
``argmax`` over ranks picks the lowest rank (== lowest vocab range).
"""
logits = tgemm.mm(x, self.weight, self.bias) # [N, vocab/tp]
if self.tp_size <= 1:
return logits.argmax(dim=-1)
# Pack (val, idx) as fp32 — idx < 2^24 is exact — and all-gather only the
# per-rank reductions ([N, 2]) instead of the full logits.
packed = lm_head_argmax_pack(logits, self.vocab_start_idx)
gathered = get_tp_group().all_gather(packed, dim=0).view(self.tp_size, -1, 2)
winner = gathered[:, :, 0].argmax(dim=0) # [N] winning rank (ties -> lowest)
token = gathered[:, :, 1].gather(0, winner.unsqueeze(0)).squeeze(0) # [N] fp32
return token.to(torch.long)
Comment on lines +242 to +251
Loading
Loading