diff --git a/atom/config.py b/atom/config.py index b5fdd0d60..bdba10640 100644 --- a/atom/config.py +++ b/atom/config.py @@ -726,6 +726,16 @@ def _normalize_minimax_m3_text_config(hf_config: PretrainedConfig) -> None: if getattr(text_config, "swiglu_beta", None) is None: text_config.swiglu_beta = 1.0 + for attr_name in ( + "use_index_cache", + "index_topk_freq", + "index_topk_pattern", + "index_skip_topk_offset", + ): + attr_value = getattr(hf_config, attr_name, None) + if attr_value is not None: + setattr(text_config, attr_name, attr_value) + for attr_name, attr_value in vars(text_config).items(): if attr_name.startswith("_") or getattr(hf_config, attr_name, None) is not None: continue @@ -1239,11 +1249,29 @@ def compute_hash(self) -> str: factors.append(vllm_factors) factors.append(self.tensor_parallel_size) factors.append(self.enable_dp_attention) + text_config = getattr(self.hf_config, "text_config", self.hf_config) factors.append( ( - getattr(self.hf_config, "use_index_cache", False), - getattr(self.hf_config, "index_topk_freq", None), - getattr(self.hf_config, "index_topk_pattern", None), + getattr( + text_config, + "use_index_cache", + getattr(self.hf_config, "use_index_cache", False), + ), + getattr( + text_config, + "index_topk_freq", + getattr(self.hf_config, "index_topk_freq", None), + ), + getattr( + text_config, + "index_topk_pattern", + getattr(self.hf_config, "index_topk_pattern", None), + ), + getattr( + text_config, + "index_skip_topk_offset", + getattr(self.hf_config, "index_skip_topk_offset", None), + ), ) ) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 91ba9a6ca..eee5503e1 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -882,7 +882,12 @@ def forward( **kwargs, ): return self.forward_impl( - q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv + q=query, + k=key, + v=value, + position=position, + q_scale=q_scale, + qkv=qkv, ) @@ -937,6 +942,8 @@ def __init__( topk: int = 0, init_blocks: int = 0, local_blocks: int = 0, + skip_index_topk: bool = False, + sparse_layer_ordinal: int = -1, **kwargs, ): super().__init__( @@ -972,9 +979,15 @@ def __init__( self.topk = topk self.init_blocks = init_blocks self.local_blocks = local_blocks + self.skip_index_topk = skip_index_topk + self.sparse_layer_ordinal = sparse_layer_ordinal # Bound by AiterAttentionMetadataBuilder.build_kv_cache_tensor (Task 6): # the page-128 indexer-key cache. None until the runner binds it. self.index_cache: Optional[torch.Tensor] = None + # Optional shared dict bound by the metadata builder. It is scoped to the + # current sparse metadata object and carries the last full layer top-k. + self.index_topk_cache_state: Optional[dict] = None + self._index_q_cache_key_info: Optional[tuple] = None # Rotated indexer query produced by rope_cache, consumed (and cleared) by # dispatch_backend within the same single-threaded layer forward. self._index_q: Optional[torch.Tensor] = None @@ -1013,8 +1026,8 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): """MiniMax-M3 fused qk-norm + partial-NeoX-RoPE + page-16 SHUFFLE KV insert + indexer-key insert, via ``aiter.fused_qknorm_idxrqknorm``. - Consumes the PACKED ``qkv`` tensor (Gemma (1+w) norm path needs it) laid - out as ``[q | k | v | index_q | index_k]``. Writes: + Consumes the packed ``qkv`` tensor laid out as + ``[q | k | v | index_q | index_k]``. Writes: * normed+roped main K/V -> SHUFFLE K/V cache (asm_layout=True) * normed+roped index_k -> page-128 index_cache * fp8 per-token dequant scales -> k_scale / v_scale (when fp8) @@ -1055,14 +1068,6 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): qkv = qkv.contiguous() num_tokens = qkv.shape[0] - q_out = torch.empty( - (num_tokens, self.num_heads * self.head_dim), - dtype=qkv.dtype, - device=qkv.device, - ) - index_q = torch.empty( - (num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device - ) from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache cos_sin_cache = _minimax_m3_cos_sin_cache(self.rotary_emb, qkv) @@ -1074,6 +1079,55 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): fused_k_scale = k_scale if is_fp8 else None fused_v_scale = v_scale if is_fp8 else None + if self.skip_index_topk: + from atom.model_ops.triton_fused_qkv_norm_rope_cache import ( + triton_fused_norm_rope_cache, + ) + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + q_raw, k_raw, v_raw, _, _ = torch.split( + qkv, + [q_size, kv_size, kv_size, self.index_q_size, self.index_head_dim], + dim=-1, + ) + q_out, k_out = triton_fused_norm_rope_cache( + q_raw, + k_raw, + v_raw, + position, + q_norm=self.q_norm, + k_norm=self.k_norm, + rotary_emb=self.rotary_emb, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + k_cache=k_cache, + v_cache=v_cache, + k_scale=fused_k_scale, + v_scale=fused_v_scale, + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + q = q_out.view(-1, self.num_heads, self.head_dim) + k = k_out.view(-1, self.num_kv_heads, self.head_dim) + v = v_raw.view(-1, self.num_kv_heads, self.head_dim) + self._index_q = None + self._index_q_cache_key_info = ( + (num_tokens, self.num_idx_heads, self.index_head_dim), + qkv.dtype, + qkv.device, + ) + return q, k, v, k_cache, v_cache, k_scale, v_scale + + q_out = torch.empty( + (num_tokens, self.num_heads * self.head_dim), + dtype=qkv.dtype, + device=qkv.device, + ) + index_q = torch.empty( + (num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device + ) aiter.fused_qknorm_idxrqknorm( qkv, self.q_norm.weight, @@ -1105,6 +1159,11 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): # Stash the rotated indexer query for dispatch_backend (same-forward, # single-threaded; cleared after the sparse backend consumes it). self._index_q = index_q.view(-1, self.num_idx_heads, self.index_head_dim) + self._index_q_cache_key_info = ( + tuple(self._index_q.shape), + self._index_q.dtype, + self._index_q.device, + ) return q, k, v, k_cache, v_cache, k_scale, v_scale @@ -1132,6 +1191,72 @@ def _sparse_metadata(self, fwd_ctx: ForwardContext): sm = getattr(attn_metadata, "sparse_attention_metadata", None) return sm if sm is not None else attn_metadata + def _topk_cache_state(self, sparse_metadata): + state = self.index_topk_cache_state + if state is None: + return None + metadata_id = id(sparse_metadata) + if state.get("metadata_id") != metadata_id: + state.clear() + state["metadata_id"] = metadata_id + return state + + def _topk_cache_key( + self, + mode: str, + index_q: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + max_query_len: int, + max_seq_len: int, + ) -> tuple: + if index_q is None: + if self._index_q_cache_key_info is None: + raise RuntimeError( + "MiniMax-M3 index cache key missing index_q metadata" + ) + index_q_shape, index_q_dtype, index_q_device = self._index_q_cache_key_info + else: + index_q_shape = tuple(index_q.shape) + index_q_dtype = index_q.dtype + index_q_device = index_q.device + return ( + mode, + index_q_shape, + index_q_dtype, + index_q_device, + tuple(block_table.shape), + tuple(block_table.stride()), + tuple(seq_lens.shape), + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + max_query_len, + max_seq_len, + ) + + def _load_cached_topk(self, sparse_metadata, key: tuple): + if not self.skip_index_topk: + return None + state = self._topk_cache_state(sparse_metadata) + if state is None: + return None + entry = state.get("topk") + if entry is None or entry.get("key") != key: + return None + return entry["value"] + + def _store_cached_topk(self, sparse_metadata, key: tuple, value: tuple): + state = self._topk_cache_state(sparse_metadata) + if state is not None: + state["topk"] = { + "key": key, + "value": value, + "layer_num": self.layer_num, + "sparse_layer_ordinal": self.sparse_layer_ordinal, + } + @mark_trace(prefix="sparse_attention_prefill", torch_compile=False) def _sparse_prefill( self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext @@ -1150,22 +1275,39 @@ def _sparse_prefill( prefix_lens = prefill_md.context_lens block_tables = prefill_md.block_table - topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk( + topk_key = self._topk_cache_key( + "prefill", index_q, - self.index_cache, block_tables, - cu_seqlens_q, seq_lens, - prefix_lens, prefill_md.max_query_len, prefill_md.max_seq_len, - self.topk, - self.init_blocks, - self.local_blocks, - self.num_kv_heads, - self.scale, - emit_sparse_block_table=True, ) + cached_topk = self._load_cached_topk(sparse_metadata, topk_key) + if cached_topk is None: + if index_q is None: + raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer") + topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk( + index_q, + self.index_cache, + block_tables, + cu_seqlens_q, + seq_lens, + prefix_lens, + prefill_md.max_query_len, + prefill_md.max_seq_len, + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + ) + self._store_cached_topk( + sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx) + ) + else: + topk_idx, sparse_bt, sparse_ctx = cached_topk output = torch.empty_like(q) minimax_m3_sparse_attn_prefill_asm( q, @@ -1188,6 +1330,7 @@ def _sparse_prefill( ) output = output.view(*q.shape) self._index_q = None + self._index_q_cache_key_info = None return output @mark_trace(prefix="sparse_attention_decode", torch_compile=False) @@ -1205,20 +1348,37 @@ def _sparse_decode( assert decode_md is not None, "sparse decode metadata missing" max_query_len = getattr(decode_md, "max_query_len", 1) - topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode( + topk_key = self._topk_cache_key( + "decode", index_q, - self.index_cache, decode_md.block_table, decode_md.seq_lens, + max_query_len, sparse_metadata.max_seq_len, - self.topk, - self.init_blocks, - self.local_blocks, - self.num_kv_heads, - self.scale, - emit_sparse_block_table=True, - max_query_len=max_query_len, ) + cached_topk = self._load_cached_topk(sparse_metadata, topk_key) + if cached_topk is None: + if index_q is None: + raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer") + topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode( + index_q, + self.index_cache, + decode_md.block_table, + decode_md.seq_lens, + sparse_metadata.max_seq_len, + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + max_query_len=max_query_len, + ) + self._store_cached_topk( + sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx) + ) + else: + topk_idx, sparse_bt, sparse_ctx = cached_topk output = torch.empty_like(q) minimax_m3_sparse_attn_decode_asm( q, @@ -1236,5 +1396,6 @@ def _sparse_decode( sparse_ctx=sparse_ctx, ) self._index_q = None + self._index_q_cache_key_info = None output = output.view(*q.shape) return output diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 0897f7da9..a7b1a9022 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -499,6 +499,10 @@ def allocate_kv_cache_tensors( device="cuda", ) tensors["_sparse_attention_cache_next"] = 0 + if getattr(text_config, "use_index_cache", False) or getattr( + hf_config, "use_index_cache", False + ): + tensors["_sparse_attention_topk_cache_state"] = {} return tensors def build_kv_cache_tensor(self, layer_id: int, module): @@ -530,6 +534,9 @@ def build_kv_cache_tensor(self, layer_id: int, module): runner._sparse_attention_cache_next += 1 module.impl.index_cache = runner.sparse_attention_index_cache[sparse_idx] module.impl.max_model_len = runner.config.max_model_len + module.impl.index_topk_cache_state = getattr( + runner, "_sparse_attention_topk_cache_state", None + ) # NOTE: no return — fall through to the standard MHA binding below. if not ( diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 8296f5ed0..03f5177a6 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -62,6 +62,40 @@ def _sparse_attention_layer_ids(config: PretrainedConfig) -> set[int]: return {i for i, enabled in enumerate(freq) if enabled != 0} +def _sparse_attention_layer_ordinals(config: PretrainedConfig) -> dict[int, int]: + return { + layer_id: ordinal + for ordinal, layer_id in enumerate(sorted(_sparse_attention_layer_ids(config))) + } + + +def _should_skip_minimax_m3_index_topk( + config: PretrainedConfig, layer_id: int +) -> tuple[bool, int]: + sparse_ordinals = _sparse_attention_layer_ordinals(config) + sparse_ordinal = sparse_ordinals.get(layer_id, -1) + if sparse_ordinal < 0: + return False, sparse_ordinal + if not getattr(config, "use_index_cache", False): + return False, sparse_ordinal + + index_topk_freq = int(getattr(config, "index_topk_freq", 1) or 1) + index_topk_pattern = getattr(config, "index_topk_pattern", None) + if index_topk_pattern is not None: + if 0 <= sparse_ordinal < len(index_topk_pattern): + return index_topk_pattern[sparse_ordinal] == "S", sparse_ordinal + return False, sparse_ordinal + + if index_topk_freq <= 0: + raise ValueError("index_topk_freq must be a positive integer") + if index_topk_freq == 1: + return False, sparse_ordinal + + # MiniMax-M3 schedules sharing by sparse-layer ordinal, not absolute layer id. + offset = int(getattr(config, "index_skip_topk_offset", 0)) + return max(sparse_ordinal - offset, 0) % index_topk_freq != 0, sparse_ordinal + + def _is_moe_layer(config: PretrainedConfig, layer_id: int) -> bool: moe_layer_freq = getattr(config, "moe_layer_freq", None) if moe_layer_freq is None: @@ -402,6 +436,9 @@ def __init__( self.topk_blocks = sparse_cfg["sparse_topk_blocks"] self.init_blocks = sparse_cfg.get("sparse_init_block", 0) self.local_blocks = sparse_cfg.get("sparse_local_block", 0) + self.skip_index_topk, self.sparse_layer_ordinal = ( + _should_skip_minimax_m3_index_topk(config, layer_id) + ) score_type = sparse_cfg.get("sparse_score_type", "max") if score_type != "max": raise ValueError( @@ -472,6 +509,8 @@ def __init__( topk=self.topk_blocks, init_blocks=self.init_blocks, local_blocks=self.local_blocks, + skip_index_topk=self.skip_index_topk, + sparse_layer_ordinal=self.sparse_layer_ordinal, ) def forward( @@ -479,10 +518,8 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # Packed qkv = [q | k | v | index_q | index_k]. The sparse impl's - # rope_cache consumes the packed tensor directly (passed via qkv=); the - # split q/k/v slices satisfy forward_impl's view contract (they are - # rebuilt from qkv inside rope_cache). + # Keep index Q/K packed with main QKV. Layers that reuse cached top-k skip + # the indexer norm/rope/top-k path, but still compute the packed GEMM. qkv = self.qkv_proj(hidden_states) q, k, v, _, _ = qkv.split( [