diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index aaf8789496..8270ff1cdc 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1595,6 +1595,10 @@ def allocate_kv_cache(self, num_kvcache_blocks): for i, kv_cache_tensor in enumerate(kv_cache_tensors) } transfer_tensors = self.attn_metadata_builder.get_kv_transfer_tensors() + if hasattr(self, "eagle3_draft_builder") and transfer_tensors is not None: + draft_regions = self.eagle3_draft_builder.get_kv_transfer_tensors() + if draft_regions: + transfer_tensors.block_regions.extend(draft_regions) set_kv_cache_data(kv_cache_data, config, transfer_tensors) # Cross-validate: compare estimated vs actual KV cache allocation. diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index dab098f1a9..6db875046f 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -1226,6 +1226,19 @@ def postprocess( # no-op). if stop_at_idx is not None and stop_at_idx < num_new_token - 1: num_tokens -= (num_new_token - 1) - stop_at_idx + # The same truncation MUST apply to the EMITTED tokens, not just + # the internal seq length. The client-visible text is built from + # RequestOutput.output_tokens (an accumulation of `new_tokens`) by + # generate_async / the streaming callback — NOT from + # completion_token_ids (which the `seq.num_tokens` write above + # governs). Without trimming `new_tokens` here, the post-stop + # tokens the rejection sampler emits past EOS (it does not inspect + # EOS) leak into the response: strict-match still finds the answer, + # but flexible-extract's last-number picks up the leaked trailing + # digit. `injected_t0` (if present) prepends one slot not counted + # in stop_at_idx / num_new_token, so offset the cut by it. + keep = stop_at_idx + 1 + (1 if injected_t0 is not None else 0) + new_tokens = new_tokens[:keep] # Prepare stream output if stream_output_queue is not None and new_tokens: diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index a7b1a90227..969a3b1021 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -7,6 +7,8 @@ import aiter import numpy as np import torch +import triton +import triton.language as tl from aiter.dist.parallel_state import get_tp_group from atom.model_engine.scheduler import ScheduledBatch from atom.utils import CpuGpuBuffer, envs @@ -15,10 +17,7 @@ kv_indices_generate_triton, ) from atom.model_ops.attention_mha import PagedAttentionImpl, use_pa_decode_bf16_asm -from atom.utils.forward_context import ( - AttentionMetaData, - Context, -) +from atom.utils.forward_context import AttentionMetaData, Context, get_forward_context from atom.utils.tbo import TokenSplitPrefillState from .backends import AttentionBackend, CommonAttentionBuilder @@ -43,6 +42,54 @@ def _is_indexed_sparse_attention(module) -> bool: return bool(getattr(impl, "is_indexed_sparse_attention", False)) +@triton.jit +def _mtp_prepare_decode_metadata_kernel( + context_lens_ptr, + block_tables_ptr, + slot_mapping_ptr, + positions_in_ptr, + positions_out_ptr, + last_token_indices_ptr, + bs, + skip_update: tl.constexpr, + update_context_lens: tl.constexpr, + update_positions: tl.constexpr, + select_positions: tl.constexpr, + block_size: tl.constexpr, + block_table_stride: tl.constexpr, + position_stride: tl.constexpr, + BLOCK: tl.constexpr, +): + if not skip_update: + seq = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + mask = seq < bs + + ctx = tl.load(context_lens_ptr + seq, mask=mask, other=1).to(tl.int64) + if update_context_lens: + ctx += 1 + tl.store(context_lens_ptr + seq, ctx, mask=mask) + + if update_positions: + pos_idx = seq + if select_positions: + pos_idx = tl.load(last_token_indices_ptr + seq, mask=mask, other=0) + pos = tl.load(positions_in_ptr + pos_idx, mask=mask, other=0) + tl.store(positions_out_ptr + seq * position_stride, pos + 1, mask=mask) + + last_pos = tl.maximum(ctx - 1, 0) + block_col = last_pos // block_size + within_block = last_pos - block_col * block_size + + phys_block = tl.load( + block_tables_ptr + seq * block_table_stride + block_col, + mask=mask, + other=0, + ).to(tl.int64) + tl.store( + slot_mapping_ptr + seq, phys_block * block_size + within_block, mask=mask + ) + + class AiterBackend(AttentionBackend): @staticmethod def get_name() -> str: @@ -59,6 +106,9 @@ def get_impl_cls(): class AiterAttentionMetadataBuilder(CommonAttentionBuilder): BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + # EagleProposer fuses the per-draft-step position bump into + # prepare_mtp_decode's kernel when this is set (block-paged MHA draft). + fuse_mtp_decode_position_update = True def __init__( self, @@ -712,6 +762,73 @@ def _add_region(tensor): num_blocks=runner.num_physical_kvcache_blocks, ) + def prepare_mtp_decode( + self, + bs: int, + max_seqlen_q: int, + max_seqlen_k: int, + positions: torch.Tensor, + only_update: bool = False, + num_reject_tokens: torch.Tensor = None, + *, + update_context_lens: bool = False, + positions_out: torch.Tensor | None = None, + last_token_indices: torch.Tensor | None = None, + ): + """Per-draft-step metadata for a block-paged MHA Eagle3 draft. + + Called by EagleProposer.propose at mid-step iters. The draft's decode + kernels (``paged_attention_{asm,triton}``) read ``block_tables`` + + ``context_lens``. Eagle can pre-bump ``context_lens`` before this call, + or ask this fused kernel to update it in place. The block_size==1024 + persistent path is the only one consuming ``kv_indptr``/``kv_indices``; + MiniMax-M3 runs at ``--block-size 128`` so the kernel never reads them - + no rebuild. + + The one value we must (re)compute is the write slot for the new draft + token in the draft's own block-paged KV cache: + + slot = block_tables[seq, (ctx-1)//B] * B + (ctx-1) % B, B = block_size + + Returned under ``slot_mapping`` so EagleProposer skips its token-granular + (MLA physical block_size==1) flat-kv slot derivation, which would yield + a bare block id for ``B > 1``. + + ``only_update`` / ``num_reject_tokens`` are MLA/V4-specific knobs and are + unused here: there are no persistent worker buffers to roll over for + ``block_size != 1024``. + """ + var = self.model_runner.forward_vars + slot_mapping = var["slot_mapping"].gpu[:bs] + block_tables = var["block_tables"].gpu + context_lens = var["context_lens"].gpu + update_positions = positions_out is not None + select_positions = update_positions and last_token_indices is not None + if positions_out is None: + positions_out = positions + if last_token_indices is None: + last_token_indices = slot_mapping + # Dummy runs skip the draft attention, so keep this launch as a no-op: + # their synthetic context_lens can point past block_tables. + _mtp_prepare_decode_metadata_kernel[(max(1, triton.cdiv(bs, 128)),)]( + context_lens, + block_tables, + slot_mapping, + positions, + positions_out, + last_token_indices, + bs, + bs == 0 or get_forward_context().context.is_dummy_run, + update_context_lens, + update_positions, + select_positions, + self.model_runner.block_size, + block_tables.stride(0), + positions_out.stride(0) if update_positions else 1, + BLOCK=128, + ) + return {"slot_mapping": slot_mapping} + def prepare_prefill(self, batch: ScheduledBatch): attn_metadata, positions = CommonAttentionBuilder.prepare_prefill(self, batch) if self._has_sparse_attention and not attn_metadata.has_cached: diff --git a/atom/model_ops/embed_head.py b/atom/model_ops/embed_head.py index 0c2ca9bd60..15a8505c97 100644 --- a/atom/model_ops/embed_head.py +++ b/atom/model_ops/embed_head.py @@ -10,6 +10,7 @@ from aiter.dist.parallel_state import get_tp_group from aiter.jit.utils.torch_guard import torch_compile_guard +from atom.model_ops.lm_head_argmax import lm_head_argmax_pack from atom.model_ops.utils import atom_parameter from atom.plugin import is_plugin_mode from atom.utils import envs @@ -151,6 +152,41 @@ def forward(self, x: torch.Tensor): # return y +class ReplicatedEmbedding(nn.Module): + """Full vocab embedding replicated on every TP rank (no sharding). + + Each rank holds the complete ``[num_embeddings, embedding_dim]`` table and + does a purely local lookup, so the forward needs **no all-reduce** — unlike + ``VocabParallelEmbedding``, which shards the vocab and must all-reduce the + masked partial lookups to reconstruct the full vector. + + Trades ``(tp-1)/tp`` of the embedding's memory per rank for one fewer + collective per embed. Use ONLY where the embedding is independent of any + sharded ``lm_head`` (e.g. the EAGLE3 draft, whose embed/lm_head are separate + tensors). Do NOT use for an embedding shared/tied with a TP-sharded lm_head + or with the target model's sharded embedding. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__() + self.num_embeddings = num_embeddings + self.weight = atom_parameter( + torch.empty(num_embeddings, embedding_dim), + ) + self.weight.weight_loader = self.weight_loader + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + # Full (un-sharded) copy: every rank gets the complete table. + assert param.data.size() == loaded_weight.size(), ( + f"ReplicatedEmbedding expects the full weight " + f"{tuple(param.data.size())}, got {tuple(loaded_weight.size())}" + ) + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor): + return F.embedding(x, self.weight) + + class ParallelLMHead(VocabParallelEmbedding): def __init__( @@ -190,3 +226,26 @@ def forward(self, x: torch.Tensor): # dist.gather(logits, all_logits, 0) # logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None return logits + + def compute_argmax_token(self, x: torch.Tensor) -> torch.Tensor: + """Greedy argmax token over the (TP-sharded) vocab — returns ``[N]`` token + ids WITHOUT all-gathering the full ``[N, vocab]`` logits. + + For greedy speculative drafting only the argmax is needed, so each rank + reduces its own vocab shard to ``(max_val, global_idx)`` and we all-gather + just those ``[N, 2]`` (tp small) instead of the O(vocab) logits. Token + selection is identical to a full-logits ``argmax``: the values compared + are the same bf16 logits (fp32-packed exactly), and tie-breaking matches + the lowest global index — ``torch.max`` picks the lowest local index, and + ``argmax`` over ranks picks the lowest rank (== lowest vocab range). + """ + logits = tgemm.mm(x, self.weight, self.bias) # [N, vocab/tp] + if self.tp_size <= 1: + return logits.argmax(dim=-1) + # Pack (val, idx) as fp32 — idx < 2^24 is exact — and all-gather only the + # per-rank reductions ([N, 2]) instead of the full logits. + packed = lm_head_argmax_pack(logits, self.vocab_start_idx) + gathered = get_tp_group().all_gather(packed, dim=0).view(self.tp_size, -1, 2) + winner = gathered[:, :, 0].argmax(dim=0) # [N] winning rank (ties -> lowest) + token = gathered[:, :, 1].gather(0, winner.unsqueeze(0)).squeeze(0) # [N] fp32 + return token.to(torch.long) diff --git a/atom/model_ops/fused_aux_rmsnorm.py b/atom/model_ops/fused_aux_rmsnorm.py new file mode 100644 index 0000000000..9f512284f0 --- /dev/null +++ b/atom/model_ops/fused_aux_rmsnorm.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Fused per-group RMSNorm for EAGLE3 aux hidden-state fusion. + +EAGLE3's ``combine_hidden_states`` normalizes ``num_aux`` aux chunks (each with +its own ``fc_norm`` weight) and concatenates them into the ``[N, num_aux*H]`` +input of the ``fc`` projection. The naive path launches one RMSNorm per chunk +plus a concat; this kernel does all chunks in a single launch, writing straight +into the contiguous ``fc`` input buffer. + +Input layout: ``x`` is the concatenated aux ``[N, num_aux*H]`` (view as groups +of ``H`` along the last dim). ``weight`` is the per-group RMSNorm weights +stacked to ``[num_aux, H]``. Plain RMSNorm (``x * rstd * w``, fp32 reduction) — +matches ``atom.model_ops.layernorm.RMSNorm`` (NOT the Gemma ``1+w`` variant). +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_group_rmsnorm_kernel( + x_ptr, # [N, G*H] contiguous + w_ptr, # [G, H] contiguous + out_ptr, # [N, G*H] contiguous + n_rows, + G: tl.constexpr, + H: tl.constexpr, + eps, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + g = tl.program_id(1) + col = tl.arange(0, BLOCK_H) + mask = col < H + + row_base = row * (G * H) + g * H + x = tl.load(x_ptr + row_base + col, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) / H + rstd = 1.0 / tl.sqrt(var + eps) + w = tl.load(w_ptr + g * H + col, mask=mask, other=0.0).to(tl.float32) + y = x * rstd * w + tl.store(out_ptr + row_base + col, y.to(out_ptr.dtype.element_ty), mask=mask) + + +def fused_group_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + num_groups: int, +) -> torch.Tensor: + """Per-group RMSNorm over a concatenated ``[N, num_groups*H]`` tensor. + + Args: + x: contiguous ``[N, num_groups*H]`` (groups of ``H`` along dim -1). + weight: per-group weights stacked to ``[num_groups, H]`` (contiguous). + eps: RMSNorm epsilon. + num_groups: number of aux groups (``G``). + + Returns: + ``[N, num_groups*H]`` with each group RMS-normalized by its own weight. + """ + assert x.is_cuda, "fused_group_rmsnorm requires a CUDA tensor." + assert x.dim() == 2 and x.is_contiguous() + n_rows, total = x.shape + assert total % num_groups == 0 + H = total // num_groups + assert weight.shape == ( + num_groups, + H, + ), f"weight must be [{num_groups}, {H}], got {tuple(weight.shape)}" + + out = torch.empty_like(x) + BLOCK_H = triton.next_power_of_2(H) + num_warps = 8 if BLOCK_H >= 4096 else (4 if BLOCK_H >= 1024 else 2) + grid = (n_rows, num_groups) + _fused_group_rmsnorm_kernel[grid]( + x, + weight.contiguous(), + out, + n_rows, + num_groups, + H, + float(eps), + BLOCK_H=BLOCK_H, + num_warps=num_warps, + ) + return out + + +# --------------------------------------------------------------------------- +# Dual-input RMSNorm + concat (EAGLE3 draft decoder-layer attention input) +# +# The Eagle3 draft decoder layer normalizes two same-shaped ``[N, H]`` inputs +# (``embeds`` with ``input_layernorm``, ``hidden_states`` with ``hidden_norm``) +# and concatenates them into the ``[N, 2H]`` QKV input. The naive path is two +# RMSNorm launches + a concat (3 launches; the concat re-reads + re-writes 2NH). +# This kernel does it in a single launch that writes each normalized half +# straight into the contiguous ``[N, 2H]`` output, cutting memory traffic from +# ~8NH (norm+norm+cat) to ~4NH. Plain RMSNorm math (``x * rstd * w``, fp32 +# reduction) — matches ``atom.model_ops.layernorm.RMSNorm`` and the sibling +# ``fused_group_rmsnorm`` above. +# +# Raw Triton (no custom-op wrapper): the EAGLE3 draft is built with +# ``CompilationLevel.NO_COMPILATION`` (eagle.py), so its forward always runs +# eager and never enters Dynamo — same as ``fused_group_rmsnorm`` above. +# +# grid = (n_rows, 2): program (row, 0) normalizes ``a`` -> out[:, :H], program +# (row, 1) normalizes ``b`` -> out[:, H:]. 2*n_rows programs (vs n_rows) keeps +# occupancy up at small batch (EAGLE decode N == bs). +# --------------------------------------------------------------------------- + + +@triton.jit +def _fused_dual_rmsnorm_cat_kernel( + a_ptr, # [N, H] contiguous + b_ptr, # [N, H] contiguous + wa_ptr, # [H] + wb_ptr, # [H] + out_ptr, # [N, 2H] contiguous + H, + eps, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + g = tl.program_id(1) # 0 -> (a, wa) into out[:, :H]; 1 -> (b, wb) into out[:, H:] + col = tl.arange(0, BLOCK_H) + mask = col < H + + # g is uniform across the program (one (row, half) per program), so this is + # uniform control flow — no divergence, and avoids selecting between two + # base pointers (unsupported in Triton). Weights are reused across rows of + # the same half, so keep them resident with evict_last. + if g == 0: + x = tl.load(a_ptr + row * H + col, mask=mask, other=0.0).to(tl.float32) + w = tl.load( + wa_ptr + col, mask=mask, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + else: + x = tl.load(b_ptr + row * H + col, mask=mask, other=0.0).to(tl.float32) + w = tl.load( + wb_ptr + col, mask=mask, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + + var = tl.sum(x * x, axis=0) / H + rstd = tl.rsqrt(var + eps) + y = x * rstd * w + tl.store( + out_ptr + row * (2 * H) + g * H + col, + y.to(out_ptr.dtype.element_ty), + mask=mask, + ) + + +def fused_dual_rmsnorm_cat( + a: torch.Tensor, + b: torch.Tensor, + w_a: torch.Tensor, + w_b: torch.Tensor, + eps: float, +) -> torch.Tensor: + """RMS-norm two ``[N, H]`` inputs by their own weights into one ``[N, 2H]``. + + ``out[:, :H] = rmsnorm(a, w_a)``, ``out[:, H:] = rmsnorm(b, w_b)`` — the + concatenated attention input for the Eagle3 draft decoder layer, produced + in a single Triton launch (no separate per-input norm + concat). + + Args: + a, b: contiguous ``[N, H]`` inputs (same shape). + w_a, w_b: per-input RMSNorm weights ``[H]``. + eps: RMSNorm epsilon (shared by both norms). + + Returns: + contiguous ``[N, 2H]`` with the two normalized halves side by side. + """ + n_rows, H = a.shape + out = torch.empty((n_rows, 2 * H), dtype=a.dtype, device=a.device) + BLOCK_H = triton.next_power_of_2(H) + num_warps = 8 if BLOCK_H >= 4096 else (4 if BLOCK_H >= 1024 else 2) + grid = (n_rows, 2) + _fused_dual_rmsnorm_cat_kernel[grid]( + a, + b, + w_a, + w_b, + out, + H, + float(eps), + BLOCK_H=BLOCK_H, + num_warps=num_warps, + ) + return out diff --git a/atom/model_ops/lm_head_argmax.py b/atom/model_ops/lm_head_argmax.py new file mode 100644 index 0000000000..4137f67656 --- /dev/null +++ b/atom/model_ops/lm_head_argmax.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl + +from aiter.jit.utils.torch_guard import torch_compile_guard + +_MAX_BLOCK_M = 131072 +# One program reduces one row, so small row counts underutilize the GPU. +_MIN_ROWS_FOR_FUSED_ARGMAX = 16 + + +@triton.jit +def _lm_head_argmax_pack_kernel( + logits_ptr, + packed_ptr, + vocab_start_idx, + M: tl.constexpr, + stride_logits_n: tl.constexpr, + stride_logits_m: tl.constexpr, + stride_packed_n: tl.constexpr, + BLOCK_M: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_M) + mask = offs < M + vals = tl.load( + logits_ptr + row * stride_logits_n + offs * stride_logits_m, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + + max_val = tl.max(vals, axis=0) + idxs = offs.to(tl.int64) + masked_idxs = tl.where((vals == max_val) & mask, idxs, idxs + BLOCK_M) + local_idx = tl.min(masked_idxs, axis=0) + global_idx = local_idx + vocab_start_idx + + tl.store(packed_ptr + row * stride_packed_n, max_val) + tl.store(packed_ptr + row * stride_packed_n + 1, global_idx.to(tl.float32)) + + +def _lm_head_argmax_pack_fake( + logits: torch.Tensor, + vocab_start_idx: int, +) -> torch.Tensor: + return torch.empty((logits.shape[0], 2), dtype=torch.float32, device=logits.device) + + +def _torch_lm_head_argmax_pack( + logits: torch.Tensor, + vocab_start_idx: int, +) -> torch.Tensor: + local_max_val, local_idx = logits.max(dim=-1) + global_idx = local_idx + vocab_start_idx + return torch.stack([local_max_val.float(), global_idx.float()], dim=-1) + + +@torch_compile_guard(gen_fake=_lm_head_argmax_pack_fake) +def lm_head_argmax_pack(logits: torch.Tensor, vocab_start_idx: int) -> torch.Tensor: + """Reduce local LM-head logits and pack (max_val, global_idx) as fp32.""" + if logits.dim() != 2: + raise ValueError("lm_head_argmax_pack expects a 2-D logits tensor") + + N, M = logits.shape + if N == 0: + return torch.empty((0, 2), dtype=torch.float32, device=logits.device) + if N < _MIN_ROWS_FOR_FUSED_ARGMAX or M > _MAX_BLOCK_M: + return _torch_lm_head_argmax_pack(logits, vocab_start_idx) + + packed = torch.empty((N, 2), dtype=torch.float32, device=logits.device) + block_m = triton.next_power_of_2(M) + num_warps = 8 if block_m >= 2048 else 4 + + _lm_head_argmax_pack_kernel[(N,)]( + logits, + packed, + vocab_start_idx, + M=M, + stride_logits_n=logits.stride(0), + stride_logits_m=logits.stride(1), + stride_packed_n=packed.stride(0), + BLOCK_M=block_m, + num_warps=num_warps, + num_stages=2, + ) + return packed diff --git a/atom/models/eagle3_llama.py b/atom/models/eagle3_llama.py index bfb8dcb501..0aac1f2f67 100644 --- a/atom/models/eagle3_llama.py +++ b/atom/models/eagle3_llama.py @@ -21,7 +21,15 @@ from atom.config import Config from atom.model_ops.activation import SiluAndMul from atom.model_ops.base_attention import Attention -from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding +from atom.model_ops.embed_head import ( + ParallelLMHead, + ReplicatedEmbedding, + VocabParallelEmbedding, +) +from atom.model_ops.fused_aux_rmsnorm import ( + fused_dual_rmsnorm_cat, + fused_group_rmsnorm, +) from atom.model_ops.layernorm import RMSNorm from atom.model_ops.linear import ( MergedColumnParallelLinear, @@ -29,9 +37,17 @@ ReplicatedLinear, RowParallelLinear, ) +from atom.utils import envs from atom.utils.decorators import support_torch_compile from torch import nn +# AR+RMSNorm fusion: when on (default), RowParallel o_proj/down_proj skip their +# own all-reduce (reduce_results=False) and the downstream RMSNorm fuses +# all-reduce + residual-add + norm into one kernel. Only active at TP>1; the +# RMSNorm/RowParallel paths fall back to plain behavior at TP1. Same env and +# kernel as ATOM's mainline TP models (deepseek_v2, qwen3_moe, ...). +ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION + class Eagle3LlamaAttention(nn.Module): """Llama full-attention with input_size = hidden_size * 2. @@ -49,6 +65,7 @@ def __init__( cache_config: str = "bf16", prefix: str = "", layer_num: int = 0, + reduce_results: bool = True, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -85,6 +102,7 @@ def __init__( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=False, + reduce_results=reduce_results, prefix=f"{prefix}.o_proj", ) @@ -142,10 +160,20 @@ def __init__( cache_config: str = "bf16", prefix: str = "", layer_num: int = 0, + norm_output: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size + # Point 1 (always): o_proj skips its all-reduce so post_attention_layernorm + # fuses all-reduce + residual-add + norm. Point 2 (norm_output only): + # down_proj skips its all-reduce so the model's final self.norm fuses it; + # for the legacy (norm_output=False) path the output norm is deferred to + # compute_logits with no adjacent residual-add, so down_proj all-reduces + # normally. + attn_reduce = not ENABLE_ALLREDUCE_RMSNORM_FUSION + mlp_reduce = not (ENABLE_ALLREDUCE_RMSNORM_FUSION and norm_output) + self.self_attn = Eagle3LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -156,40 +184,71 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", layer_num=layer_num, + reduce_results=attn_reduce, ) self.mlp = Eagle3LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, prefix=f"{prefix}.mlp", + reduce_results=mlp_reduce, ) # Dual norms matching checkpoint keys: midlayer.input_layernorm, midlayer.hidden_norm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION, ) + def _dual_norm_cat( + self, embeds: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: + """RMS-norm embeds and the carried hidden by their own weights and concat + into the [N, 2*hidden] QKV input. + + Single fused Triton launch (one [N, 2H] write) instead of two RMSNorm + launches + a concat. Falls back to the aiter RMSNorm + torch.cat path + when the kernel's preconditions don't hold (non-CUDA / non-contiguous / + shape mismatch). input_layernorm and hidden_norm share rms_norm_eps. + """ + if ( + embeds.is_cuda + and embeds.is_contiguous() + and hidden_states.is_contiguous() + and embeds.shape == hidden_states.shape + ): + return fused_dual_rmsnorm_cat( + embeds, + hidden_states, + self.input_layernorm.weight, + self.hidden_norm.weight, + self.input_layernorm.eps, + ) + normed_embeds = self.input_layernorm(embeds) + normed_hidden = self.hidden_norm(hidden_states) + return torch.cat([normed_embeds, normed_hidden], dim=-1) + def forward( self, positions: torch.Tensor, embeds: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: - normed_embeds = self.input_layernorm(embeds) - normed_hidden = self.hidden_norm(hidden_states) - # Concat for attention input: [N, hidden*2] - attn_input = torch.cat([normed_embeds, normed_hidden], dim=-1) + ) -> tuple[torch.Tensor, torch.Tensor]: + attn_input = self._dual_norm_cat(embeds, hidden_states) attn_output = self.self_attn(positions, attn_input) - # Residual connection on hidden_states - hidden_states = hidden_states + attn_output - # MLP with pre-norm + residual - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + # Fused (all-reduce +) residual-add + pre-MLP norm in one kernel: + # residual = [all_reduce(attn_output)] + hidden_states + # hidden_states = post_attention_layernorm(residual) + hidden_states, residual = self.post_attention_layernorm( + attn_output, hidden_states + ) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + # Return the MLP output and its residual; the model fuses the final + # residual-add with the output norm (norm_output) or adds plainly. + return hidden_states, residual class Eagle3LlamaMLP(nn.Module): @@ -200,6 +259,7 @@ def __init__( hidden_size: int, intermediate_size: int, prefix: str = "", + reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -212,6 +272,7 @@ def __init__( input_size=intermediate_size, output_size=hidden_size, bias=False, + reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) self.act_fn = SiluAndMul() @@ -243,6 +304,13 @@ class Eagle3LlamaModel(nn.Module): "up_proj": ("gate_up_proj", 1), } + # The single decoder layer is named `midlayer` here, but some EAGLE3 + # checkpoints ship it as `layers.0.*` (e.g. the torchspec-format + # Inferact/MiniMax-M3-EAGLE3) instead of the kimi-k2.5 `midlayer.*` layout. + # Translate that prefix on load. No-op for `midlayer.*` checkpoints (the + # substring is absent), so both naming conventions load correctly. + weights_mapping = {"layers.0.": "midlayer."} + def __init__(self, atom_config: Config, prefix: str = "", layer_offset: int = 0): super().__init__() config = atom_config.hf_config @@ -267,10 +335,20 @@ def __init__(self, atom_config: Config, prefix: str = "", layer_offset: int = 0) self.num_aux_hidden_states = num_aux self.norm_output = getattr(config, "norm_output", False) - # Independent embedding (vocab matches target model) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size - ) + # Independent embedding (vocab matches target model). The draft embed is + # NOT shared with the (still TP-sharded) lm_head, so it can be replicated + # full on every rank — a local lookup with no post-embedding all-reduce. + # Bit-identical to the sharded path; on by default (trades memory for one + # fewer collective per draft step). Falls back to the sharded embedding + # when ATOM_EAGLE_REPLICATE_EMBED=0. + if envs.ATOM_EAGLE_REPLICATE_EMBED: + self.embed_tokens = ReplicatedEmbedding( + config.vocab_size, config.hidden_size + ) + else: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) # Aux fusion: [N, target_hidden_size * num_aux] -> [N, hidden_size] self.fc = ReplicatedLinear( @@ -296,30 +374,82 @@ def __init__(self, atom_config: Config, prefix: str = "", layer_offset: int = 0) cache_config=cache_config, prefix="midlayer", layer_num=layer_offset, + norm_output=self.norm_output, ) - # Final norm - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Final norm. Point 2: on the norm_output path it fuses down_proj's + # all-reduce + residual-add + norm. On the legacy path it stays plain + # (called without residual in compute_logits), so no fusion here. + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=ENABLE_ALLREDUCE_RMSNORM_FUSION and self.norm_output, + ) # Independent lm_head (not shared with target model) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - def combine_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Project concatenated aux hidden states through fc. + def combine_hidden_states(self, aux_hidden_states) -> torch.Tensor: + """Project the per-layer aux hidden states through fc. Args: - hidden_states: [N, target_hidden_size * num_aux_hidden_states] + aux_hidden_states: either a list/tuple of per-layer aux tensors + ([N, target_hidden_size] each) — preferred, skips an extra + concat — or a single pre-concatenated + [N, target_hidden_size * num_aux_hidden_states] tensor + (back-compat). Returns: [N, hidden_size] projected hidden states """ - if self.fc_norm is not None: - chunks = hidden_states.chunk(self.num_aux_hidden_states, dim=-1) - hidden_states = torch.cat( + is_list = isinstance(aux_hidden_states, (list, tuple)) + if self.fc_norm is None: + if is_list: + fc_in = ( + aux_hidden_states[0] + if len(aux_hidden_states) == 1 + else torch.cat(aux_hidden_states, dim=-1) + ) + else: + fc_in = aux_hidden_states + return self.fc(fc_in) + + # fc_norm path: per-group RMSNorm, then fc. Use the single-launch fused + # kernel (one RMSNorm over all aux chunks) instead of per-chunk RMSNorm + # + concat; fall back to the torch path only when the fused kernel's + # preconditions don't hold (non-CUDA / non-contiguous / shape mismatch). + x = torch.cat(aux_hidden_states, dim=-1) if is_list else aux_hidden_states + if ( + x.is_cuda + and x.is_contiguous() + and x.shape[-1] == self.num_aux_hidden_states * self.fc_norm[0].dim + ): + fc_in = fused_group_rmsnorm( + x, + self._fc_norm_weight_stacked(), + self.fc_norm[0].eps, + self.num_aux_hidden_states, + ) + else: + chunks = ( + aux_hidden_states + if is_list + else x.chunk(self.num_aux_hidden_states, dim=-1) + ) + fc_in = torch.cat( [norm(chunk) for norm, chunk in zip(self.fc_norm, chunks)], dim=-1, ) - return self.fc(hidden_states) + return self.fc(fc_in) + + def _fc_norm_weight_stacked(self) -> torch.Tensor: + """Per-group fc_norm weights stacked to [num_aux, H] (cached).""" + ref = self.fc_norm[0].weight + w = getattr(self, "_fc_norm_w_cache", None) + if w is None or w.device != ref.device or w.dtype != ref.dtype: + w = torch.stack([m.weight for m in self.fc_norm], dim=0).contiguous() + self._fc_norm_w_cache = w + return w def forward( self, @@ -334,8 +464,16 @@ def forward( compute_logits() is norm-aware, so EagleProposer only sees one tensor. """ embeds = self.embed_tokens(input_ids) - hidden_states = self.midlayer(positions, embeds, hidden_states) - return self.norm(hidden_states) if self.norm_output else hidden_states + hidden_states, residual = self.midlayer(positions, embeds, hidden_states) + if self.norm_output: + # EAGLE 3.1: fused final residual-add + output RMSNorm (one kernel). + hidden_states, _ = self.norm(hidden_states, residual) + else: + # EAGLE 3 / K2.5: carry the pre-norm hidden forward; the norm is + # deferred to compute_logits, so the add stays standalone here + # (byte-equivalent to the legacy path). + hidden_states = residual + hidden_states + return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: # Only norm the legacy pre-norm path; norm_output already normed in @@ -343,3 +481,12 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: if not self.norm_output: hidden_states = self.norm(hidden_states) return self.lm_head(hidden_states) + + def compute_draft_token(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Greedy draft token via distributed argmax — avoids all-gathering the + full [N, vocab] logits every draft step. Token-identical to + compute_logits(...).argmax(-1); norm handling mirrors compute_logits. + """ + if not self.norm_output: + hidden_states = self.norm(hidden_states) + return self.lm_head.compute_argmax_token(hidden_states) diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 03f5177a6e..fd0ddef08e 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -587,6 +587,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + aux_out: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -596,6 +597,13 @@ def forward( hidden_states, residual, self.input_layernorm ) + # Eagle3 aux hidden state = the all-reduced residual stream entering this + # layer (post input-norm). Captured here, not as `hidden_states + residual` + # in the model loop, because M3's fused all-reduce RMSNorm leaves that sum + # TP-partial / NaN-prone under CUDAGraph. + if aux_out is not None: + aux_out.append(residual.clone()) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states, residual = fused_allreduce_gemma_rms_norm( hidden_states, residual, self.post_attention_layernorm @@ -646,6 +654,10 @@ def __init__( else: self.norm = PPMissingLayer() + # Eagle3 aux hidden-state capture layer ids. Empty unless an Eagle3 drafter + # registers them via MiniMaxM3SparseForCausalLM.set_aux_hidden_state_layers. + self.aux_hidden_state_layers: tuple[int, ...] = tuple() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) @@ -672,9 +684,11 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + aux_hidden_states: list[torch.Tensor] = [] for idx in range(self.start_layer, self.end_layer): + aux_out = aux_hidden_states if idx in self.aux_hidden_state_layers else None hidden_states, residual = self.layers[idx]( - positions, hidden_states, residual + positions, hidden_states, residual, aux_out=aux_out ) if not get_pp_group().is_last_rank: @@ -685,6 +699,8 @@ def forward( hidden_states, _ = fused_allreduce_gemma_rms_norm( hidden_states, residual, self.norm ) + if aux_hidden_states: + return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -743,6 +759,16 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Default Eagle3 aux hidden-state layer ids: early / middle / late of + the target model (early=2, mid=n//2, late=n-3), matching vLLM's default. + """ + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, @@ -808,6 +834,12 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.language_model.embed_input_ids(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + return self.language_model.get_eagle3_aux_hidden_state_layers() + def forward( self, input_ids: torch.Tensor, diff --git a/atom/spec_decode/eagle.py b/atom/spec_decode/eagle.py index 2e7e3db357..7a921de241 100644 --- a/atom/spec_decode/eagle.py +++ b/atom/spec_decode/eagle.py @@ -159,6 +159,43 @@ def build_kv_cache_tensor(self, layer_id: int, module): v_scale=getattr(module, "v_scale", None), ) + def get_kv_transfer_tensors(self) -> list: + from atom.kv_transfer.disaggregation.types import KVTransferRegion + + runner = self.model_runner + if not hasattr(runner, "eagle3_kv_cache"): + return [] + + regions: list[KVTransferRegion] = [] + cache = runner.eagle3_kv_cache + for layer_id in range(self.num_layers): + for kv in range(2): + t = cache[kv, layer_id] + regions.append( + KVTransferRegion( + base_addr=t.data_ptr(), + total_bytes=t.numel() * t.element_size(), + unit_bytes=t.stride(0) * t.element_size(), + ) + ) + scale = runner.eagle3_kv_scale + if ( + self.model_runner.config.kv_cache_dtype == "fp8" + and scale is not None + and scale.numel() > 0 + ): + for layer_id in range(self.num_layers): + for kv in range(2): + t = scale[kv, layer_id] + regions.append( + KVTransferRegion( + base_addr=t.data_ptr(), + total_bytes=t.numel() * t.element_size(), + unit_bytes=t.stride(0) * t.element_size(), + ) + ) + return regions + class EagleProposer: @@ -223,6 +260,8 @@ def __init__( else: self.model = model_class(self.config) + self._draft_argmax_fused = hasattr(self.model, "compute_draft_token") + i32_kwargs = {"dtype": torch.int32, "device": self.device} i64_kwargs = {"dtype": torch.int64, "device": self.device} max_bs = self.config.max_num_seqs @@ -251,8 +290,6 @@ def _share_if_not_loaded( def load_model(self, target_model: nn.Module) -> None: if self.speculative_config.method == "eagle3": - # Eagle3: load from a separate draft model checkpoint with - # independent embed_tokens and lm_head (no sharing). load_model( self.model, self.speculative_config.model, @@ -415,8 +452,13 @@ def propose( if i == 0 else ret_hidden_states ) - logits = self.model.compute_logits(sample_hidden_states) - new_draft_ids = logits.argmax(dim=-1) + # Distributed argmax (all-gather [N, 2] not [N, vocab]) when the + # draft supports it; token-identical to compute_logits().argmax(). + if self._draft_argmax_fused: + new_draft_ids = self.model.compute_draft_token(sample_hidden_states) + else: + logits = self.model.compute_logits(sample_hidden_states) + new_draft_ids = logits.argmax(dim=-1) draft_token_ids[:, i] = new_draft_ids if i < self.mtp_k - 1: @@ -471,10 +513,20 @@ def propose( # update metadata attn_metadata.max_seqlen_k += 1 - # Update context_lens for each draft step (needed by both - # MHA attention and MLA+sparse indexer) - attn_metadata.context_lens[:bs] += 1 - positions += 1 + fuse_mtp = positions.ndim == 1 and getattr( + self.runner.attn_metadata_builder, + "fuse_mtp_decode_position_update", + False, + ) + if fuse_mtp: + mtp_decode_kwargs = { + "update_context_lens": True, + "positions_out": positions, + } + else: + attn_metadata.context_lens[:bs] += 1 + positions += 1 + mtp_decode_kwargs = {} workinfos = self.runner.attn_metadata_builder.prepare_mtp_decode( bs, ( @@ -486,10 +538,11 @@ def propose( positions, only_update=do_attn_metadata_update, num_reject_tokens=num_reject_tokens if i == 0 else None, + **mtp_decode_kwargs, ) for k, v in workinfos.items(): attn_metadata.__dict__[k] = v - if has_flat_kv: + if has_flat_kv and "slot_mapping" not in workinfos: # MLA/MHA path: slot derived from flat kv_indices. slot_mapping[:] = kv_indices[kv_indptr[1 : bs + 1] - 1] diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 8baead10ce..3bd2a53abb 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -73,6 +73,12 @@ "ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: ( os.getenv("ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1") == "1" ), + # Replicate the EAGLE3 draft vocab embedding on every TP rank (full table per + # rank, local lookup) instead of sharding it — eliminates the post-embedding + # all-reduce. The draft embed is independent of the (sharded) lm_head. + "ATOM_EAGLE_REPLICATE_EMBED": lambda: ( + os.getenv("ATOM_EAGLE_REPLICATE_EMBED", "1") == "1" + ), "ATOM_ENABLE_GDN_DECODE_LOSSY_FAST": lambda: ( os.getenv("ATOM_ENABLE_GDN_DECODE_LOSSY_FAST", "0").lower() == "1" ), diff --git a/recipes/MiniMax-M3.md b/recipes/MiniMax-M3.md index afa324e894..4b3e6966bf 100644 --- a/recipes/MiniMax-M3.md +++ b/recipes/MiniMax-M3.md @@ -145,4 +145,113 @@ Reference MXFP8 results from the validated run on 4xMI355 GPUs: | 8 | 80 | 103.52 | 323.33 | 1284.59 | 10.67 | 11.31 | 715.51 | 6364.77 | | 16 | 160 | 143.25 | 414.95 | 2411.41 | 14.80 | 16.44 | 1022.17 | 9224.81 | | 32 | 320 | 208.34 | 565.02 | 4936.02 | 21.42 | 24.16 | 1421.47 | 12711.25 | -| 64 | 640 | 305.81 | 893.93 | 9610.43 | 31.69 | 37.31 | 1929.04 | 17387.94 | \ No newline at end of file +| 64 | 640 | 305.81 | 893.93 | 9610.43 | 31.69 | 37.31 | 1929.04 | 17387.94 | + +## EAGLE3 Speculative Decoding + +EAGLE3 runs a small single-layer draft model alongside the MiniMax-M3 target to +propose multiple tokens per step, which the target then verifies. It is lossless +with respect to the target's greedy output. The draft checkpoint is +[`Inferact/MiniMax-M3-EAGLE3`](https://huggingface.co/Inferact/MiniMax-M3-EAGLE3). +Enable it by adding three flags to any of the server commands above: + +- `--method eagle3` +- `--draft-model Inferact/MiniMax-M3-EAGLE3` +- `--num-speculative-tokens 3` + +### Launching Server + +The following starts the MXFP4 target with the EAGLE3 draft on 4xMI355 (the FP4 +server command above plus the three speculative-decoding flags): + +```bash +model_path=amd/MiniMax-M3-MXFP4 +draft_path=Inferact/MiniMax-M3-EAGLE3 + +export ATOM_FORCE_ATTN_TRITON=1 +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +export AITER_QUICK_REDUCE_CAST_BF16_TO_FP16=0 + +python -m atom.entrypoints.openai_server \ + --model "$model_path" \ + --tensor-parallel-size 4 \ + --server-port 8000 \ + --trust-remote-code \ + --gpu-memory-utilization 0.8 \ + --block-size 128 \ + --max-model-len 32768 \ + --max-num-seqs 256 \ + --kv_cache_dtype fp8 \ + --max-num-batched-tokens 32768 \ + --no-enable_prefix_caching \ + --method eagle3 \ + --draft-model "$draft_path" \ + --num-speculative-tokens 3 2>&1 | tee m3-mxfp4-eagle3-server.log +``` + +### Accuracy Test + +Run GSM8K 5-shot with `lm_eval` (identical to the non-speculative test): + +```bash +model_path=amd/MiniMax-M3-MXFP4 +model_path=MiniMaxAI/MiniMax-M3-MXFP8 +BS=65 + +lm_eval \ + --model local-chat-completions \ + --model_args "model=$model_path,base_url=http://127.0.0.1:8000/v1/chat/completions,num_concurrent=32,max_gen_toks=16384" \ + --tasks gsm8k \ + --num_fewshot 5 \ + --batch_size "${BS}" \ + --apply_chat_template \ + --fewshot_as_multiturn 2>&1 | tee m3-mxfp4-eagle3-bs65-accuracy.log +``` + +Validated MXFP4+EAGLE GSM8K result: + +```text +| Case | ATOM Commit | GSM8K flexible-extract | GSM8K strict-match | Accept ratio | Avg toks/fwd | Accepted / Total Draft | +|---|---:|---:|---:|---:|---:|---:| +| `fp4_eagle_tp4` | `9fc48338` | `0.9469 ± 0.0062` | `0.9477 ± 0.0061` | `73.36%` | `3.20` | `90229 / 123000` | + +MiniMax-M3 Eagle accepted tokens distribution: +`{0: 14.40%, 1: 12.00%, 2: 12.73%, 3: 60.87%}` +``` + +### Serving Benchmark + +The following script can be used to benchmark online serving throughput and latency: + +```bash +model_path=${model_path:-amd/MiniMax-M3-MXFP4} +ISL=8192 +OSL=1024 +CONC=16 + +python -m atom.benchmarks.benchmark_serving \ + --model="$model_path" \ + --backend=vllm \ + --base-url=http://localhost:8000 \ + --dataset-name=random \ + --random-input-len="${ISL}" \ + --random-output-len="${OSL}" \ + --random-range-ratio=0.8 \ + --num-prompts=$(( CONC * 10 )) \ + --max-concurrency="${CONC}" \ + --request-rate=inf \ + --ignore-eos \ + --save-result \ + --use-chat-template \ + --percentile-metrics="ttft,tpot,itl,e2el" +``` + +Reference MXFP4 EAGLE3 results from our run on 4xMI355 GPUs: + +| CONC | Requests | Duration (s) | Mean TTFT (ms) | P99 TTFT (ms) | Mean TPOT (ms) | P99 TPOT (ms) | Output tok/s | Total tok/s | +|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 4 | 40 | 43.38 | 287.09 | 755.46 | 4.27 | 7.78 | 850.53 | 7653.56 | +| 8 | 80 | 59.31 | 343.81 | 1516.38 | 5.93 | 10.85 | 1251.08 | 11146.00 | +| 16 | 160 | 78.17 | 430.34 | 2680.95 | 7.91 | 15.58 | 1876.30 | 16928.43 | +| 32 | 320 | 125.69 | 609.24 | 5304.23 | 12.60 | 23.81 | 2355.93 | 21132.49 | +| 64 | 640 | 198.58 | 966.20 | 10476.78 | 19.97 | 40.44 | 2973.94 | 26857.80 |