From bf115d5b350b15b14e62357a16c67c1ebf34c91c Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 25 Jun 2026 04:11:43 -0500 Subject: [PATCH 1/5] feat(minimax-m3): split index cache projection Route MiniMax-M3 index Q/K through a separate projection and thread it through the attention stack so cached top-k layers can skip indexer work while preserving the non-cache path. Co-authored-by: Cursor --- atom/config.py | 34 ++- atom/model_ops/attention_mha.py | 257 ++++++++++++++++--- atom/model_ops/attentions/aiter_attention.py | 7 + atom/model_ops/base_attention.py | 3 + atom/model_ops/linear.py | 99 +++++++ atom/model_ops/paged_attention.py | 11 +- atom/models/minimax_m3.py | 74 ++++-- 7 files changed, 429 insertions(+), 56 deletions(-) diff --git a/atom/config.py b/atom/config.py index b5fdd0d60..bdba10640 100644 --- a/atom/config.py +++ b/atom/config.py @@ -726,6 +726,16 @@ def _normalize_minimax_m3_text_config(hf_config: PretrainedConfig) -> None: if getattr(text_config, "swiglu_beta", None) is None: text_config.swiglu_beta = 1.0 + for attr_name in ( + "use_index_cache", + "index_topk_freq", + "index_topk_pattern", + "index_skip_topk_offset", + ): + attr_value = getattr(hf_config, attr_name, None) + if attr_value is not None: + setattr(text_config, attr_name, attr_value) + for attr_name, attr_value in vars(text_config).items(): if attr_name.startswith("_") or getattr(hf_config, attr_name, None) is not None: continue @@ -1239,11 +1249,29 @@ def compute_hash(self) -> str: factors.append(vllm_factors) factors.append(self.tensor_parallel_size) factors.append(self.enable_dp_attention) + text_config = getattr(self.hf_config, "text_config", self.hf_config) factors.append( ( - getattr(self.hf_config, "use_index_cache", False), - getattr(self.hf_config, "index_topk_freq", None), - getattr(self.hf_config, "index_topk_pattern", None), + getattr( + text_config, + "use_index_cache", + getattr(self.hf_config, "use_index_cache", False), + ), + getattr( + text_config, + "index_topk_freq", + getattr(self.hf_config, "index_topk_freq", None), + ), + getattr( + text_config, + "index_topk_pattern", + getattr(self.hf_config, "index_topk_pattern", None), + ), + getattr( + text_config, + "index_skip_topk_offset", + getattr(self.hf_config, "index_skip_topk_offset", None), + ), ) ) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 91ba9a6ca..2557b6af3 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os from functools import cache from typing import Optional @@ -168,6 +169,7 @@ def forward_impl( position: torch.Tensor = None, q_scale: torch.Tensor = None, qkv: torch.Tensor = None, + index_qk: torch.Tensor = None, ): fwd_ctx: ForwardContext = get_forward_context() @@ -184,7 +186,7 @@ def forward_impl( # rope cache q, k, v, k_cache, v_cache, k_scale, v_scale = self.rope_cache( - q, k, v, qkv, position, fwd_ctx + q, k, v, qkv, index_qk, position, fwd_ctx ) attn_impl = self.dispatch_backend(fwd_ctx, q, k, v) @@ -195,7 +197,7 @@ def forward_impl( return o @mark_trace(prefix="rope_cache", torch_compile=False) - def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): + def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): attn_metadata = fwd_ctx.attn_metadata kv_cache_data = fwd_ctx.kv_cache_data @@ -878,11 +880,18 @@ def forward( position: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, + index_qk: torch.Tensor = None, output: torch.Tensor = None, **kwargs, ): return self.forward_impl( - q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv + q=query, + k=key, + v=value, + position=position, + q_scale=q_scale, + qkv=qkv, + index_qk=index_qk, ) @@ -937,6 +946,8 @@ def __init__( topk: int = 0, init_blocks: int = 0, local_blocks: int = 0, + skip_index_topk: bool = False, + sparse_layer_ordinal: int = -1, **kwargs, ): super().__init__( @@ -972,9 +983,15 @@ def __init__( self.topk = topk self.init_blocks = init_blocks self.local_blocks = local_blocks + self.skip_index_topk = skip_index_topk + self.sparse_layer_ordinal = sparse_layer_ordinal # Bound by AiterAttentionMetadataBuilder.build_kv_cache_tensor (Task 6): # the page-128 indexer-key cache. None until the runner binds it. self.index_cache: Optional[torch.Tensor] = None + # Optional shared dict bound by the metadata builder. It is scoped to the + # current sparse metadata object and carries the last full layer top-k. + self.index_topk_cache_state: Optional[dict] = None + self._index_q_cache_key_info: Optional[tuple] = None # Rotated indexer query produced by rope_cache, consumed (and cleared) by # dispatch_backend within the same single-threaded layer forward. self._index_q: Optional[torch.Tensor] = None @@ -1009,12 +1026,12 @@ def _to_page16_shuffle(k_cache, v_cache, k_scale, v_scale): return k16, v16, k_scale, v_scale @mark_trace(prefix="rope_cache", torch_compile=False) - def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): + def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): """MiniMax-M3 fused qk-norm + partial-NeoX-RoPE + page-16 SHUFFLE KV insert + indexer-key insert, via ``aiter.fused_qknorm_idxrqknorm``. - Consumes the PACKED ``qkv`` tensor (Gemma (1+w) norm path needs it) laid - out as ``[q | k | v | index_q | index_k]``. Writes: + Consumes the main ``qkv`` tensor laid out as ``[q | k | v]`` plus an + optional ``index_qk`` tensor laid out as ``[index_q | index_k]``. Writes: * normed+roped main K/V -> SHUFFLE K/V cache (asm_layout=True) * normed+roped index_k -> page-128 index_cache * fp8 per-token dequant scales -> k_scale / v_scale (when fp8) @@ -1054,15 +1071,9 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): slot_mapping = sparse_metadata.slot_mapping qkv = qkv.contiguous() + if index_qk is not None: + index_qk = index_qk.contiguous() num_tokens = qkv.shape[0] - q_out = torch.empty( - (num_tokens, self.num_heads * self.head_dim), - dtype=qkv.dtype, - device=qkv.device, - ) - index_q = torch.empty( - (num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device - ) from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache cos_sin_cache = _minimax_m3_cos_sin_cache(self.rotary_emb, qkv) @@ -1074,6 +1085,56 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): fused_k_scale = k_scale if is_fp8 else None fused_v_scale = v_scale if is_fp8 else None + if self.skip_index_topk: + from atom.model_ops.triton_fused_qkv_norm_rope_cache import ( + triton_fused_norm_rope_cache, + ) + + state = self._topk_cache_state(sparse_metadata) + if state is not None: + self._debug_topk_cache_event(state, "skip_indexer_rope", ("rope",)) + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + q_raw, k_raw, v_raw = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + q_out, k_out = triton_fused_norm_rope_cache( + q_raw, + k_raw, + v_raw, + position, + q_norm=self.q_norm, + k_norm=self.k_norm, + rotary_emb=self.rotary_emb, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + k_cache=k_cache, + v_cache=v_cache, + k_scale=fused_k_scale, + v_scale=fused_v_scale, + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + q = q_out.view(-1, self.num_heads, self.head_dim) + k = k_out.view(-1, self.num_kv_heads, self.head_dim) + v = v_raw.view(-1, self.num_kv_heads, self.head_dim) + self._index_q = None + self._index_q_cache_key_info = ( + (num_tokens, self.num_idx_heads, self.index_head_dim), + qkv.dtype, + qkv.device, + ) + return q, k, v, k_cache, v_cache, k_scale, v_scale + + q_out = torch.empty( + (num_tokens, self.num_heads * self.head_dim), + dtype=qkv.dtype, + device=qkv.device, + ) + if index_qk is None: + raise RuntimeError("MiniMax-M3 non-skip sparse layer requires index_qk") + index_q = torch.empty( + (num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device + ) aiter.fused_qknorm_idxrqknorm( qkv, self.q_norm.weight, @@ -1099,12 +1160,18 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): k_scale=fused_k_scale, v_scale=fused_v_scale, asm_layout=True, + index_qk=index_qk, ) q = q_out.view(-1, self.num_heads, self.head_dim) # Stash the rotated indexer query for dispatch_backend (same-forward, # single-threaded; cleared after the sparse backend consumes it). self._index_q = index_q.view(-1, self.num_idx_heads, self.index_head_dim) + self._index_q_cache_key_info = ( + tuple(self._index_q.shape), + self._index_q.dtype, + self._index_q.device, + ) return q, k, v, k_cache, v_cache, k_scale, v_scale @@ -1132,6 +1199,94 @@ def _sparse_metadata(self, fwd_ctx: ForwardContext): sm = getattr(attn_metadata, "sparse_attention_metadata", None) return sm if sm is not None else attn_metadata + def _topk_cache_state(self, sparse_metadata): + state = self.index_topk_cache_state + if state is None: + return None + metadata_id = id(sparse_metadata) + if state.get("metadata_id") != metadata_id: + debug_totals = state.get("_debug_totals") + state.clear() + if debug_totals is not None: + state["_debug_totals"] = debug_totals + state["metadata_id"] = metadata_id + return state + + def _debug_topk_cache_event(self, state: dict, event: str, key: tuple): + if os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE") != "1": + return + totals = state.setdefault("_debug_totals", {}) + mode = key[0] if key else "unknown" + counter_key = f"{mode}_{event}" + totals[counter_key] = totals.get(counter_key, 0) + 1 + total_events = sum(totals.values()) + interval = int(os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE_INTERVAL", "256")) + if total_events <= 32 or (interval > 0 and total_events % interval == 0): + print( + "[MiniMax-M3 IndexCache] " + f"pid={os.getpid()} event={event} mode={mode} " + f"layer={self.layer_num} sparse_ord={self.sparse_layer_ordinal} " + f"skip={self.skip_index_topk} totals={totals}", + flush=True, + ) + + def _topk_cache_key( + self, + mode: str, + index_q: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + max_query_len: int, + max_seq_len: int, + ) -> tuple: + if index_q is None: + if self._index_q_cache_key_info is None: + raise RuntimeError("MiniMax-M3 index cache key missing index_q metadata") + index_q_shape, index_q_dtype, index_q_device = self._index_q_cache_key_info + else: + index_q_shape = tuple(index_q.shape) + index_q_dtype = index_q.dtype + index_q_device = index_q.device + return ( + mode, + index_q_shape, + index_q_dtype, + index_q_device, + tuple(block_table.shape), + tuple(block_table.stride()), + tuple(seq_lens.shape), + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + max_query_len, + max_seq_len, + ) + + def _load_cached_topk(self, sparse_metadata, key: tuple): + if not self.skip_index_topk: + return None + state = self._topk_cache_state(sparse_metadata) + if state is None: + return None + entry = state.get("topk") + if entry is None or entry.get("key") != key: + self._debug_topk_cache_event(state, "miss", key) + return None + self._debug_topk_cache_event(state, "hit", key) + return entry["value"] + + def _store_cached_topk(self, sparse_metadata, key: tuple, value: tuple): + state = self._topk_cache_state(sparse_metadata) + if state is not None: + state["topk"] = { + "key": key, + "value": value, + "layer_num": self.layer_num, + "sparse_layer_ordinal": self.sparse_layer_ordinal, + } + self._debug_topk_cache_event(state, "store", key) + @mark_trace(prefix="sparse_attention_prefill", torch_compile=False) def _sparse_prefill( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext @@ -1150,22 +1305,39 @@ def _sparse_prefill( prefix_lens = prefill_md.context_lens block_tables = prefill_md.block_table - topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk( + topk_key = self._topk_cache_key( + "prefill", index_q, - self.index_cache, block_tables, - cu_seqlens_q, seq_lens, - prefix_lens, prefill_md.max_query_len, prefill_md.max_seq_len, - self.topk, - self.init_blocks, - self.local_blocks, - self.num_kv_heads, - self.scale, - emit_sparse_block_table=True, ) + cached_topk = self._load_cached_topk(sparse_metadata, topk_key) + if cached_topk is None: + if index_q is None: + raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer") + topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk( + index_q, + self.index_cache, + block_tables, + cu_seqlens_q, + seq_lens, + prefix_lens, + prefill_md.max_query_len, + prefill_md.max_seq_len, + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + ) + self._store_cached_topk( + sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx) + ) + else: + topk_idx, sparse_bt, sparse_ctx = cached_topk output = torch.empty_like(q) minimax_m3_sparse_attn_prefill_asm( q, @@ -1188,6 +1360,7 @@ def _sparse_prefill( ) output = output.view(*q.shape) self._index_q = None + self._index_q_cache_key_info = None return output @mark_trace(prefix="sparse_attention_decode", torch_compile=False) @@ -1205,20 +1378,37 @@ def _sparse_decode( assert decode_md is not None, "sparse decode metadata missing" max_query_len = getattr(decode_md, "max_query_len", 1) - topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode( + topk_key = self._topk_cache_key( + "decode", index_q, - self.index_cache, decode_md.block_table, decode_md.seq_lens, + max_query_len, sparse_metadata.max_seq_len, - self.topk, - self.init_blocks, - self.local_blocks, - self.num_kv_heads, - self.scale, - emit_sparse_block_table=True, - max_query_len=max_query_len, ) + cached_topk = self._load_cached_topk(sparse_metadata, topk_key) + if cached_topk is None: + if index_q is None: + raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer") + topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode( + index_q, + self.index_cache, + decode_md.block_table, + decode_md.seq_lens, + sparse_metadata.max_seq_len, + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + max_query_len=max_query_len, + ) + self._store_cached_topk( + sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx) + ) + else: + topk_idx, sparse_bt, sparse_ctx = cached_topk output = torch.empty_like(q) minimax_m3_sparse_attn_decode_asm( q, @@ -1236,5 +1426,6 @@ def _sparse_decode( sparse_ctx=sparse_ctx, ) self._index_q = None + self._index_q_cache_key_info = None output = output.view(*q.shape) return output diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 0897f7da9..a7b1a9022 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -499,6 +499,10 @@ def allocate_kv_cache_tensors( device="cuda", ) tensors["_sparse_attention_cache_next"] = 0 + if getattr(text_config, "use_index_cache", False) or getattr( + hf_config, "use_index_cache", False + ): + tensors["_sparse_attention_topk_cache_state"] = {} return tensors def build_kv_cache_tensor(self, layer_id: int, module): @@ -530,6 +534,9 @@ def build_kv_cache_tensor(self, layer_id: int, module): runner._sparse_attention_cache_next += 1 module.impl.index_cache = runner.sparse_attention_index_cache[sparse_idx] module.impl.max_model_len = runner.config.max_model_len + module.impl.index_topk_cache_state = getattr( + runner, "_sparse_attention_topk_cache_state", None + ) # NOTE: no return — fall through to the standard MHA binding below. if not ( diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index a086a87b6..1becd039a 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -327,6 +327,7 @@ def fake_( layer_name: str, use_mla: bool, qkv: torch.Tensor, + index_qk: Optional[torch.Tensor], ) -> torch.Tensor: output_shape = list(q.shape) # If we fusion rmsnorm and quant, the input dtype is fp8, but actually we use bf16 for output. @@ -352,6 +353,7 @@ def unified_attention_with_output_base( layer_name: str, use_mla: bool, qkv: torch.Tensor, + index_qk: Optional[torch.Tensor], ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] @@ -371,6 +373,7 @@ def unified_attention_with_output_base( position=positions, q_scale=q_scale, qkv=qkv, + index_qk=index_qk, ) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 8c7180d74..9787e783d 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -1567,6 +1567,105 @@ def weight_loader( param.weight_loader_process(param_data, loaded_weight) +class MinimaxM3IndexerParallelLinear(ColumnParallelLinear): + """MiniMax-M3 lightning-indexer projection for ``[index_q | index_k]``. + + Splitting this from the main QKV projection lets sparse layers that reuse a + cached top-k skip the indexer GEMM entirely. + """ + + def __init__( + self, + hidden_size: int, + total_num_kv_heads: int, + total_num_index_heads: int, + index_head_size: int, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + source_quant_dtype: torch.dtype = None, + prefix: str = "", + **kwargs, + ): + if total_num_index_heads != total_num_kv_heads: + raise ValueError( + "MiniMax-M3 index_q must shard like KV heads: " + "total_num_index_heads must equal total_num_kv_heads." + ) + + self.index_head_size = index_head_size + self.total_num_kv_heads = total_num_kv_heads + self.total_num_index_heads = total_num_index_heads + + tp_size = get_tp_group().world_size + if self.total_num_kv_heads >= tp_size: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + else: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + self.num_index_heads = self.num_kv_heads + + output_sizes = [ + self.num_index_heads * self.index_head_size * tp_size, + self.index_head_size * tp_size, + ] + + super().__init__( + hidden_size, + output_sizes, + bias=bias, + quant_config=quant_config, + source_quant_dtype=source_quant_dtype, + prefix=prefix, + **kwargs, + ) + + def _shard_offset_size(self, loaded_shard_id: str) -> tuple[int, int]: + index_q_size = self.num_index_heads * self.index_head_size + mapping = { + "index_q": (0, index_q_size), + "index_k": (index_q_size, self.index_head_size), + } + if loaded_shard_id not in mapping: + raise ValueError( + "MiniMax-M3 indexer shard id must be one of " + "'index_q', 'index_k'; got " + f"{loaded_shard_id!r}." + ) + return mapping[loaded_shard_id] + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str, + ): + shard_offset, shard_size = self._shard_offset_size(loaded_shard_id) + if param is getattr(self, "weight_scale", None) or param is getattr( + self, "input_scale", None + ): + if self.quant_type == QuantType.per_1x128: + shard_offset = (shard_offset + 127) // 128 + shard_size = (shard_size + 127) // 128 + elif self.quant_type == QuantType.per_Tensor: + loaded_weight = loaded_weight.view(1, 1).repeat(self.tp_size, 1) + shard_offset = ["index_q", "index_k"].index(loaded_shard_id) + shard_size = 1 + + shard_rank = ( + 0 + if loaded_shard_id == "index_k" + else self.tp_rank // self.num_kv_head_replicas + ) + param_data = param.data.narrow(self.tp_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.tp_dim, + shard_rank * shard_size, + shard_size, + ) + param.weight_loader_process(param_data, loaded_weight) + + class RowParallelLinear(LinearBase): def __init__( self, diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 7937ee111..714eec5db 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -116,9 +116,18 @@ def forward( positions: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, + index_qk: Optional[torch.Tensor] = None, **kwargs, ): output = torch.ops.aiter.unified_attention_with_output_base( - query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv + query, + q_scale, + key, + value, + positions, + self.layer_name, + self.use_mla, + qkv, + index_qk, ) return output diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 8296f5ed0..2e883f5b6 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -24,7 +24,7 @@ ) from atom.model_ops import module_dispatch_ops as _module_dispatch_ops # noqa: F401 from atom.model_ops.linear import ( - MinimaxM3QKVParallelLinearWithIndexer, + MinimaxM3IndexerParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -62,6 +62,40 @@ def _sparse_attention_layer_ids(config: PretrainedConfig) -> set[int]: return {i for i, enabled in enumerate(freq) if enabled != 0} +def _sparse_attention_layer_ordinals(config: PretrainedConfig) -> dict[int, int]: + return { + layer_id: ordinal + for ordinal, layer_id in enumerate(sorted(_sparse_attention_layer_ids(config))) + } + + +def _should_skip_minimax_m3_index_topk( + config: PretrainedConfig, layer_id: int +) -> tuple[bool, int]: + sparse_ordinals = _sparse_attention_layer_ordinals(config) + sparse_ordinal = sparse_ordinals.get(layer_id, -1) + if sparse_ordinal < 0: + return False, sparse_ordinal + if not getattr(config, "use_index_cache", False): + return False, sparse_ordinal + + index_topk_freq = int(getattr(config, "index_topk_freq", 1) or 1) + index_topk_pattern = getattr(config, "index_topk_pattern", None) + if index_topk_pattern is not None: + if 0 <= sparse_ordinal < len(index_topk_pattern): + return index_topk_pattern[sparse_ordinal] == "S", sparse_ordinal + return False, sparse_ordinal + + if index_topk_freq <= 0: + raise ValueError("index_topk_freq must be a positive integer") + if index_topk_freq == 1: + return False, sparse_ordinal + + # MiniMax-M3 schedules sharing by sparse-layer ordinal, not absolute layer id. + offset = int(getattr(config, "index_skip_topk_offset", 0)) + return max(sparse_ordinal - offset, 0) % index_topk_freq != 0, sparse_ordinal + + def _is_moe_layer(config: PretrainedConfig, layer_id: int) -> bool: moe_layer_freq = getattr(config, "moe_layer_freq", None) if moe_layer_freq is None: @@ -402,6 +436,9 @@ def __init__( self.topk_blocks = sparse_cfg["sparse_topk_blocks"] self.init_blocks = sparse_cfg.get("sparse_init_block", 0) self.local_blocks = sparse_cfg.get("sparse_local_block", 0) + self.skip_index_topk, self.sparse_layer_ordinal = ( + _should_skip_minimax_m3_index_topk(config, layer_id) + ) score_type = sparse_cfg.get("sparse_score_type", "max") if score_type != "max": raise ValueError( @@ -409,11 +446,18 @@ def __init__( f"sparse_score_type='max', got {score_type!r}." ) - self.qkv_proj = MinimaxM3QKVParallelLinearWithIndexer( + self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.indexer_proj = MinimaxM3IndexerParallelLinear( + self.hidden_size, + self.total_num_kv_heads, self.total_idx_heads, self.idx_head_dim, bias=False, @@ -472,6 +516,8 @@ def __init__( topk=self.topk_blocks, init_blocks=self.init_blocks, local_blocks=self.local_blocks, + skip_index_topk=self.skip_index_topk, + sparse_layer_ordinal=self.sparse_layer_ordinal, ) def forward( @@ -479,22 +525,12 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # Packed qkv = [q | k | v | index_q | index_k]. The sparse impl's - # rope_cache consumes the packed tensor directly (passed via qkv=); the - # split q/k/v slices satisfy forward_impl's view contract (they are - # rebuilt from qkv inside rope_cache). + # Main qkv is always needed. The indexer projection is skipped on layers + # that reuse the previous sparse top-k. qkv = self.qkv_proj(hidden_states) - q, k, v, _, _ = qkv.split( - [ - self.q_size, - self.kv_size, - self.kv_size, - self.index_q_size, - self.idx_head_dim, - ], - dim=-1, - ) - attn_output = self.attn(q, k, v, positions, qkv=qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + index_qk = None if self.skip_index_topk else self.indexer_proj(hidden_states) + attn_output = self.attn(q, k, v, positions, qkv=qkv, index_qk=index_qk) return self.o_proj(attn_output) @@ -659,8 +695,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: class MiniMaxM3SparseForCausalLM(nn.Module): packed_modules_mapping = { - ".index_q_proj": (".qkv_proj", "index_q"), - ".index_k_proj": (".qkv_proj", "index_k"), + ".index_q_proj": (".indexer_proj", "index_q"), + ".index_k_proj": (".indexer_proj", "index_k"), ".q_proj": (".qkv_proj", "q"), ".k_proj": (".qkv_proj", "k"), ".v_proj": (".qkv_proj", "v"), From c9e8546b273d4c3252fba8ae6938f7a0faceb682 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 25 Jun 2026 04:37:26 -0500 Subject: [PATCH 2/5] refactor(minimax-m3): keep indexer qk packed Keep MiniMax-M3 index Q/K in the packed QKV projection so index-cache support only skips top-k work and does not require a separate aiter input ABI. Co-authored-by: Cursor --- atom/model_ops/attention_mha.py | 24 +++---- atom/model_ops/base_attention.py | 3 - atom/model_ops/linear.py | 100 ------------------------------ atom/model_ops/paged_attention.py | 2 - atom/models/minimax_m3.py | 33 +++++----- 5 files changed, 27 insertions(+), 135 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 2557b6af3..82b46fbb4 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -169,7 +169,6 @@ def forward_impl( position: torch.Tensor = None, q_scale: torch.Tensor = None, qkv: torch.Tensor = None, - index_qk: torch.Tensor = None, ): fwd_ctx: ForwardContext = get_forward_context() @@ -186,7 +185,7 @@ def forward_impl( # rope cache q, k, v, k_cache, v_cache, k_scale, v_scale = self.rope_cache( - q, k, v, qkv, index_qk, position, fwd_ctx + q, k, v, qkv, position, fwd_ctx ) attn_impl = self.dispatch_backend(fwd_ctx, q, k, v) @@ -197,7 +196,7 @@ def forward_impl( return o @mark_trace(prefix="rope_cache", torch_compile=False) - def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): + def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): attn_metadata = fwd_ctx.attn_metadata kv_cache_data = fwd_ctx.kv_cache_data @@ -880,7 +879,6 @@ def forward( position: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, - index_qk: torch.Tensor = None, output: torch.Tensor = None, **kwargs, ): @@ -891,7 +889,6 @@ def forward( position=position, q_scale=q_scale, qkv=qkv, - index_qk=index_qk, ) @@ -1026,12 +1023,12 @@ def _to_page16_shuffle(k_cache, v_cache, k_scale, v_scale): return k16, v16, k_scale, v_scale @mark_trace(prefix="rope_cache", torch_compile=False) - def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): + def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): """MiniMax-M3 fused qk-norm + partial-NeoX-RoPE + page-16 SHUFFLE KV insert + indexer-key insert, via ``aiter.fused_qknorm_idxrqknorm``. - Consumes the main ``qkv`` tensor laid out as ``[q | k | v]`` plus an - optional ``index_qk`` tensor laid out as ``[index_q | index_k]``. Writes: + Consumes the packed ``qkv`` tensor laid out as + ``[q | k | v | index_q | index_k]``. Writes: * normed+roped main K/V -> SHUFFLE K/V cache (asm_layout=True) * normed+roped index_k -> page-128 index_cache * fp8 per-token dequant scales -> k_scale / v_scale (when fp8) @@ -1071,8 +1068,6 @@ def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): slot_mapping = sparse_metadata.slot_mapping qkv = qkv.contiguous() - if index_qk is not None: - index_qk = index_qk.contiguous() num_tokens = qkv.shape[0] from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache @@ -1095,7 +1090,11 @@ def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): self._debug_topk_cache_event(state, "skip_indexer_rope", ("rope",)) q_size = self.num_heads * self.head_dim kv_size = self.num_kv_heads * self.head_dim - q_raw, k_raw, v_raw = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + q_raw, k_raw, v_raw, _, _ = torch.split( + qkv, + [q_size, kv_size, kv_size, self.index_q_size, self.index_head_dim], + dim=-1, + ) q_out, k_out = triton_fused_norm_rope_cache( q_raw, k_raw, @@ -1130,8 +1129,6 @@ def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): dtype=qkv.dtype, device=qkv.device, ) - if index_qk is None: - raise RuntimeError("MiniMax-M3 non-skip sparse layer requires index_qk") index_q = torch.empty( (num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device ) @@ -1160,7 +1157,6 @@ def rope_cache(self, q, k, v, qkv, index_qk, position, fwd_ctx: ForwardContext): k_scale=fused_k_scale, v_scale=fused_v_scale, asm_layout=True, - index_qk=index_qk, ) q = q_out.view(-1, self.num_heads, self.head_dim) diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index 1becd039a..a086a87b6 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -327,7 +327,6 @@ def fake_( layer_name: str, use_mla: bool, qkv: torch.Tensor, - index_qk: Optional[torch.Tensor], ) -> torch.Tensor: output_shape = list(q.shape) # If we fusion rmsnorm and quant, the input dtype is fp8, but actually we use bf16 for output. @@ -353,7 +352,6 @@ def unified_attention_with_output_base( layer_name: str, use_mla: bool, qkv: torch.Tensor, - index_qk: Optional[torch.Tensor], ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] @@ -373,7 +371,6 @@ def unified_attention_with_output_base( position=positions, q_scale=q_scale, qkv=qkv, - index_qk=index_qk, ) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 9787e783d..50680000c 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -1566,106 +1566,6 @@ def weight_loader( ) param.weight_loader_process(param_data, loaded_weight) - -class MinimaxM3IndexerParallelLinear(ColumnParallelLinear): - """MiniMax-M3 lightning-indexer projection for ``[index_q | index_k]``. - - Splitting this from the main QKV projection lets sparse layers that reuse a - cached top-k skip the indexer GEMM entirely. - """ - - def __init__( - self, - hidden_size: int, - total_num_kv_heads: int, - total_num_index_heads: int, - index_head_size: int, - bias: bool = False, - quant_config: Optional[QuantizationConfig] = None, - source_quant_dtype: torch.dtype = None, - prefix: str = "", - **kwargs, - ): - if total_num_index_heads != total_num_kv_heads: - raise ValueError( - "MiniMax-M3 index_q must shard like KV heads: " - "total_num_index_heads must equal total_num_kv_heads." - ) - - self.index_head_size = index_head_size - self.total_num_kv_heads = total_num_kv_heads - self.total_num_index_heads = total_num_index_heads - - tp_size = get_tp_group().world_size - if self.total_num_kv_heads >= tp_size: - self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) - self.num_kv_head_replicas = 1 - else: - self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) - self.num_index_heads = self.num_kv_heads - - output_sizes = [ - self.num_index_heads * self.index_head_size * tp_size, - self.index_head_size * tp_size, - ] - - super().__init__( - hidden_size, - output_sizes, - bias=bias, - quant_config=quant_config, - source_quant_dtype=source_quant_dtype, - prefix=prefix, - **kwargs, - ) - - def _shard_offset_size(self, loaded_shard_id: str) -> tuple[int, int]: - index_q_size = self.num_index_heads * self.index_head_size - mapping = { - "index_q": (0, index_q_size), - "index_k": (index_q_size, self.index_head_size), - } - if loaded_shard_id not in mapping: - raise ValueError( - "MiniMax-M3 indexer shard id must be one of " - "'index_q', 'index_k'; got " - f"{loaded_shard_id!r}." - ) - return mapping[loaded_shard_id] - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: str, - ): - shard_offset, shard_size = self._shard_offset_size(loaded_shard_id) - if param is getattr(self, "weight_scale", None) or param is getattr( - self, "input_scale", None - ): - if self.quant_type == QuantType.per_1x128: - shard_offset = (shard_offset + 127) // 128 - shard_size = (shard_size + 127) // 128 - elif self.quant_type == QuantType.per_Tensor: - loaded_weight = loaded_weight.view(1, 1).repeat(self.tp_size, 1) - shard_offset = ["index_q", "index_k"].index(loaded_shard_id) - shard_size = 1 - - shard_rank = ( - 0 - if loaded_shard_id == "index_k" - else self.tp_rank // self.num_kv_head_replicas - ) - param_data = param.data.narrow(self.tp_dim, shard_offset, shard_size) - loaded_weight = loaded_weight.narrow( - self.tp_dim, - shard_rank * shard_size, - shard_size, - ) - param.weight_loader_process(param_data, loaded_weight) - - class RowParallelLinear(LinearBase): def __init__( self, diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 714eec5db..d44b0832f 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -116,7 +116,6 @@ def forward( positions: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, - index_qk: Optional[torch.Tensor] = None, **kwargs, ): output = torch.ops.aiter.unified_attention_with_output_base( @@ -128,6 +127,5 @@ def forward( self.layer_name, self.use_mla, qkv, - index_qk, ) return output diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 2e883f5b6..03f5177a6 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -24,7 +24,7 @@ ) from atom.model_ops import module_dispatch_ops as _module_dispatch_ops # noqa: F401 from atom.model_ops.linear import ( - MinimaxM3IndexerParallelLinear, + MinimaxM3QKVParallelLinearWithIndexer, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -446,18 +446,11 @@ def __init__( f"sparse_score_type='max', got {score_type!r}." ) - self.qkv_proj = QKVParallelLinear( + self.qkv_proj = MinimaxM3QKVParallelLinearWithIndexer( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.indexer_proj = MinimaxM3IndexerParallelLinear( - self.hidden_size, - self.total_num_kv_heads, self.total_idx_heads, self.idx_head_dim, bias=False, @@ -525,12 +518,20 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # Main qkv is always needed. The indexer projection is skipped on layers - # that reuse the previous sparse top-k. + # Keep index Q/K packed with main QKV. Layers that reuse cached top-k skip + # the indexer norm/rope/top-k path, but still compute the packed GEMM. qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - index_qk = None if self.skip_index_topk else self.indexer_proj(hidden_states) - attn_output = self.attn(q, k, v, positions, qkv=qkv, index_qk=index_qk) + q, k, v, _, _ = qkv.split( + [ + self.q_size, + self.kv_size, + self.kv_size, + self.index_q_size, + self.idx_head_dim, + ], + dim=-1, + ) + attn_output = self.attn(q, k, v, positions, qkv=qkv) return self.o_proj(attn_output) @@ -695,8 +696,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: class MiniMaxM3SparseForCausalLM(nn.Module): packed_modules_mapping = { - ".index_q_proj": (".indexer_proj", "index_q"), - ".index_k_proj": (".indexer_proj", "index_k"), + ".index_q_proj": (".qkv_proj", "index_q"), + ".index_k_proj": (".qkv_proj", "index_k"), ".q_proj": (".qkv_proj", "q"), ".k_proj": (".qkv_proj", "k"), ".v_proj": (".qkv_proj", "v"), From 65e76c09c0b843a1726f3f6b1809a7a95ba0c5ac Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 25 Jun 2026 04:38:22 -0500 Subject: [PATCH 3/5] chore(minimax-m3): drop leftover formatting noise Remove residual formatting-only changes from the packed index-cache refactor so the branch only carries functional sparse-attention updates. Co-authored-by: Cursor --- atom/model_ops/linear.py | 1 + atom/model_ops/paged_attention.py | 9 +-------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 50680000c..8c7180d74 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -1566,6 +1566,7 @@ def weight_loader( ) param.weight_loader_process(param_data, loaded_weight) + class RowParallelLinear(LinearBase): def __init__( self, diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index d44b0832f..7937ee111 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -119,13 +119,6 @@ def forward( **kwargs, ): output = torch.ops.aiter.unified_attention_with_output_base( - query, - q_scale, - key, - value, - positions, - self.layer_name, - self.use_mla, - qkv, + query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv ) return output From 573094a4c40cbe9fc513950e14f83744e29b624a Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 25 Jun 2026 09:40:08 +0000 Subject: [PATCH 4/5] code format --- atom/model_ops/attention_mha.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 82b46fbb4..a522a2be0 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -1216,7 +1216,9 @@ def _debug_topk_cache_event(self, state: dict, event: str, key: tuple): counter_key = f"{mode}_{event}" totals[counter_key] = totals.get(counter_key, 0) + 1 total_events = sum(totals.values()) - interval = int(os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE_INTERVAL", "256")) + interval = int( + os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE_INTERVAL", "256") + ) if total_events <= 32 or (interval > 0 and total_events % interval == 0): print( "[MiniMax-M3 IndexCache] " @@ -1237,7 +1239,9 @@ def _topk_cache_key( ) -> tuple: if index_q is None: if self._index_q_cache_key_info is None: - raise RuntimeError("MiniMax-M3 index cache key missing index_q metadata") + raise RuntimeError( + "MiniMax-M3 index cache key missing index_q metadata" + ) index_q_shape, index_q_dtype, index_q_device = self._index_q_cache_key_info else: index_q_shape = tuple(index_q.shape) From 4d12345a8b7cb0af0976a899e840cbbb2e5c1e84 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 25 Jun 2026 04:46:38 -0500 Subject: [PATCH 5/5] chore(minimax-m3): remove index cache debug logging Drop temporary hit/miss logging and counters from the MiniMax-M3 top-k cache path now that the packed index-cache flow is settled. Co-authored-by: Cursor --- atom/model_ops/attention_mha.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index a522a2be0..eee5503e1 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import os from functools import cache from typing import Optional @@ -1085,9 +1084,6 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): triton_fused_norm_rope_cache, ) - state = self._topk_cache_state(sparse_metadata) - if state is not None: - self._debug_topk_cache_event(state, "skip_indexer_rope", ("rope",)) q_size = self.num_heads * self.head_dim kv_size = self.num_kv_heads * self.head_dim q_raw, k_raw, v_raw, _, _ = torch.split( @@ -1201,33 +1197,10 @@ def _topk_cache_state(self, sparse_metadata): return None metadata_id = id(sparse_metadata) if state.get("metadata_id") != metadata_id: - debug_totals = state.get("_debug_totals") state.clear() - if debug_totals is not None: - state["_debug_totals"] = debug_totals state["metadata_id"] = metadata_id return state - def _debug_topk_cache_event(self, state: dict, event: str, key: tuple): - if os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE") != "1": - return - totals = state.setdefault("_debug_totals", {}) - mode = key[0] if key else "unknown" - counter_key = f"{mode}_{event}" - totals[counter_key] = totals.get(counter_key, 0) + 1 - total_events = sum(totals.values()) - interval = int( - os.environ.get("ATOM_DEBUG_MINIMAX_M3_INDEX_CACHE_INTERVAL", "256") - ) - if total_events <= 32 or (interval > 0 and total_events % interval == 0): - print( - "[MiniMax-M3 IndexCache] " - f"pid={os.getpid()} event={event} mode={mode} " - f"layer={self.layer_num} sparse_ord={self.sparse_layer_ordinal} " - f"skip={self.skip_index_topk} totals={totals}", - flush=True, - ) - def _topk_cache_key( self, mode: str, @@ -1271,9 +1244,7 @@ def _load_cached_topk(self, sparse_metadata, key: tuple): return None entry = state.get("topk") if entry is None or entry.get("key") != key: - self._debug_topk_cache_event(state, "miss", key) return None - self._debug_topk_cache_event(state, "hit", key) return entry["value"] def _store_cached_topk(self, sparse_metadata, key: tuple, value: tuple): @@ -1285,7 +1256,6 @@ def _store_cached_topk(self, sparse_metadata, key: tuple, value: tuple): "layer_num": self.layer_num, "sparse_layer_ordinal": self.sparse_layer_ordinal, } - self._debug_topk_cache_event(state, "store", key) @mark_trace(prefix="sparse_attention_prefill", torch_compile=False) def _sparse_prefill(