Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,16 @@ def _normalize_minimax_m3_text_config(hf_config: PretrainedConfig) -> None:
if getattr(text_config, "swiglu_beta", None) is None:
text_config.swiglu_beta = 1.0

for attr_name in (
"use_index_cache",
"index_topk_freq",
"index_topk_pattern",
"index_skip_topk_offset",
):
attr_value = getattr(hf_config, attr_name, None)
if attr_value is not None:
setattr(text_config, attr_name, attr_value)

for attr_name, attr_value in vars(text_config).items():
if attr_name.startswith("_") or getattr(hf_config, attr_name, None) is not None:
continue
Expand Down Expand Up @@ -1239,11 +1249,29 @@ def compute_hash(self) -> str:
factors.append(vllm_factors)
factors.append(self.tensor_parallel_size)
factors.append(self.enable_dp_attention)
text_config = getattr(self.hf_config, "text_config", self.hf_config)
factors.append(
(
getattr(self.hf_config, "use_index_cache", False),
getattr(self.hf_config, "index_topk_freq", None),
getattr(self.hf_config, "index_topk_pattern", None),
getattr(
text_config,
"use_index_cache",
getattr(self.hf_config, "use_index_cache", False),
),
getattr(
text_config,
"index_topk_freq",
getattr(self.hf_config, "index_topk_freq", None),
),
getattr(
text_config,
"index_topk_pattern",
getattr(self.hf_config, "index_topk_pattern", None),
),
getattr(
text_config,
"index_skip_topk_offset",
getattr(self.hf_config, "index_skip_topk_offset", None),
),
)
)

Expand Down
221 changes: 191 additions & 30 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,12 @@ def forward(
**kwargs,
):
return self.forward_impl(
q=query, k=key, v=value, position=position, q_scale=q_scale, qkv=qkv
q=query,
k=key,
v=value,
position=position,
q_scale=q_scale,
qkv=qkv,
)


Expand Down Expand Up @@ -937,6 +942,8 @@ def __init__(
topk: int = 0,
init_blocks: int = 0,
local_blocks: int = 0,
skip_index_topk: bool = False,
sparse_layer_ordinal: int = -1,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -972,9 +979,15 @@ def __init__(
self.topk = topk
self.init_blocks = init_blocks
self.local_blocks = local_blocks
self.skip_index_topk = skip_index_topk
self.sparse_layer_ordinal = sparse_layer_ordinal
# Bound by AiterAttentionMetadataBuilder.build_kv_cache_tensor (Task 6):
# the page-128 indexer-key cache. None until the runner binds it.
self.index_cache: Optional[torch.Tensor] = None
# Optional shared dict bound by the metadata builder. It is scoped to the
# current sparse metadata object and carries the last full layer top-k.
self.index_topk_cache_state: Optional[dict] = None
self._index_q_cache_key_info: Optional[tuple] = None
# Rotated indexer query produced by rope_cache, consumed (and cleared) by
# dispatch_backend within the same single-threaded layer forward.
self._index_q: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -1013,8 +1026,8 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
"""MiniMax-M3 fused qk-norm + partial-NeoX-RoPE + page-16 SHUFFLE KV insert
+ indexer-key insert, via ``aiter.fused_qknorm_idxrqknorm``.

Consumes the PACKED ``qkv`` tensor (Gemma (1+w) norm path needs it) laid
out as ``[q | k | v | index_q | index_k]``. Writes:
Consumes the packed ``qkv`` tensor laid out as
``[q | k | v | index_q | index_k]``. Writes:
* normed+roped main K/V -> SHUFFLE K/V cache (asm_layout=True)
* normed+roped index_k -> page-128 index_cache
* fp8 per-token dequant scales -> k_scale / v_scale (when fp8)
Expand Down Expand Up @@ -1055,14 +1068,6 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):

qkv = qkv.contiguous()
num_tokens = qkv.shape[0]
q_out = torch.empty(
(num_tokens, self.num_heads * self.head_dim),
dtype=qkv.dtype,
device=qkv.device,
)
index_q = torch.empty(
(num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device
)
from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache

cos_sin_cache = _minimax_m3_cos_sin_cache(self.rotary_emb, qkv)
Expand All @@ -1074,6 +1079,55 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
fused_k_scale = k_scale if is_fp8 else None
fused_v_scale = v_scale if is_fp8 else None

if self.skip_index_topk:
from atom.model_ops.triton_fused_qkv_norm_rope_cache import (
triton_fused_norm_rope_cache,
)

Comment on lines +1082 to +1086
q_size = self.num_heads * self.head_dim
kv_size = self.num_kv_heads * self.head_dim
q_raw, k_raw, v_raw, _, _ = torch.split(
qkv,
[q_size, kv_size, kv_size, self.index_q_size, self.index_head_dim],
dim=-1,
)
q_out, k_out = triton_fused_norm_rope_cache(
q_raw,
k_raw,
v_raw,
position,
q_norm=self.q_norm,
k_norm=self.k_norm,
rotary_emb=self.rotary_emb,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
k_cache=k_cache,
v_cache=v_cache,
k_scale=fused_k_scale,
v_scale=fused_v_scale,
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
q = q_out.view(-1, self.num_heads, self.head_dim)
k = k_out.view(-1, self.num_kv_heads, self.head_dim)
v = v_raw.view(-1, self.num_kv_heads, self.head_dim)
self._index_q = None
self._index_q_cache_key_info = (
(num_tokens, self.num_idx_heads, self.index_head_dim),
qkv.dtype,
qkv.device,
)
return q, k, v, k_cache, v_cache, k_scale, v_scale

q_out = torch.empty(
(num_tokens, self.num_heads * self.head_dim),
dtype=qkv.dtype,
device=qkv.device,
)
index_q = torch.empty(
(num_tokens, self.index_q_size), dtype=qkv.dtype, device=qkv.device
)
aiter.fused_qknorm_idxrqknorm(
qkv,
self.q_norm.weight,
Expand Down Expand Up @@ -1105,6 +1159,11 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
# Stash the rotated indexer query for dispatch_backend (same-forward,
# single-threaded; cleared after the sparse backend consumes it).
self._index_q = index_q.view(-1, self.num_idx_heads, self.index_head_dim)
self._index_q_cache_key_info = (
tuple(self._index_q.shape),
self._index_q.dtype,
self._index_q.device,
)

return q, k, v, k_cache, v_cache, k_scale, v_scale

Expand Down Expand Up @@ -1132,6 +1191,72 @@ def _sparse_metadata(self, fwd_ctx: ForwardContext):
sm = getattr(attn_metadata, "sparse_attention_metadata", None)
return sm if sm is not None else attn_metadata

def _topk_cache_state(self, sparse_metadata):
state = self.index_topk_cache_state
if state is None:
return None
metadata_id = id(sparse_metadata)
if state.get("metadata_id") != metadata_id:
state.clear()
state["metadata_id"] = metadata_id
return state

def _topk_cache_key(
self,
mode: str,
index_q: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
max_query_len: int,
max_seq_len: int,
) -> tuple:
if index_q is None:
if self._index_q_cache_key_info is None:
raise RuntimeError(
"MiniMax-M3 index cache key missing index_q metadata"
)
index_q_shape, index_q_dtype, index_q_device = self._index_q_cache_key_info
else:
index_q_shape = tuple(index_q.shape)
index_q_dtype = index_q.dtype
index_q_device = index_q.device
return (
mode,
index_q_shape,
index_q_dtype,
index_q_device,
tuple(block_table.shape),
tuple(block_table.stride()),
tuple(seq_lens.shape),
self.topk,
self.init_blocks,
self.local_blocks,
self.num_kv_heads,
max_query_len,
max_seq_len,
)

def _load_cached_topk(self, sparse_metadata, key: tuple):
if not self.skip_index_topk:
return None
state = self._topk_cache_state(sparse_metadata)
if state is None:
return None
entry = state.get("topk")
if entry is None or entry.get("key") != key:
return None
return entry["value"]

def _store_cached_topk(self, sparse_metadata, key: tuple, value: tuple):
state = self._topk_cache_state(sparse_metadata)
if state is not None:
state["topk"] = {
"key": key,
"value": value,
"layer_num": self.layer_num,
"sparse_layer_ordinal": self.sparse_layer_ordinal,
}

@mark_trace(prefix="sparse_attention_prefill", torch_compile=False)
def _sparse_prefill(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
Expand All @@ -1150,22 +1275,39 @@ def _sparse_prefill(
prefix_lens = prefill_md.context_lens
block_tables = prefill_md.block_table

topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk(
topk_key = self._topk_cache_key(
"prefill",
index_q,
self.index_cache,
block_tables,
cu_seqlens_q,
seq_lens,
prefix_lens,
prefill_md.max_query_len,
prefill_md.max_seq_len,
self.topk,
self.init_blocks,
self.local_blocks,
self.num_kv_heads,
self.scale,
emit_sparse_block_table=True,
)
cached_topk = self._load_cached_topk(sparse_metadata, topk_key)
if cached_topk is None:
if index_q is None:
raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put these part under model itself?…… not must in this pr,but we do need a refact…

topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk(
Comment on lines +1288 to +1290
index_q,
self.index_cache,
block_tables,
cu_seqlens_q,
seq_lens,
prefix_lens,
prefill_md.max_query_len,
prefill_md.max_seq_len,
self.topk,
self.init_blocks,
self.local_blocks,
self.num_kv_heads,
self.scale,
emit_sparse_block_table=True,
)
self._store_cached_topk(
sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx)
)
else:
topk_idx, sparse_bt, sparse_ctx = cached_topk
output = torch.empty_like(q)
minimax_m3_sparse_attn_prefill_asm(
q,
Expand All @@ -1188,6 +1330,7 @@ def _sparse_prefill(
)
output = output.view(*q.shape)
self._index_q = None
self._index_q_cache_key_info = None
return output

@mark_trace(prefix="sparse_attention_decode", torch_compile=False)
Expand All @@ -1205,20 +1348,37 @@ def _sparse_decode(
assert decode_md is not None, "sparse decode metadata missing"
max_query_len = getattr(decode_md, "max_query_len", 1)

topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode(
topk_key = self._topk_cache_key(
"decode",
index_q,
self.index_cache,
decode_md.block_table,
decode_md.seq_lens,
max_query_len,
sparse_metadata.max_seq_len,
self.topk,
self.init_blocks,
self.local_blocks,
self.num_kv_heads,
self.scale,
emit_sparse_block_table=True,
max_query_len=max_query_len,
)
cached_topk = self._load_cached_topk(sparse_metadata, topk_key)
if cached_topk is None:
if index_q is None:
raise RuntimeError("MiniMax-M3 index cache miss on a skip-index layer")
topk_idx, sparse_bt, sparse_ctx = minimax_m3_index_topk_decode(
Comment on lines +1360 to +1363
index_q,
self.index_cache,
decode_md.block_table,
decode_md.seq_lens,
sparse_metadata.max_seq_len,
self.topk,
self.init_blocks,
self.local_blocks,
self.num_kv_heads,
self.scale,
emit_sparse_block_table=True,
max_query_len=max_query_len,
)
self._store_cached_topk(
sparse_metadata, topk_key, (topk_idx, sparse_bt, sparse_ctx)
)
else:
topk_idx, sparse_bt, sparse_ctx = cached_topk
output = torch.empty_like(q)
minimax_m3_sparse_attn_decode_asm(
q,
Expand All @@ -1236,5 +1396,6 @@ def _sparse_decode(
sparse_ctx=sparse_ctx,
)
self._index_q = None
self._index_q_cache_key_info = None
output = output.view(*q.shape)
return output
7 changes: 7 additions & 0 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ def allocate_kv_cache_tensors(
device="cuda",
)
tensors["_sparse_attention_cache_next"] = 0
if getattr(text_config, "use_index_cache", False) or getattr(
hf_config, "use_index_cache", False
):
tensors["_sparse_attention_topk_cache_state"] = {}
return tensors

def build_kv_cache_tensor(self, layer_id: int, module):
Expand Down Expand Up @@ -530,6 +534,9 @@ def build_kv_cache_tensor(self, layer_id: int, module):
runner._sparse_attention_cache_next += 1
module.impl.index_cache = runner.sparse_attention_index_cache[sparse_idx]
module.impl.max_model_len = runner.config.max_model_len
module.impl.index_topk_cache_state = getattr(
runner, "_sparse_attention_topk_cache_state", None
)
# NOTE: no return — fall through to the standard MHA binding below.

if not (
Expand Down
Loading
Loading