From 8d2791ea4d8b6fd0b1a8adf10d901355c15bb07c Mon Sep 17 00:00:00 2001 From: Zhao An Date: Wed, 20 May 2026 07:21:17 +0000 Subject: [PATCH 01/20] feat: RTPLLM plugin GLM5 integration --- atom/plugin/prepare.py | 8 + atom/plugin/rtpllm/__init__.py | 8 + .../rtpllm/attention_backend/__init__.py | 28 + .../rtp_dense_mla_backend.py | 498 ++++++++ .../attention_backend/rtp_mla_attention.py | 202 ++++ .../attention_backend/rtp_mla_metadata.py | 39 + .../rtp_sparse_mla_backend.py | 886 ++++++++++++++ atom/plugin/rtpllm/models/__init__.py | 15 +- .../rtpllm/models/base_model_wrapper.py | 9 + atom/plugin/rtpllm/models/glm5.py | 520 ++++++++ atom/plugin/rtpllm/utils/forward_context.py | 397 ++++++- .../test_rtpllm_forward_context_semantics.py | 430 ++++++- .../test_rtpllm_glm5_indexer_contract.py | 230 ++++ .../test_rtpllm_glm5_mha_bridge_guard.py | 66 ++ .../test_rtpllm_glm5_mla_bridge_shape.py | 24 + .../test_rtpllm_glm5_mla_forward_contract.py | 1053 +++++++++++++++++ tests/plugin/test_rtpllm_glm5_mla_patch.py | 23 + tests/plugin/test_rtpllm_glm5_ownership.py | 39 + tests/plugin/test_rtpllm_glm5_registration.py | 78 ++ ...est_rtpllm_glm5_sparse_backend_contract.py | 353 ++++++ .../test_rtpllm_glm5_wrapper_lifecycle.py | 320 +++++ tests/plugin/test_rtpllm_model_wrapper.py | 17 +- tests/plugin/test_rtpllm_prepare_model.py | 44 + 23 files changed, 5244 insertions(+), 43 deletions(-) create mode 100644 atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py create mode 100644 atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py create mode 100644 atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py create mode 100644 atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py create mode 100644 atom/plugin/rtpllm/models/glm5.py create mode 100644 tests/plugin/test_rtpllm_glm5_indexer_contract.py create mode 100644 tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py create mode 100644 tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py create mode 100644 tests/plugin/test_rtpllm_glm5_mla_forward_contract.py create mode 100644 tests/plugin/test_rtpllm_glm5_mla_patch.py create mode 100644 tests/plugin/test_rtpllm_glm5_ownership.py create mode 100644 tests/plugin/test_rtpllm_glm5_registration.py create mode 100644 tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py create mode 100644 tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 191e9a8bad..41b892c474 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -98,6 +98,7 @@ def _prepare_model_atom_sglang( def _prepare_model_atom_rtpllm( config: Any, atom_config: Any, + model_arch: str, model_cls: Any, set_attn_cls: Any, init_aiter_dist: Any, @@ -120,6 +121,12 @@ def _prepare_model_atom_rtpllm( ) set_attn_cls() + if model_arch == "GlmMoeDsaForCausalLM": + from atom.plugin.rtpllm.attention_backend import ( + apply_attention_mla_rtpllm_patch, + ) + + apply_attention_mla_rtpllm_patch() # init aiter dist for using aiter custom collective ops init_aiter_dist(config=atom_config) @@ -172,6 +179,7 @@ def prepare_model(config: Any, engine: str): return _prepare_model_atom_rtpllm( config, atom_config, + model_arch, model_cls, set_attn_cls, init_aiter_dist, diff --git a/atom/plugin/rtpllm/__init__.py b/atom/plugin/rtpllm/__init__.py index e69de29bb2..4dad126add 100644 --- a/atom/plugin/rtpllm/__init__.py +++ b/atom/plugin/rtpllm/__init__.py @@ -0,0 +1,8 @@ +try: + from .models import base_model_wrapper as _base_model_wrapper +except ModuleNotFoundError as exc: + if exc.name != "rtp_llm": + raise + _base_model_wrapper = None + +__all__ = ["_base_model_wrapper"] diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index 1afa3cdb59..9a157a3f71 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,10 +1,38 @@ from .attention_gdn import apply_attention_gdn_rtpllm_patch from .attention_switch import apply_attention_mha_rtpllm_patch from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention +from .rtp_dense_mla_backend import RTPDenseMlaBackend +from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch +from .rtp_mla_metadata import ( + GLM5_RTP_BRIDGE_MODE, + GLM5_RTP_BRIDGE_MODE_M0_DENSE, + GLM5_RTP_OWNERSHIP, + RTPMlaPluginMetadata, +) +from .rtp_sparse_mla_backend import RTPSparseMlaBackend + + +def __getattr__(name): + if name in {"RTPAttention", "RTPFullAttention"}: + from .rtp_full_attention import RTPAttention, RTPFullAttention + + return {"RTPAttention": RTPAttention, "RTPFullAttention": RTPFullAttention}[ + name + ] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "AttentionForRTPLLM", "RTPFullAttention", + "RTPDenseMlaBackend", + "RTPMLAAttention", + "RTPSparseMlaBackend", + "GLM5_RTP_BRIDGE_MODE", + "GLM5_RTP_BRIDGE_MODE_M0_DENSE", + "GLM5_RTP_OWNERSHIP", + "RTPMlaPluginMetadata", "apply_attention_gdn_rtpllm_patch", "apply_attention_mha_rtpllm_patch", + "apply_attention_mla_rtpllm_patch", ] diff --git a/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py new file mode 100644 index 0000000000..8cd9b06f25 --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py @@ -0,0 +1,498 @@ +"""Dense MLA fallback for GLM5 rtp-llm plugin mode.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +_FP8_CACHE_DTYPES = tuple( + dtype + for dtype in ( + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e4m3fn", None), + torch.uint8, + ) + if dtype is not None +) + + +def _raise_cache_error(message: str) -> None: + raise RuntimeError(message) + + +@dataclass(frozen=True) +class _DenseMlaMetadata: + query_start_loc: torch.Tensor + seq_lens: torch.Tensor | None + block_table: torch.Tensor | None + slot_mapping: torch.Tensor | None + is_prefill: bool + block_size: int + + +class RTPDenseMlaBackend: + """Small dense MLA backend used before the sparse kernel is wired. + + This backend intentionally avoids vLLM plugin metadata. It consumes the + native GLM5 five-tuple already prepared by DeepseekV2MLAAttention and uses + RTPForwardContext metadata only to recover per-sequence token ranges. + """ + + def __init__(self, *, mla_modules: Any) -> None: + self.mla_modules = mla_modules + self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) + self.rotary_emb = getattr(mla_modules, "rotary_emb", None) + self.v_head_dim = int(getattr(mla_modules, "v_head_dim")) + self.qk_nope_head_dim = getattr(mla_modules, "qk_nope_head_dim", None) + self.qk_rope_head_dim = getattr(mla_modules, "qk_rope_head_dim", None) + self._projection_checked = False + + @staticmethod + def _read_is_prefill(context: Any) -> bool: + if context is None or not hasattr(context, "is_prefill"): + raise ValueError( + "GLM5 RTP dense MLA requires explicit context.is_prefill metadata." + ) + return bool(getattr(context, "is_prefill")) + + @staticmethod + def _get_metadata(num_tokens: int, device: torch.device) -> _DenseMlaMetadata: + attn_metadata = None + context = None + rtp_seq_size_per_block = 1 + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + attn_metadata = getattr(forward_context, "attn_metadata", None) + context = getattr(forward_context, "context", None) + rtp_seq_size_per_block = int( + getattr(attn_metadata, "rtp_seq_size_per_block", 0) + or getattr(attn_metadata, "rtp_kernel_seq_size_per_block", 0) + or 0 + ) + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + query_start_loc = getattr(plugin_metadata, "query_start_loc", None) + if query_start_loc is None: + query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) + if query_start_loc is None: + query_start_loc = getattr(attn_metadata, "cu_seqlens_q", None) + if query_start_loc is None: + decode_metadata = getattr(plugin_metadata, "decode_metadata", None) + query_start_loc = getattr(decode_metadata, "query_start_loc", None) + seq_lens = getattr(plugin_metadata, "seq_lens", None) + if seq_lens is None: + seq_lens = getattr(attn_metadata, "context_lens", None) + block_table = getattr(plugin_metadata, "block_table", None) + if block_table is None: + block_table = getattr(attn_metadata, "block_tables", None) + slot_mapping = getattr(plugin_metadata, "slot_mapping", None) + if slot_mapping is None: + slot_mapping = getattr(attn_metadata, "slot_mapping", None) + except Exception: + query_start_loc = None + seq_lens = None + block_table = None + slot_mapping = None + + if ( + context is not None + and hasattr(context, "is_prefill") + and not bool(getattr(context, "is_prefill")) + and isinstance(seq_lens, torch.Tensor) + and int(seq_lens.numel()) == num_tokens + and isinstance(block_table, torch.Tensor) + and isinstance(slot_mapping, torch.Tensor) + ): + query_start_loc = torch.arange( + num_tokens + 1, dtype=torch.int64, device=device + ) + + if query_start_loc is not None and int(query_start_loc.numel()) >= 2: + query_start_loc = query_start_loc.to(device=device, dtype=torch.int64) + if int(query_start_loc[0].item()) == 0 and int(query_start_loc[-1].item()) == num_tokens: + is_prefill = RTPDenseMlaBackend._read_is_prefill(context) + return _DenseMlaMetadata( + query_start_loc=query_start_loc, + seq_lens=( + seq_lens.to(device=device, dtype=torch.int64) + if isinstance(seq_lens, torch.Tensor) + else None + ), + block_table=( + block_table.to(device=device, dtype=torch.int64) + if isinstance(block_table, torch.Tensor) + else None + ), + slot_mapping=( + slot_mapping.to(device=device, dtype=torch.int64) + if isinstance(slot_mapping, torch.Tensor) + else None + ), + is_prefill=is_prefill, + block_size=max(1, rtp_seq_size_per_block), + ) + if num_tokens != 1: + raise ValueError( + "GLM5 RTP dense MLA requires query_start_loc metadata for " + f"multi-token batches (num_tokens={num_tokens})." + ) + is_prefill = RTPDenseMlaBackend._read_is_prefill(context) + return _DenseMlaMetadata( + query_start_loc=torch.tensor([0, num_tokens], dtype=torch.int64, device=device), + seq_lens=None, + block_table=None, + slot_mapping=None, + is_prefill=is_prefill, + block_size=max(1, rtp_seq_size_per_block), + ) + + @staticmethod + def _unwrap_linear_output(value: Any) -> torch.Tensor: + if isinstance(value, tuple): + value = value[0] + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected kv_b_proj to return Tensor, got {type(value)!r}.") + return value + + def _apply_current_rope( + self, + q: torch.Tensor, + k_pe: torch.Tensor, + positions: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + rope_dim = int(self.qk_rope_head_dim or k_pe.shape[-1]) + if rope_dim == 0: + return q, k_pe + if self.rotary_emb is None: + raise ValueError("GLM5 RTP dense MLA requires rotary_emb for RoPE dimensions.") + if positions is None or int(positions.numel()) != int(q.shape[0]): + got = None if positions is None else int(positions.numel()) + raise ValueError( + "GLM5 RTP dense MLA requires per-token absolute positions for RoPE " + f"(positions={got}, tokens={int(q.shape[0])})." + ) + if int(q.shape[-1]) < rope_dim: + raise ValueError( + f"GLM5 RTP dense MLA invalid q shape for RoPE: q={tuple(q.shape)}, " + f"rope_dim={rope_dim}." + ) + + q_rope = q.clone() + k_pe_rope = k_pe.clone() + # RotaryEmbedding.forward rotates the full tensor it receives. Passing + # only q_pe/k_pe is equivalent to the fused MLA path's nope-first layout. + rotated_q_pe, rotated_k_pe = self.rotary_emb( + positions.to(device=q.device, dtype=torch.long), + q_rope[..., -rope_dim:], + k_pe_rope, + ) + q_rope[..., -rope_dim:] = rotated_q_pe + return q_rope, rotated_k_pe + + def _project_kv( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor] | None: + if self.kv_b_proj is None: + return None + _, num_heads, qk_head_dim = q.shape + num_kv_tokens = int(compressed_kv.shape[0]) + rope_dim = int(self.qk_rope_head_dim or k_pe.shape[-1]) + nope_dim = int(self.qk_nope_head_dim or (qk_head_dim - rope_dim)) + if nope_dim <= 0: + raise ValueError( + f"Invalid MLA qk dims: qk_head_dim={qk_head_dim}, rope_dim={rope_dim}." + ) + + compressed_kv = compressed_kv.contiguous() + kv_nope = self._unwrap_linear_output(self.kv_b_proj(compressed_kv)) + if kv_nope.numel() == 0: + raise ValueError("GLM5 RTP dense MLA kv_b_proj returned an empty tensor.") + expected_last_dim = num_heads * (nope_dim + self.v_head_dim) + if kv_nope.shape[-1] != expected_last_dim: + raise ValueError( + "GLM5 RTP dense MLA kv_b_proj output shape mismatch " + f"(got={tuple(kv_nope.shape)}, expected_last_dim={expected_last_dim}, " + f"num_heads={num_heads}, qk_nope_head_dim={nope_dim}, " + f"v_head_dim={self.v_head_dim})." + ) + if not self._projection_checked: + self._projection_checked = True + + kv_nope = kv_nope.reshape(num_kv_tokens, num_heads, nope_dim + self.v_head_dim) + k_nope, value = kv_nope.split([nope_dim, self.v_head_dim], dim=-1) + if k_pe.dim() == 2: + k_pe = k_pe.unsqueeze(1) + k_pe = k_pe.expand(num_kv_tokens, num_heads, rope_dim) + key = torch.cat((k_nope, k_pe), dim=-1) + return key, value + + @staticmethod + def _causal_attention( + q: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + query_start_loc: torch.Tensor, + scale: float, + ) -> torch.Tensor: + pieces: list[torch.Tensor] = [] + for start_tensor, end_tensor in zip(query_start_loc[:-1], query_start_loc[1:]): + start = int(start_tensor.item()) + end = int(end_tensor.item()) + if end <= start: + continue + q_seg = q[start:end].float() + k_seg = key[start:end].float() + v_seg = value[start:end].float() + scores = torch.einsum("tnd,snd->nts", q_seg, k_seg) * scale + seq_len = end - start + causal_mask = torch.ones( + (seq_len, seq_len), dtype=torch.bool, device=q.device + ).tril() + scores = scores.masked_fill(~causal_mask.unsqueeze(0), float("-inf")) + probs = torch.softmax(scores, dim=-1) + pieces.append(torch.einsum("nts,snd->tnd", probs, v_seg)) + if not pieces: + return value.new_empty((0, value.shape[1], value.shape[2])) + return torch.cat(pieces, dim=0) + + @staticmethod + def _cross_causal_attention( + q: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + ) -> torch.Tensor: + q_len = int(q.shape[0]) + k_len = int(key.shape[0]) + if q_len == 0: + return value.new_empty((0, value.shape[1], value.shape[2])) + if k_len < q_len: + raise ValueError( + f"GLM5 RTP dense MLA got invalid cross attention lengths: q={q_len}, k={k_len}." + ) + scores = torch.einsum("tnd,snd->nts", q.float(), key.float()) * scale + q_pos = torch.arange(q_len, device=q.device).unsqueeze(1) + k_pos = torch.arange(k_len, device=q.device).unsqueeze(0) + causal_mask = k_pos <= (k_len - q_len + q_pos) + scores = scores.masked_fill(~causal_mask.unsqueeze(0), float("-inf")) + probs = torch.softmax(scores, dim=-1) + return torch.einsum("nts,snd->tnd", probs, value.float()) + + @staticmethod + def _flatten_latent_cache( + layer_cache: Any, + *, + block_size: int, + kv_dim: int, + ) -> torch.Tensor | None: + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if not isinstance(kv_cache_base, torch.Tensor) or kv_cache_base.numel() == 0: + return None + if kv_cache_base.dtype in _FP8_CACHE_DTYPES: + raise NotImplementedError( + "GLM5 RTP dense MLA reference path requires BF16/FP16 latent KV cache; " + "FP8 KV cache layout/dequant is not supported yet." + ) + if kv_cache_base.dim() == 3 and int(kv_cache_base.shape[-1]) == kv_dim: + return kv_cache_base.reshape(-1, kv_dim) + if kv_cache_base.dim() == 2 and int(kv_cache_base.shape[1]) % block_size == 0: + per_token_dim = int(kv_cache_base.shape[1]) // block_size + if per_token_dim == kv_dim: + return kv_cache_base.view(kv_cache_base.shape[0], block_size, kv_dim).reshape( + -1, kv_dim + ) + return None + + @staticmethod + def _write_current_to_cache( + *, + layer_cache: Any, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + metadata: _DenseMlaMetadata, + kv_dim: int, + ) -> None: + if metadata.slot_mapping is None: + return + flat_cache = RTPDenseMlaBackend._flatten_latent_cache( + layer_cache, block_size=metadata.block_size, kv_dim=kv_dim + ) + if flat_cache is None: + return + latent = torch.cat((compressed_kv, k_pe), dim=-1) + if latent.shape[0] != metadata.slot_mapping.shape[0]: + return + slots = metadata.slot_mapping[: latent.shape[0]].long() + flat_size = int(flat_cache.shape[0]) + non_negative = slots >= 0 + in_bounds = non_negative & (slots < flat_size) + if bool((non_negative & (slots >= flat_size)).any().item()): + bad_slots = slots[non_negative & (slots >= flat_size)] + _raise_cache_error( + "GLM5 RTP dense MLA refuses to write out-of-bounds slot_mapping " + f"(block_size={metadata.block_size}, flat_tokens={flat_size}, " + f"slot_min={int(bad_slots.min().item())}, " + f"slot_max={int(bad_slots.max().item())})." + ) + if not bool(in_bounds.any().item()): + return + flat_cache[slots[in_bounds]] = latent[in_bounds].to(dtype=flat_cache.dtype) + + @staticmethod + def _resolve_layer_cache(kv_cache: object, layer_id: int) -> object: + if kv_cache is not None: + return kv_cache + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + kv_cache_data = getattr(forward_context, "kv_cache_data", None) + if kv_cache_data is None: + return None + layer_cache_entry = kv_cache_data.get(f"layer_{int(layer_id)}") + if layer_cache_entry is None: + return None + return getattr(layer_cache_entry, "k_cache", layer_cache_entry) + except Exception: + return None + + @staticmethod + def _gather_latent_history( + *, + layer_cache: Any, + metadata: _DenseMlaMetadata, + batch_idx: int, + kv_dim: int, + ) -> torch.Tensor | None: + if metadata.block_table is None or metadata.seq_lens is None: + return None + flat_cache = RTPDenseMlaBackend._flatten_latent_cache( + layer_cache, block_size=metadata.block_size, kv_dim=kv_dim + ) + if flat_cache is None: + return None + seq_len = int(metadata.seq_lens[batch_idx].item()) + if seq_len <= 0: + return None + block_row = metadata.block_table[batch_idx].long() + positions = torch.arange(seq_len, dtype=torch.long, device=flat_cache.device) + block_cols = torch.div(positions, metadata.block_size, rounding_mode="floor") + block_col_max = int(block_cols.max().item()) + if block_col_max >= int(block_row.numel()): + return None + offsets = positions.remainder(metadata.block_size) + slots = block_row[block_cols] * metadata.block_size + offsets + flat_size = int(flat_cache.shape[0]) + if bool(((slots < 0) | (slots >= flat_size)).any().item()): + bad_slots = slots[(slots < 0) | (slots >= flat_size)] + _raise_cache_error( + "GLM5 RTP dense MLA refuses to gather out-of-bounds KV history " + f"(batch_idx={batch_idx}, seq_len={seq_len}, " + f"block_size={metadata.block_size}, flat_tokens={flat_size}, " + f"slot_min={int(bad_slots.min().item())}, " + f"slot_max={int(bad_slots.max().item())})." + ) + return flat_cache[slots] + + @staticmethod + def _require_decode_cache_metadata( + *, + layer_cache: Any, + metadata: _DenseMlaMetadata, + kv_dim: int, + ) -> None: + missing = [] + if metadata.block_table is None: + missing.append("block_table") + if metadata.seq_lens is None: + missing.append("seq_lens") + if metadata.slot_mapping is None: + missing.append("slot_mapping") + if missing: + raise ValueError( + "GLM5 RTP dense MLA decode requires RTP KV metadata: " + + ", ".join(missing) + + "." + ) + flat_cache = RTPDenseMlaBackend._flatten_latent_cache( + layer_cache, block_size=metadata.block_size, kv_dim=kv_dim + ) + if flat_cache is None: + raise ValueError( + "GLM5 RTP dense MLA decode requires a readable BF16/FP16 kv_cache_base." + ) + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + topk_indices: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del topk_indices + if self.kv_b_proj is None: + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + layer_cache = self._resolve_layer_cache(kv_cache, layer_id) + q, k_pe = self._apply_current_rope(q, k_pe, positions) + projected = self._project_kv(q, compressed_kv, k_pe) + if projected is None: + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + metadata = self._get_metadata(q.shape[0], q.device) + kv_dim = int(compressed_kv.shape[-1]) + int(k_pe.shape[-1]) + self._write_current_to_cache( + layer_cache=layer_cache, + compressed_kv=compressed_kv, + k_pe=k_pe, + metadata=metadata, + kv_dim=kv_dim, + ) + key, value = projected + query_start_loc = metadata.query_start_loc + scale = float(q.shape[-1] ** -0.5) + if metadata.is_prefill: + output = self._causal_attention(q, key, value, query_start_loc, scale) + return output.to(dtype=compressed_kv.dtype) + + self._require_decode_cache_metadata( + layer_cache=layer_cache, + metadata=metadata, + kv_dim=kv_dim, + ) + pieces: list[torch.Tensor] = [] + for batch_idx, (start_tensor, end_tensor) in enumerate( + zip(query_start_loc[:-1], query_start_loc[1:]) + ): + start = int(start_tensor.item()) + end = int(end_tensor.item()) + if end <= start: + continue + q_seg = q[start:end] + latent_history = self._gather_latent_history( + layer_cache=layer_cache, + metadata=metadata, + batch_idx=batch_idx, + kv_dim=kv_dim, + ) + if latent_history is None: + raise ValueError( + "GLM5 RTP dense MLA decode failed to gather latent KV history." + ) + hist_compressed_kv, hist_k_pe = latent_history.split( + [compressed_kv.shape[-1], k_pe.shape[-1]], dim=-1 + ) + hist_key, hist_value = self._project_kv(q_seg, hist_compressed_kv, hist_k_pe) + pieces.append( + self._cross_causal_attention(q_seg, hist_key, hist_value, scale) + ) + output = torch.cat(pieces, dim=0) if pieces else value.new_empty((0, *value.shape[1:])) + return output.to(dtype=compressed_kv.dtype) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py new file mode 100644 index 0000000000..3944a92c10 --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -0,0 +1,202 @@ +"""RTP-style MLA adapter for GLM5 rtp-llm plugin mode.""" + +from __future__ import annotations + +import inspect +from typing import Optional + +import torch + + +def _resolve_index_topk(attn) -> int: + for obj, attr in ( + (getattr(attn, "indexer", None), "index_topk"), + (getattr(attn, "indexer", None), "topk_tokens"), + (attn, "index_topk"), + (getattr(attn, "config", None), "index_topk"), + ): + value = getattr(obj, attr, None) if obj is not None else None + if value is not None: + return int(value) + raise AttributeError("GLM5 RTP MLA M1 indexer requires index_topk/topk_tokens") + + +def _get_topk_indices_buffer(attn) -> torch.Tensor: + indexer = getattr(attn, "indexer", None) + buffer = getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None + if buffer is None: + buffer = getattr(attn, "topk_indices_buffer", None) + if buffer is None: + buffer = getattr(attn, "_topk_indices_buffer", None) + if buffer is None: + raise AttributeError("GLM5 RTP MLA M1 indexer requires topk_indices_buffer") + return buffer + + +def _should_emit_topk_indices(attn) -> bool: + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + except Exception: + return True + + context = getattr(forward_context, "context", None) + if getattr(context, "is_dummy_run", False): + return False + attn_metadata = getattr(forward_context, "attn_metadata", None) + if getattr(context, "is_prefill", False) and attn_metadata is not None: + max_seqlen_k = getattr(attn_metadata, "max_seqlen_k", None) + if max_seqlen_k is not None: + try: + return int(max_seqlen_k) > _get_topk_indices_buffer(attn).shape[1] + except AttributeError: + return True + return True + + +class RTPMLAAttention: + """Dense RTP MLA adapter for the native GLM5 MLA call contract.""" + + use_mla = True + + def __init__(self, *args, **kwargs) -> None: + self.args = args + self.kwargs = kwargs + mla_modules = kwargs.get("mla_modules") + self.mla_modules = mla_modules + self.q_proj = getattr(mla_modules, "q_proj", None) + self.o_proj = getattr(mla_modules, "o_proj", None) + self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) + self.indexer = getattr(mla_modules, "indexer", None) + self.qk_head_dim = getattr(mla_modules, "qk_head_dim", None) + self.v_head_dim = getattr(mla_modules, "v_head_dim", None) + self.q_lora_rank = getattr(mla_modules, "q_lora_rank", None) + self.kv_lora_rank = getattr(mla_modules, "kv_lora_rank", None) + self.num_heads = getattr(mla_modules, "num_heads", None) + self.num_local_heads = getattr(mla_modules, "num_local_heads", self.num_heads) + self.index_topk = getattr(mla_modules, "index_topk", None) + self.topk_indices_buffer = ( + getattr(self.indexer, "topk_indices_buffer", None) + if self.indexer is not None + else None + ) + injected_backend = kwargs.get("dense_backend") + if injected_backend is not None: + self.dense_backend = injected_backend + elif mla_modules is not None: + from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( + RTPDenseMlaBackend, + ) + from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( + RTPSparseMlaBackend, + ) + + self.dense_backend = RTPSparseMlaBackend( + dense_backend=RTPDenseMlaBackend(mla_modules=mla_modules), + v_head_dim=mla_modules.v_head_dim, + mla_modules=mla_modules, + scale=kwargs.get("scale"), + ) + else: + self.dense_backend = None + self.kv_cache = kwargs.get("kv_cache") + self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0))) + + @staticmethod + def _backend_accepts_positions(backend: object) -> bool: + try: + signature = inspect.signature(backend.forward) + except (AttributeError, TypeError, ValueError): + return False + return "positions" in signature.parameters or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + + def _project_query( + self, query: torch.Tensor, q_scale: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, bool]: + if query.ndim == 3: + return query, False + if self.q_proj is None: + return query, False + + q = self.q_proj(query, q_scale) + if q.ndim == 3: + return q, True + + num_heads = self.num_local_heads if self.num_local_heads is not None else self.num_heads + if num_heads is None: + if self.qk_head_dim is None: + raise AttributeError("GLM5 RTP MLA native contract requires num_heads") + num_heads = q.shape[-1] // int(self.qk_head_dim) + if self.qk_head_dim is None: + self.qk_head_dim = q.shape[-1] // int(num_heads) + return q.reshape(-1, int(num_heads), int(self.qk_head_dim)), True + + def _resolve_topk_indices( + self, + query: torch.Tensor, + q_scale: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + explicit_topk_indices: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + if explicit_topk_indices is not None: + return explicit_topk_indices + if self.indexer is None: + return None + + if not _should_emit_topk_indices(self): + return None + index_topk = _resolve_index_topk(self) + return _get_topk_indices_buffer(self)[: query.shape[0], :index_topk] + + def forward( + self, + query: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + positions: Optional[torch.Tensor] = None, + q_scale: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if self.dense_backend is None: + raise NotImplementedError( + "RTPMLAAttention requires an attention backend for contract execution" + ) + q, native_projected = self._project_query(query, q_scale) + topk_indices = self._resolve_topk_indices( + query, + q_scale, + positions, + kwargs.get("topk_indices", topk_indices), + ) + forward_kwargs = {"topk_indices": topk_indices} + if self._backend_accepts_positions(self.dense_backend): + forward_kwargs["positions"] = positions + attn_output = self.dense_backend.forward( + q, + compressed_kv, + k_pe, + self.kv_cache, + self.layer_id, + **forward_kwargs, + ) + if native_projected and self.o_proj is not None: + attn_output = attn_output.reshape(attn_output.shape[0], -1).contiguous() + return self.o_proj(attn_output) + return attn_output + + __call__ = forward + + +def apply_attention_mla_rtpllm_patch() -> None: + """Switch ATOM's generic Attention symbol to the RTP MLA adapter.""" + + import atom.model_ops as ops + + ops.RTPMLAAttention = RTPMLAAttention + ops.Attention = RTPMLAAttention + diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py new file mode 100644 index 0000000000..bcd1b20c5e --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py @@ -0,0 +1,39 @@ +"""Metadata and static contracts for GLM5 MLA in rtp-llm plugin mode.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +GLM5_RTP_BRIDGE_MODE_M0_DENSE = "m0_dense" +GLM5_RTP_BRIDGE_MODE = GLM5_RTP_BRIDGE_MODE_M0_DENSE + + +GLM5_RTP_OWNERSHIP = { + "main_q_norm": "DeepseekV2MLAAttention", + "main_kv_norm": "DeepseekV2MLAAttention", + "main_rope": "RTPMLAAttention", + "main_kv_cache": "RTPMLAAttention", + "indexer_k_norm": "Indexer", + "indexer_rope": "Indexer", + "indexer_cache": "Indexer", + "topk_selector": "Indexer", +} + + +@dataclass(frozen=True) +class RTPMlaPluginMetadata: + """Minimal M0 placeholder for RTP MLA metadata. + + M0 intentionally does not model indexer/top-k metadata. M1/M2 should extend + this structure instead of overloading MHA plugin metadata. + """ + + is_prefill: bool + slot_mapping: Optional[torch.Tensor] = None + block_table: Optional[torch.Tensor] = None + seq_lens: Optional[torch.Tensor] = None + diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py new file mode 100644 index 0000000000..b82270cd83 --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -0,0 +1,886 @@ +"""Contract-executable sparse MLA backend for GLM5 rtp-llm plugin mode.""" + +from __future__ import annotations + +import inspect +import os +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +class _SparseUnavailable(RuntimeError): + pass + + +@dataclass +class _AbsorbedWeights: + w_kc: torch.Tensor + w_vc: torch.Tensor + + +@dataclass +class _AtomSparseMetadata: + qo_indptr: torch.Tensor + paged_kv_indptr: torch.Tensor + paged_kv_indices: torch.Tensor + paged_kv_last_page_len: torch.Tensor + work_meta_data: torch.Tensor + work_indptr: torch.Tensor + work_info_set: torch.Tensor + reduce_indptr: torch.Tensor + reduce_final_map: torch.Tensor + reduce_partial_map: torch.Tensor + padded_num_heads: int + head_repeat_factor: int + + +class _ContractSparseMlaImpl: + """CPU/mock sparse implementation used before the real RTP kernel is wired.""" + + def __init__(self, v_head_dim: int) -> None: + self.v_head_dim = int(v_head_dim) + self.calls = [] + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + *, + topk_indices: torch.Tensor, + attn_metadata: object, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + "attn_metadata": attn_metadata, + "positions": positions, + } + ) + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + + +class _RealSparseMlaImpl: + """Runtime sparse MLA adapter for ATOM-owned GLM5 weights and RTP KV cache.""" + + def __init__( + self, + *, + mla_modules: Any, + v_head_dim: int, + scale: Optional[float] = None, + ) -> None: + self.mla_modules = mla_modules + self.v_head_dim = int(v_head_dim) + self.kv_lora_rank = int(getattr(mla_modules, "kv_lora_rank")) + self.qk_nope_head_dim = int(getattr(mla_modules, "qk_nope_head_dim")) + self.qk_rope_head_dim = int(getattr(mla_modules, "qk_rope_head_dim")) + self.num_heads = int(getattr(mla_modules, "num_heads", 0) or 0) + self.rotary_emb = getattr(mla_modules, "rotary_emb", None) + self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) + self.scale = ( + float(scale) + if scale is not None + else float((self.qk_nope_head_dim + self.qk_rope_head_dim) ** -0.5) + ) + self._absorbed_weights: _AbsorbedWeights | None = None + self._cache_write_scale: dict[torch.device, torch.Tensor] = {} + + @staticmethod + def _unwrap_linear_output(value: Any) -> torch.Tensor: + if isinstance(value, tuple): + value = value[0] + if not isinstance(value, torch.Tensor): + raise TypeError(f"Expected kv_b_proj to return Tensor, got {type(value)!r}.") + return value + + def _infer_num_heads(self, q: torch.Tensor) -> int: + if self.num_heads > 0: + return self.num_heads + self.num_heads = int(q.shape[1]) + return self.num_heads + + def _read_kv_b_proj_weight(self) -> torch.Tensor: + if self.kv_b_proj is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires kv_b_proj.") + try: + from atom.model_ops.utils import get_and_maybe_dequant_weights + + weight = get_and_maybe_dequant_weights(self.kv_b_proj) + except Exception: + weight = getattr(self.kv_b_proj, "weight", None) + if not isinstance(weight, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA cannot read kv_b_proj.weight.") + if weight.dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA needs dequantized kv_b_proj weights for " + "the current adapter." + ) + return weight + + def _get_absorbed_weights(self, q: torch.Tensor) -> _AbsorbedWeights: + cached = self._absorbed_weights + if cached is not None and cached.w_kc.device == q.device: + return cached + + weight = self._read_kv_b_proj_weight().to(device=q.device) + num_heads = self._infer_num_heads(q) + expected_out = num_heads * (self.qk_nope_head_dim + self.v_head_dim) + if weight.ndim != 2: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA got invalid kv_b_proj weight shape {tuple(weight.shape)}." + ) + if int(weight.shape[0]) == expected_out and int(weight.shape[1]) == self.kv_lora_rank: + kv_b_weight = weight.T.contiguous() + elif int(weight.shape[1]) == expected_out and int(weight.shape[0]) == self.kv_lora_rank: + kv_b_weight = weight.contiguous() + else: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA kv_b_proj weight shape mismatch " + f"(got={tuple(weight.shape)}, expected_out={expected_out}, " + f"kv_lora_rank={self.kv_lora_rank})." + ) + + kv_b_weight = kv_b_weight.view( + self.kv_lora_rank, + num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + w_uk, w_uv = kv_b_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + absorbed = _AbsorbedWeights( + w_kc=w_uk.permute(1, 2, 0).contiguous(), + w_vc=w_uv.permute(1, 0, 2).contiguous(), + ) + self._absorbed_weights = absorbed + return absorbed + + def _apply_rope( + self, + q: torch.Tensor, + k_pe: torch.Tensor, + positions: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + rope_dim = int(self.qk_rope_head_dim) + if rope_dim == 0: + return q, k_pe + if self.rotary_emb is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires rotary_emb.") + if positions is None or int(positions.numel()) != int(q.shape[0]): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires per-token positions for RoPE " + f"(positions={None if positions is None else int(positions.numel())}, " + f"tokens={int(q.shape[0])})." + ) + q_rope = q.clone() + k_pe_rope = k_pe.clone() + rotated_q_pe, rotated_k_pe = self.rotary_emb( + positions.to(device=q.device, dtype=torch.long), + q_rope[..., -rope_dim:], + k_pe_rope, + ) + q_rope[..., -rope_dim:] = rotated_q_pe + return q_rope, rotated_k_pe + + def _cache_dtype_name(self, kv_cache_base: torch.Tensor) -> str: + fp8_dtypes = { + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + torch.uint8, + ) + if dtype is not None + } + if kv_cache_base.dtype not in fp8_dtypes: + return "auto" + explicit = os.getenv("ATOM_RTP_MLA_FP8_CACHE_DTYPE", "").strip() + if explicit: + return explicit + return "fp8_model1_mla" if self.kv_lora_rank == 448 else "fp8_ds_mla" + + def _write_current_to_cache( + self, + *, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: Any, + attn_metadata: Any, + ) -> torch.Tensor: + kv_cache_base = getattr(kv_cache, "kv_cache_base", None) + if not isinstance(kv_cache_base, torch.Tensor) or kv_cache_base.numel() == 0: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires kv_cache_base.") + slot_mapping = getattr(attn_metadata, "slot_mapping", None) + if slot_mapping is None: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + slot_mapping = getattr(plugin_metadata, "slot_mapping", None) + if not isinstance(slot_mapping, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires slot_mapping.") + try: + from aiter import concat_and_cache_mla + except Exception as exc: + raise _SparseUnavailable(f"aiter.concat_and_cache_mla unavailable: {exc}") from exc + + scale = self._cache_write_scale.get(compressed_kv.device) + if scale is None: + scale = torch.tensor(1.0, dtype=torch.float32, device=compressed_kv.device) + self._cache_write_scale[compressed_kv.device] = scale + try: + concat_and_cache_mla( + compressed_kv, + k_pe, + kv_cache_base, + slot_mapping.to(device=compressed_kv.device, dtype=torch.int64), + kv_cache_dtype=self._cache_dtype_name(kv_cache_base), + scale=scale, + ) + except Exception as exc: + raise _SparseUnavailable(f"concat_and_cache_mla failed: {exc}") from exc + return kv_cache_base + + @staticmethod + def _build_req_id_per_token( + attn_metadata: Any, + num_tokens: int, + device: torch.device, + ) -> torch.Tensor: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + req_id = getattr(plugin_metadata, "req_id_per_token", None) + if isinstance(req_id, torch.Tensor) and int(req_id.numel()) >= num_tokens: + return req_id[:num_tokens].to(device=device, dtype=torch.int32) + query_start_loc = getattr(plugin_metadata, "query_start_loc", None) + if query_start_loc is None: + query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) + if query_start_loc is None: + query_start_loc = getattr(attn_metadata, "cu_seqlens_q", None) + if isinstance(query_start_loc, torch.Tensor) and int(query_start_loc.numel()) >= 2: + qsl = query_start_loc.to(device=device, dtype=torch.int64) + lengths = qsl[1:] - qsl[:-1] + return torch.repeat_interleave( + torch.arange(int(lengths.numel()), device=device, dtype=torch.int32), + lengths, + )[:num_tokens].contiguous() + return torch.arange(num_tokens, device=device, dtype=torch.int32) + + @staticmethod + def _block_table(attn_metadata: Any, device: torch.device) -> torch.Tensor: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + block_table = getattr(plugin_metadata, "block_table", None) + if block_table is None: + block_table = getattr(attn_metadata, "block_tables", None) + if not isinstance(block_table, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires block_table.") + if block_table.ndim == 1: + block_table = block_table.unsqueeze(0) + return block_table.to(device=device, dtype=torch.int32) + + @staticmethod + def _convert_topk_to_global( + *, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> torch.Tensor: + num_tokens, topk = topk_indices.shape + device = topk_indices.device + block_table = _RealSparseMlaImpl._block_table(attn_metadata, device) + req_id = _RealSparseMlaImpl._build_req_id_per_token( + attn_metadata, num_tokens, device + ).to(dtype=torch.long) + token_indices = topk_indices.to(device=device, dtype=torch.long) + valid = token_indices >= 0 + block_cols = torch.div( + torch.clamp(token_indices, min=0), + int(block_size), + rounding_mode="floor", + ) + offsets = torch.remainder(torch.clamp(token_indices, min=0), int(block_size)) + valid = valid & (req_id[:, None] >= 0) & (req_id[:, None] < block_table.shape[0]) + valid = valid & (block_cols >= 0) & (block_cols < block_table.shape[1]) + safe_req = torch.clamp(req_id, min=0, max=max(int(block_table.shape[0]) - 1, 0)) + safe_cols = torch.clamp(block_cols, min=0, max=max(int(block_table.shape[1]) - 1, 0)) + block_ids = block_table.to(dtype=torch.long)[safe_req[:, None], safe_cols] + valid = valid & (block_ids >= 0) + global_indices = block_ids * int(block_size) + offsets + return torch.where(valid, global_indices, torch.zeros_like(global_indices)).to( + dtype=torch.int32 + ) + + @staticmethod + def _decode_indptr( + *, + num_tokens: int, + topk: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) + paged_kv_indptr = ( + torch.arange(num_tokens + 1, device=device, dtype=torch.int32) * int(topk) + ) + paged_kv_last_page_len = torch.ones( + (num_tokens,), device=device, dtype=torch.int32 + ) + return qo_indptr, paged_kv_indptr, paged_kv_last_page_len + + @staticmethod + def _generate_sparse_seqlen_torch( + *, + query_lens: torch.Tensor, + seq_lens: torch.Tensor, + query_start_loc: torch.Tensor, + topk: int, + num_tokens: int, + ) -> torch.Tensor: + out = torch.zeros((num_tokens,), dtype=torch.int32, device=query_lens.device) + for req_id in range(int(query_lens.numel())): + q_len = int(query_lens[req_id].item()) + seq_len = int(seq_lens[req_id].item()) + start = int(query_start_loc[req_id].item()) + if q_len <= 0 or seq_len <= 0: + continue + context_start = seq_len - q_len + offsets = torch.arange(q_len, device=query_lens.device, dtype=torch.int32) + out[start : start + q_len] = torch.clamp( + context_start + offsets + 1, + min=0, + max=int(topk), + ) + return out + + @staticmethod + def _aiter_dtype_for_tensor(tensor: torch.Tensor) -> Any: + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + + fp8_dtypes = { + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + torch.uint8, + getattr(dtypes, "fp8", None), + ) + if dtype is not None + } + if tensor.dtype in fp8_dtypes: + return dtypes.fp8 + if tensor.dtype == torch.float16: + return dtypes.d_dtypes["fp16"] + return dtypes.d_dtypes["bf16"] + + def _build_atom_sparse_metadata( + self, + *, + q_latent: torch.Tensor, + kv_cache_base: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> _AtomSparseMetadata: + if torch.cuda.is_current_stream_capturing(): + raise _SparseUnavailable("ATOM sparse MLA metadata is not graph-capture safe yet.") + try: + from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 + from atom.plugin.attention_mla_sparse import ( + generate_sparse_seqlen_triton, + triton_convert_req_index_to_global_index, + ) + except Exception as exc: + raise _SparseUnavailable(f"ATOM sparse MLA metadata helpers unavailable: {exc}") from exc + + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if plugin_metadata is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires plugin metadata.") + + num_tokens = int(q_latent.shape[0]) + num_heads = int(q_latent.shape[1]) + topk = int(topk_indices.shape[1]) + device = q_latent.device + + query_start_loc = getattr(plugin_metadata, "query_start_loc", None) + if query_start_loc is None: + query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) + if not isinstance(query_start_loc, torch.Tensor) or int(query_start_loc.numel()) < 2: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires query_start_loc.") + query_start_loc = query_start_loc.to(device=device, dtype=torch.int32).contiguous() + + seq_lens = getattr(plugin_metadata, "seq_lens", None) + if seq_lens is None: + seq_lens = getattr(attn_metadata, "context_lens", None) + if not isinstance(seq_lens, torch.Tensor) or int(seq_lens.numel()) + 1 != int( + query_start_loc.numel() + ): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires seq_lens per request.") + seq_lens = seq_lens.to(device=device, dtype=torch.int32).contiguous() + + req_id = self._build_req_id_per_token(attn_metadata, num_tokens, device).to( + dtype=torch.int32 + ) + block_table = self._block_table(attn_metadata, device).to(dtype=torch.int32) + topk_indices_i32 = topk_indices.to(device=device, dtype=torch.int32).contiguous() + query_lens = (query_start_loc[1:] - query_start_loc[:-1]).contiguous() + + if device.type == "cpu": + sparse_seqlen = self._generate_sparse_seqlen_torch( + query_lens=query_lens, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + topk=topk, + num_tokens=num_tokens, + ) + else: + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + query_start_loc, + topk, + num_tokens, + int(torch.max(query_lens).detach().cpu().item()) if num_tokens else 1, + ) + + qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) + paged_kv_indptr = torch.zeros((num_tokens + 1,), device=device, dtype=torch.int32) + torch.cumsum(sparse_seqlen, dim=0, out=paged_kv_indptr[1:]) + paged_kv_last_page_len = torch.ones((num_tokens,), device=device, dtype=torch.int32) + paged_kv_indices = torch.zeros((num_tokens * topk,), device=device, dtype=torch.int32) + + triton_convert_req_index_to_global_index( + req_id, + block_table, + topk_indices_i32, + paged_kv_indptr, + paged_kv_indices, + BLOCK_SIZE=int(block_size), + NUM_TOPK_TOKENS=topk, + ) + + padded_num_heads = max(num_heads, 16) + if padded_num_heads % num_heads != 0: + padded_num_heads = ((padded_num_heads + num_heads - 1) // num_heads) * num_heads + head_repeat_factor = padded_num_heads // num_heads + q_dtype = self._aiter_dtype_for_tensor(q_latent) + kv_dtype = self._aiter_dtype_for_tensor(kv_cache_base) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max(num_tokens, 1), + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + work_meta_data = torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device) + work_indptr = torch.empty(work_indptr_size, dtype=work_indptr_type, device=device) + work_info_set = torch.empty(work_info_set_size, dtype=work_info_set_type, device=device) + reduce_indptr = torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device) + reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + reduce_partial_map = torch.empty( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=device + ) + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + padded_num_heads, + 1, + True, + work_meta_data, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + page_size=1, + kv_granularity=16, + max_seqlen_qo=1, + uni_seqlen_qo=1, + fast_mode=True, + dtype_q=q_dtype, + dtype_kv=kv_dtype, + ) + return _AtomSparseMetadata( + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + padded_num_heads=padded_num_heads, + head_repeat_factor=head_repeat_factor, + ) + + def _run_sparse_decode( + self, + *, + q_latent: torch.Tensor, + kv_cache_base: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> torch.Tensor: + try: + from flash_mla import flash_mla_sparse_fwd + except Exception as exc: + return self._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + + latent_dim = int(q_latent.shape[-1]) + global_topk = self._convert_topk_to_global( + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + try: + kv_buffer = kv_cache_base.reshape(-1, latent_dim) + output, _, _ = flash_mla_sparse_fwd( + q_latent, + kv_buffer, + global_topk.contiguous().unsqueeze(1), + self.scale, + d_v=self.kv_lora_rank, + ) + except Exception as exc: + raise _SparseUnavailable(f"flash_mla_sparse_fwd failed: {exc}") from exc + return output + + def _run_aiter_sparse_decode( + self, + *, + q_latent: torch.Tensor, + kv_cache_base: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> torch.Tensor: + try: + from aiter.mla import mla_decode_fwd + except Exception as exc: + raise _SparseUnavailable(f"aiter.mla_decode_fwd unavailable: {exc}") from exc + + num_tokens, num_heads, latent_dim = q_latent.shape + sparse_meta = self._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + if sparse_meta.head_repeat_factor > 1: + q_for_kernel = q_latent.repeat_interleave( + sparse_meta.head_repeat_factor, dim=1 + ) + else: + q_for_kernel = q_latent + output = torch.empty( + (num_tokens, sparse_meta.padded_num_heads, self.kv_lora_rank), + dtype=q_for_kernel.dtype, + device=q_latent.device, + ) + try: + kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) + mla_decode_fwd( + q_for_kernel, + kv_buffer, + output, + sparse_meta.qo_indptr, + sparse_meta.paged_kv_indptr, + sparse_meta.paged_kv_indices, + sparse_meta.paged_kv_last_page_len, + 1, + sm_scale=self.scale, + page_size=1, + work_meta_data=sparse_meta.work_meta_data, + work_indptr=sparse_meta.work_indptr, + work_info_set=sparse_meta.work_info_set, + reduce_indptr=sparse_meta.reduce_indptr, + reduce_final_map=sparse_meta.reduce_final_map, + reduce_partial_map=sparse_meta.reduce_partial_map, + ) + except Exception as exc: + raise _SparseUnavailable(f"mla_decode_fwd failed: {exc}") from exc + if sparse_meta.head_repeat_factor > 1: + output = output[:, :: sparse_meta.head_repeat_factor, :].contiguous() + return output + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + *, + topk_indices: torch.Tensor, + attn_metadata: object, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del layer_id + if attn_metadata is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires attn_metadata.") + if getattr(getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False): + raise _SparseUnavailable("GLM5 RTP sparse MLA skips dummy warmup.") + q_rope, k_pe_rope = self._apply_rope(q, k_pe, positions) + kv_cache_base = self._write_current_to_cache( + compressed_kv=compressed_kv, + k_pe=k_pe_rope, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + absorbed = self._get_absorbed_weights(q_rope) + q_nope = q_rope[..., : self.qk_nope_head_dim] + q_latent_nope = torch.bmm( + q_nope.transpose(0, 1).to(dtype=absorbed.w_kc.dtype), + absorbed.w_kc, + ).transpose(0, 1) + q_latent = torch.empty( + q.shape[0], + q.shape[1], + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=q_latent_nope.dtype, + device=q.device, + ) + q_latent[..., : self.kv_lora_rank] = q_latent_nope + if self.qk_rope_head_dim > 0: + q_latent[..., self.kv_lora_rank :] = q_rope[ + ..., -self.qk_rope_head_dim : + ].to(dtype=q_latent.dtype) + + block_size = int(getattr(attn_metadata, "rtp_seq_size_per_block", 0) or 0) + if block_size <= 0: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + block_size = int(getattr(plugin_metadata, "sparse_block_size", 0) or 0) + if block_size <= 0: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires physical block size.") + latent_output = self._run_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + output = torch.bmm( + latent_output.transpose(0, 1).to(dtype=absorbed.w_vc.dtype), + absorbed.w_vc, + ).transpose(0, 1) + return output.to(dtype=q.dtype) + + +class RTPSparseMlaBackend: + """M2 sparse top-k consumption contract. + + This backend intentionally avoids importing RTP CUDA sparse kernels. It only + validates and threads the sparse contract so M2.5 can replace the mock impl. + """ + + def __init__( + self, + *, + dense_backend: object, + sparse_impl: Optional[object] = None, + v_head_dim: Optional[int] = None, + mla_modules: Optional[object] = None, + scale: Optional[float] = None, + ) -> None: + self.dense_backend = dense_backend + self.v_head_dim = int( + v_head_dim + if v_head_dim is not None + else getattr(dense_backend, "v_head_dim") + ) + if sparse_impl is not None: + self.sparse_impl = sparse_impl + self._default_mock = False + elif mla_modules is not None and all( + hasattr(mla_modules, attr) + for attr in ( + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "kv_b_proj", + "rotary_emb", + ) + ): + self.sparse_impl = _RealSparseMlaImpl( + mla_modules=mla_modules, + v_head_dim=self.v_head_dim, + scale=scale, + ) + self._default_mock = False + else: + self.sparse_impl = _ContractSparseMlaImpl(self.v_head_dim) + self._default_mock = True + + @staticmethod + def _get_attn_metadata() -> object: + try: + from atom.utils.forward_context import get_forward_context + + return getattr(get_forward_context(), "attn_metadata", None) + except Exception: + return None + + @staticmethod + def _validate_topk_indices(q: torch.Tensor, topk_indices: torch.Tensor) -> None: + if topk_indices.ndim != 2: + raise ValueError( + "Expected topk_indices to be rank-2 [T,K], " + f"got shape {tuple(topk_indices.shape)}" + ) + if topk_indices.dtype != torch.int32: + raise ValueError( + f"Expected topk_indices dtype torch.int32, got {topk_indices.dtype}" + ) + if topk_indices.shape[0] != q.shape[0]: + raise ValueError( + "Expected topk_indices first dimension to match q tokens, " + f"got {topk_indices.shape[0]} and {q.shape[0]}" + ) + + @staticmethod + def _enable_sparse_mock() -> bool: + return os.getenv("ATOM_RTP_ENABLE_SPARSE_MLA_MOCK", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + @staticmethod + def _strict_sparse() -> bool: + return os.getenv("ATOM_RTP_SPARSE_MLA_STRICT", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + @staticmethod + def _impl_accepts_positions(impl: object) -> bool: + try: + signature = inspect.signature(impl.forward) + except (AttributeError, TypeError, ValueError): + return False + return "positions" in signature.parameters or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + + @staticmethod + def _call_accepts_positions(callable_obj: object) -> bool: + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + return False + return "positions" in signature.parameters or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + + def _dense_forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + topk_indices: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + ) -> torch.Tensor: + kwargs = {"topk_indices": topk_indices} + if self._call_accepts_positions(self.dense_backend.forward): + kwargs["positions"] = positions + return self.dense_backend.forward( + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + **kwargs, + ) + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + topk_indices: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if topk_indices is None: + return self._dense_forward( + q, compressed_kv, k_pe, kv_cache, layer_id, None, positions + ) + + self._validate_topk_indices(q, topk_indices) + if ( + (self._default_mock and not self._enable_sparse_mock()) + or not callable(getattr(self.sparse_impl, "forward", None)) + ): + return self._dense_forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions + ) + + kwargs = { + "topk_indices": topk_indices, + "attn_metadata": self._get_attn_metadata(), + } + if self._impl_accepts_positions(self.sparse_impl): + kwargs["positions"] = positions + try: + return self.sparse_impl.forward( + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + **kwargs, + ) + except _SparseUnavailable: + if self._strict_sparse(): + raise + return self._dense_forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions + ) diff --git a/atom/plugin/rtpllm/models/__init__.py b/atom/plugin/rtpllm/models/__init__.py index ecaaacb47b..0fd1c7d8bb 100644 --- a/atom/plugin/rtpllm/models/__init__.py +++ b/atom/plugin/rtpllm/models/__init__.py @@ -1,3 +1,14 @@ -from .base_model_wrapper import ATOMQwen35Moe +try: + from .base_model_wrapper import ATOMGlm5Moe, ATOMQwen35Moe +except ModuleNotFoundError as exc: + if not (exc.name or "").startswith("rtp_llm"): + raise + ATOMGlm5Moe = None + ATOMQwen35Moe = None +else: + from atom.models.deepseek_v2 import GlmMoeDsaForCausalLM + from atom.plugin.register import _ATOM_SUPPORTED_MODELS -__all__ = ["ATOMQwen35Moe"] + _ATOM_SUPPORTED_MODELS.setdefault("GlmMoeDsaForCausalLM", GlmMoeDsaForCausalLM) + +__all__ = ["ATOMGlm5Moe", "ATOMQwen35Moe"] diff --git a/atom/plugin/rtpllm/models/base_model_wrapper.py b/atom/plugin/rtpllm/models/base_model_wrapper.py index d952107b95..b0aed863a6 100644 --- a/atom/plugin/rtpllm/models/base_model_wrapper.py +++ b/atom/plugin/rtpllm/models/base_model_wrapper.py @@ -14,6 +14,7 @@ register_model, ) +from atom.plugin.rtpllm.models.glm5 import ATOMGlm5Moe from atom.plugin.rtpllm.models.qwen3_5 import ATOMQwen35Moe @@ -28,4 +29,12 @@ def _register_atom_qwen35_moe() -> None: _hf_architecture_2_ft["Qwen3_5MoeForConditionalGeneration"] = "qwen35_moe" +def _register_atom_glm5_moe() -> None: + """Register ATOM's rtp-llm model hook for GLM5.""" + register_model("atom_glm5_moe", ATOMGlm5Moe, []) + _model_factory["glm_5"] = ATOMGlm5Moe + _hf_architecture_2_ft["GlmMoeDsaForCausalLM"] = "glm_5" + + _register_atom_qwen35_moe() +_register_atom_glm5_moe() diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py new file mode 100644 index 0000000000..aa9236d893 --- /dev/null +++ b/atom/plugin/rtpllm/models/glm5.py @@ -0,0 +1,520 @@ +"""GLM5 wrapper for rtp-llm external model loading.""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import torch +from rtp_llm.config.model_config import ModelConfig +from rtp_llm.model_loader.model_weight_info import ModelWeights +from rtp_llm.models.deepseek_v2 import DeepSeekV2 +from rtp_llm.models_py.model_desc.module_base import GptModelBase +from rtp_llm.ops import ParallelismConfig +from rtp_llm.ops.compute_ops import PyModelInputs, PyModelOutputs +from rtp_llm.utils.model_weight import W + +logger = logging.getLogger("atom.plugin.rtpllm.models") + +# Patched in tests; lazily imported in runtime to keep module import lightweight. +RTPForwardContext = None + + +class _NoopWeightManager: + def update(self, req): # noqa: ANN001 + return None + + +class _NoopModelWeightsLoader: + _py_eplb = None + + def load_lora_weights(self, adapter_name, lora_path, device): # noqa: ANN001 + logger.warning( + "No-op model_weights_loader received load_lora_weights(%s, %s, %s); " + "external plugin mode uses ATOM model weights path only.", + adapter_name, + lora_path, + device, + ) + return None + + +class _ATOMGlm5AttnPyObj: + """Minimal attention object so RTP does not build native MLA fmha_impl.""" + + is_cuda_graph = False + + @property + def fmha_params(self): + return None + + def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 + return None + + +class _ATOMGlm5MoeRuntime(GptModelBase): + """rtp-llm runtime adapter backed by an ATOM GLM5 model.""" + + def __init__( + self, + model_config: ModelConfig, + parallelism_config: ParallelismConfig, + weights: ModelWeights, + max_generate_batch_size: int, + atom_model: Any, + fmha_config=None, + py_hw_kernel_config=None, + device_resource_config=None, + ) -> None: + super().__init__( + model_config, + parallelism_config, + weights, + max_generate_batch_size=max_generate_batch_size, + fmha_config=fmha_config, + py_hw_kernel_config=py_hw_kernel_config, + device_resource_config=device_resource_config, + ) + self.model = atom_model + first_param = next(iter(self.model.parameters()), None) + if first_param is not None: + self._model_device = first_param.device + self._model_dtype = first_param.dtype + else: + self._model_device = torch.device("cpu") + self._model_dtype = torch.get_default_dtype() + forward_context_cls = self._get_forward_context_cls() + self._rtp_layer_maps = forward_context_cls.collect_layer_maps(model=self.model) + self._rtp_kv_cache_data: dict | None = None + self._rtp_kv_cache_signature: tuple | None = None + self._rtp_layer_group_map: dict[int, int] | None = None + self._rtp_layer_group_map_signature: tuple | None = None + self._cg_max_seq_len: int = int( + getattr(model_config, "max_seq_len", 0) + or getattr(model_config, "max_position_embeddings", 0) + or 32768 + ) + self._atom_attn_pyobj: _ATOMGlm5AttnPyObj | None = None + + def load_weights(self): + return None + + def prepare_fmha_impl( + self, inputs: PyModelInputs, is_cuda_graph: bool = False + ) -> _ATOMGlm5AttnPyObj: + if self._atom_attn_pyobj is None: + self._atom_attn_pyobj = _ATOMGlm5AttnPyObj() + self._atom_attn_pyobj.is_cuda_graph = bool(is_cuda_graph) + if bool(is_cuda_graph): + inputs.attention_inputs.is_cuda_graph = True + return self._atom_attn_pyobj + + @staticmethod + def _get_forward_context_cls(): + global RTPForwardContext + if RTPForwardContext is None: + from atom.plugin.rtpllm.utils import RTPForwardContext as _RTPForwardContext + + RTPForwardContext = _RTPForwardContext + return RTPForwardContext + + def _get_model_device(self) -> torch.device: + return self._model_device + + def _get_model_dtype(self) -> torch.dtype: + return self._model_dtype + + def _get_token_num( + self, inputs: PyModelInputs, input_ids: torch.Tensor | None + ) -> int: + if input_ids is not None and input_ids.numel() > 0: + return int(input_ids.numel()) + input_hiddens = getattr(inputs, "input_hiddens", None) + if input_hiddens is not None and input_hiddens.numel() > 0: + return int(input_hiddens.shape[0]) + return 0 + + @staticmethod + def _build_token_positions( + input_lengths: torch.Tensor, + starts: torch.Tensor, + ) -> torch.Tensor | None: + token_starts = torch.repeat_interleave(starts, input_lengths) + if token_starts.numel() == 0: + return None + per_seq_base = input_lengths.cumsum(dim=0) - input_lengths + token_ordinal = ( + torch.cumsum( + torch.repeat_interleave(torch.ones_like(input_lengths), input_lengths), + dim=0, + ) + - 1 + ) + token_ordinal = token_ordinal - torch.repeat_interleave( + per_seq_base, input_lengths + ) + return (token_starts + token_ordinal).to(dtype=torch.int32).contiguous() + + def _build_positions_from_attention_inputs( + self, attn_inputs: Any, model_device: torch.device + ) -> torch.Tensor | None: + if attn_inputs is None: + return None + + input_lengths = getattr(attn_inputs, "input_lengths", None) + if input_lengths is None or input_lengths.numel() == 0: + return None + input_lengths_i32 = input_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + if is_prefill: + prefix_lengths = getattr(attn_inputs, "prefix_lengths", None) + if prefix_lengths is None or prefix_lengths.numel() == 0: + return None + prefix_lengths_i32 = prefix_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(prefix_lengths_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = prefix_lengths_i32[: int(input_lengths_i32.numel())] + return self._build_token_positions(input_lengths_i32, starts) + + sequence_lengths = getattr(attn_inputs, "sequence_lengths", None) + if sequence_lengths is None or sequence_lengths.numel() == 0: + return None + sequence_lengths_i32 = sequence_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(sequence_lengths_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = sequence_lengths_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + 1 + return self._build_token_positions(input_lengths_i32, starts) + + def _extract_combo_positions( + self, inputs: PyModelInputs, model_device: torch.device + ) -> torch.Tensor | None: + bert_inputs = getattr(inputs, "bert_embedding_inputs", None) + if bert_inputs is None: + return None + combo_position_ids = getattr(bert_inputs, "combo_position_ids", None) + if combo_position_ids is None or combo_position_ids.numel() == 0: + return None + return combo_position_ids.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + + def _extract_positions( + self, inputs: PyModelInputs, model_device: torch.device, token_num: int + ) -> torch.Tensor: + attn_inputs = getattr(inputs, "attention_inputs", None) + if attn_inputs is None: + raise ValueError( + "GLM5 RTP plugin requires inputs.attention_inputs to provide position metadata." + ) + positions = getattr(attn_inputs, "position_ids", None) + if positions is None or positions.numel() == 0: + positions = self._extract_combo_positions( + inputs=inputs, model_device=model_device + ) + if positions is None or positions.numel() == 0: + positions = self._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=model_device, + ) + if positions is None or positions.numel() == 0: + raise ValueError( + "GLM5 RTP plugin requires real position metadata from attention_inputs." + ) + positions = positions.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + if not torch.cuda.is_current_stream_capturing(): + pos_tokens = int(positions.shape[-1]) if positions.dim() > 0 else int(positions.numel()) + if token_num > 0 and pos_tokens != token_num: + rebuilt_positions = self._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=model_device, + ) + rebuilt_tokens = ( + int(rebuilt_positions.shape[-1]) + if rebuilt_positions is not None and rebuilt_positions.dim() > 0 + else ( + int(rebuilt_positions.numel()) + if rebuilt_positions is not None + else -1 + ) + ) + if rebuilt_positions is not None and rebuilt_tokens == token_num: + positions = rebuilt_positions.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + elif pos_tokens > token_num: + positions = positions[..., -token_num:].contiguous() + else: + raise ValueError( + "GLM5 RTP plugin position_ids/token_num mismatch " + f"(position_ids_tokens={pos_tokens}, token_num={token_num})." + ) + return positions + + def forward(self, inputs: PyModelInputs, fmha_impl=None) -> PyModelOutputs: # noqa: ANN001 + if bool(getattr(fmha_impl, "is_cuda_graph", False)): + inputs.attention_inputs.is_cuda_graph = True + model_device = self._get_model_device() + model_dtype = self._get_model_dtype() + input_ids = inputs.input_ids + inputs_embeds = None + + if ( + input_ids is not None + and input_ids.numel() > 0 + and input_ids.device != model_device + ): + input_ids = input_ids.to(device=model_device, non_blocking=True) + token_num = self._get_token_num(inputs=inputs, input_ids=input_ids) + positions = self._extract_positions( + inputs=inputs, model_device=model_device, token_num=token_num + ) + if input_ids is None or input_ids.numel() == 0: + inputs_embeds = inputs.input_hiddens + if ( + inputs_embeds is not None + and inputs_embeds.numel() > 0 + and inputs_embeds.device != model_device + ): + inputs_embeds = inputs_embeds.to(device=model_device, non_blocking=True) + if ( + inputs_embeds is not None + and inputs_embeds.numel() > 0 + and inputs_embeds.dtype != model_dtype + ): + inputs_embeds = inputs_embeds.to(dtype=model_dtype) + + forward_context_cls = self._get_forward_context_cls() + with forward_context_cls.bind( + model=self.model, + runtime=self, + inputs=inputs, + positions=positions, + layer_maps=self._rtp_layer_maps, + cg_max_seq_len=int(self._cg_max_seq_len), + cg_bufs=getattr(self, "_cg_meta_bufs", None), + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=inputs_embeds, + ) + return PyModelOutputs(hidden_states) + + +class ATOMGlm5Moe(DeepSeekV2): + """GLM5 model class that starts ATOM runtime in rtp-llm plugin mode.""" + + @staticmethod + def _is_external_plugin_mode() -> bool: + modules = os.getenv("RTP_LLM_EXTERNAL_MODEL_PACKAGES", "") + return "atom.plugin.rtpllm.models" in modules + + def support_cuda_graph(self) -> bool: + if os.getenv("ENABLE_CUDA_GRAPH", "1") == "0": + logger.info("ENABLE_CUDA_GRAPH=0 - ATOMGlm5Moe forces eager forward.") + return False + return True + + @staticmethod + def _make_glm5_hf_mapper(): + from atom.model_loader.loader import WeightsMapper + + return WeightsMapper( + orig_to_new_prefix={}, + orig_to_new_substr={ + "indexers_proj.": "indexer.weights_proj.", + }, + ) + + @staticmethod + def _get_named_parameters(atom_model: Any) -> dict[str, torch.Tensor]: + if atom_model is None or not hasattr(atom_model, "named_parameters"): + return {} + return { + name: param + for name, param in atom_model.named_parameters(recurse=True) + if param is not None + } + + @staticmethod + def _first_param( + params: dict[str, torch.Tensor], candidates: tuple[str, ...] + ) -> torch.Tensor | None: + for name in candidates: + param = params.get(name) + if param is not None: + return param + return None + + def _inject_rtp_projection_weights(self, atom_model: Any) -> None: + params = self._get_named_parameters(atom_model) + if not params: + logger.warning( + "Skip GLM5 RTP projection weight injection because atom_model has no named parameters." + ) + return + + required = { + W.lm_head: ( + "language_model.lm_head.weight", + "lm_head.weight", + ), + W.embedding: ( + "language_model.model.embed_tokens.weight", + "model.embed_tokens.weight", + ), + W.final_ln_gamma: ( + "language_model.model.norm.weight", + "model.norm.weight", + ), + } + missing = [] + for weight_name, candidates in required.items(): + param = self._first_param(params, candidates) + if param is None: + missing.append((weight_name, candidates)) + continue + self.weight.set_global_weight(weight_name, param.detach()) + logger.info( + "Injected GLM5 runtime %s for RTP: %s", + weight_name, + tuple(param.shape), + ) + if missing: + details = ", ".join( + f"{weight_name} candidates={candidates}" + for weight_name, candidates in missing + ) + raise ValueError(f"Cannot locate GLM5 RTP runtime projection weights: {details}") + + def _assert_norm_weights_loaded(self, atom_model: Any) -> None: + params = self._get_named_parameters(atom_model) + if not params: + logger.warning( + "Skip GLM5 norm weight validation because atom_model has no named parameters." + ) + return + norm_w = self._first_param( + params, + ( + "language_model.model.layers.0.input_layernorm.weight", + "model.layers.0.input_layernorm.weight", + ), + ) + if norm_w is None: + raise ValueError( + "Cannot locate GLM5 layer-0 input_layernorm.weight after ATOM load in RTP plugin mode." + ) + norm_w_cpu = norm_w.detach().float().reshape(-1).cpu() + if norm_w_cpu.numel() == 0 or bool(torch.all(norm_w_cpu == 0)): + raise ValueError( + "Loaded GLM5 layer-0 input_layernorm.weight is all zeros; " + "refusing to run with default values." + ) + + def load(self, skip_python_model: bool = False): + if self._is_external_plugin_mode(): + self.device = self._get_device_str() + self.weight = ModelWeights( + num_layers=self.model_config.num_layers, + device=self.device, + dtype=self.model_config.compute_dtype, + ) + self.model_weights_loader = _NoopModelWeightsLoader() + self.py_eplb = self.model_weights_loader._py_eplb + self.weight_manager = _NoopWeightManager() + if skip_python_model: + logger.info( + "External plugin mode: skip ATOM GLM5 python model creation as requested" + ) + return + self._create_python_model() + logger.info( + "External plugin mode: use ATOM GLM5 loading path and skip native load" + ) + return + + super().load(skip_python_model=skip_python_model) + + def _create_python_model(self): + if not self._is_external_plugin_mode(): + return super()._create_python_model() + + import atom + from atom.model_loader.loader import load_model_in_plugin_mode + + target_device = torch.device( + self.device if getattr(self, "device", None) else "cuda" + ) + target_dtype = self.model_config.compute_dtype + old_default_dtype = torch.get_default_dtype() + try: + old_default_device = torch.get_default_device() + except Exception: + old_default_device = None + + torch.set_default_device(target_device) + if target_dtype in { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + }: + torch.set_default_dtype(target_dtype) + + try: + atom_model = atom.prepare_model(config=self, engine="rtpllm") + if atom_model is None: + raise ValueError("ATOM failed to create GLM5 model for rtp-llm plugin") + + if hasattr(atom_model, "to"): + atom_model = atom_model.to(target_device) + + atom_config = getattr(atom_model, "atom_config", None) + if atom_config is None: + atom_config = getattr( + getattr(atom_model, "model", None), "atom_config", None + ) + if atom_config is None: + # M0 tests use mocked ATOM models; real loading must expose atom_config. + atom_config = getattr(self, "atom_config", None) + + load_model_in_plugin_mode( + model=atom_model, + config=atom_config, + prefix="model.", + weights_mapper=self._make_glm5_hf_mapper(), + ) + self._assert_norm_weights_loaded(atom_model) + self._inject_rtp_projection_weights(atom_model) + finally: + torch.set_default_dtype(old_default_dtype) + if old_default_device is not None: + torch.set_default_device(old_default_device) + else: + torch.set_default_device("cpu") + + self.py_model = _ATOMGlm5MoeRuntime( + model_config=self.model_config, + parallelism_config=self.parallelism_config, + weights=self.weight, + max_generate_batch_size=self.max_generate_batch_size, + fmha_config=self.fmha_config, + py_hw_kernel_config=self.hw_kernel_config, + device_resource_config=self.device_resource_config, + atom_model=atom_model, + ) + logger.info("Created ATOM GLM5 runtime for rtp-llm plugin mode") + return self.py_model + diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 2566966d97..265b4612dd 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -1,11 +1,13 @@ from __future__ import annotations import os +import warnings from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Iterator, Tuple import torch +from aiter import dtypes from atom.config import KVCacheTensor, get_current_atom_config from atom.model_ops.attention_gdn import GatedDeltaNet @@ -73,7 +75,9 @@ class RTPForwardContext: layer_group_map: Dict[int, int] context: Context num_tokens: int - LayerMaps = tuple[Dict[int, GatedDeltaNet], Dict[int, Any]] + mla_layer_map: Dict[int, Any] + use_rtp_indexer_cache: bool = False + LayerMaps = tuple[Dict[int, GatedDeltaNet], Dict[int, Any], Dict[int, Any]] @staticmethod def _non_empty_int32( @@ -305,6 +309,22 @@ def _select_block_table_for_layer( return by_group[gid] return getattr(attn_inputs, "kv_cache_kernel_block_id_device", None) + @staticmethod + def _select_physical_block_table_for_layer( + attn_inputs: Any, + group_id: int | None = None, + ) -> torch.Tensor | None: + # MLA cache writes use concat_and_cache_mla(slot_mapping), whose slot is + # indexed in the physical KV cache layout, not the smaller kernel block + # granularity used by some RTP attention kernels. + block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if block_table is not None: + return block_table + return RTPForwardContext._select_block_table_for_layer( + attn_inputs=attn_inputs, + group_id=group_id, + ) + @staticmethod def _build_layer_group_map(attn_inputs: Any) -> Dict[int, int]: layer_to_group = getattr(attn_inputs, "kv_cache_layer_to_group", None) @@ -646,18 +666,7 @@ def _build_slot_mapping( "RTP plugin block_table/query_start_loc batch mismatch " f"(block_table={int(bt.shape[0])}, batch={batch_size})." ) - validate_slot_mapping = os.getenv("ATOM_VALIDATE_SLOT_MAPPING", "0") == "1" - if validate_slot_mapping and int(qsl[-1].item()) != num_tokens: - raise ValueError( - "RTP plugin query_start_loc/positions token mismatch " - f"(query_start_loc[-1]={int(qsl[-1].item())}, positions={num_tokens})." - ) - lengths = qsl[1:] - qsl[:-1] - if validate_slot_mapping and torch.any(lengths <= 0): - raise ValueError( - "RTP plugin query_start_loc contains non-positive sequence length." - ) if in_capture and cg_bufs is not None: # Zero-alloc path: use pre-allocated buffers so captured GPU ops # reference stable addresses that stay alive through replay. @@ -694,29 +703,14 @@ def _build_slot_mapping( torch.arange(batch_size, device=device, dtype=torch.int64), lengths.to(dtype=torch.int64), ) - if validate_slot_mapping and int(seq_id.numel()) != num_tokens: - raise ValueError( - "RTP plugin internal seq_id construction mismatch for slot_mapping." - ) block_col = torch.div( pos_i32, int(seq_size_per_block), rounding_mode="floor", ) - if validate_slot_mapping and ( - torch.any(block_col < 0) or torch.any(block_col >= bt.shape[1]) - ): - raise ValueError( - "RTP plugin block-table index out of range for full-attn slot_mapping " - f"(max_col={int(bt.shape[1]) - 1})." - ) slot_base = bt[seq_id, block_col.to(dtype=torch.int64)] - if validate_slot_mapping and torch.any(slot_base < 0): - raise ValueError( - "RTP plugin resolved padded/invalid (-1) block slot for full-attn slot_mapping." - ) token_offset = torch.remainder(pos_i32, int(seq_size_per_block)) slot_mapping = slot_base * int(seq_size_per_block) + token_offset return slot_mapping.to(dtype=torch.int64).contiguous() @@ -790,21 +784,84 @@ def _build_query_start_loc_for_plugin( f"(batch={batch_size}, num_tokens={int(num_tokens)})." ) + @staticmethod + def _build_req_id_per_token( + *, + query_start_loc: torch.Tensor, + num_tokens: int, + device: torch.device, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + batch_size = int(query_start_loc.numel()) - 1 + if batch_size <= 0: + raise ValueError( + "RTP plugin cannot build req_id_per_token for empty batch." + ) + if int(num_tokens) == 0: + return torch.empty((0,), dtype=torch.int32, device=device) + if cg_bufs is not None and "seq_id" in cg_bufs: + seq_id = cg_bufs["seq_id"][:num_tokens] + return seq_id.to( + device=device, dtype=torch.int32, non_blocking=True + ).contiguous() + lengths = (query_start_loc[1:] - query_start_loc[:-1]).to(dtype=torch.int64) + if not torch.cuda.is_current_stream_capturing() and int( + lengths.sum().item() + ) != int(num_tokens): + raise ValueError( + "RTP plugin query_start_loc/num_tokens mismatch for req_id_per_token " + f"(query_start_loc[-1]={int(query_start_loc[-1].item())}, " + f"num_tokens={int(num_tokens)})." + ) + return torch.repeat_interleave( + torch.arange(batch_size, device=device, dtype=torch.int32), + lengths, + ).contiguous() + + @staticmethod + def _expand_block_table_for_atom_indexer( + block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot expand block_table for ATOM indexer: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + offsets = torch.arange( + block_ratio, device=block_table.device, dtype=torch.int32 + ) + base = block_table.to(dtype=torch.int32) + expanded = base.unsqueeze(-1) * block_ratio + offsets + expanded = torch.where(base.unsqueeze(-1) >= 0, expanded, -1) + return expanded.reshape(base.shape[0], base.shape[1] * block_ratio).contiguous() + @staticmethod def _build_plugin_attention_metadata( *, attn_inputs: Any, positions: torch.Tensor, seq_size_per_block: int, + kernel_seq_size_per_block: int = 0, cg_max_seq_len: int = 0, cg_bufs: dict | None = None, ) -> AttentionMetaData: - block_table = RTPForwardContext._select_block_table_for_layer( + block_table = RTPForwardContext._select_physical_block_table_for_layer( attn_inputs=attn_inputs, ) if block_table is None or block_table.numel() == 0: raise ValueError( - "RTP plugin requires kv_cache_kernel_block_id_device for plugin attention metadata." + "RTP plugin requires kv_cache_block_id_device for plugin attention metadata." ) device = positions.device is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) @@ -853,6 +910,12 @@ def _build_plugin_attention_metadata( seq_size_per_block=seq_size_per_block, cg_bufs=cg_bufs, ) + req_id_per_token = RTPForwardContext._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=num_actual_tokens, + device=device, + cg_bufs=cg_bufs if in_capture else None, + ) is_dummy_warmup = False if in_capture: @@ -932,6 +995,16 @@ def _build_plugin_attention_metadata( block_table_i32 = block_table.to( device=device, dtype=torch.int32, non_blocking=True ).contiguous() + if in_capture: + indexer_block_table_i32 = block_table_i32 + else: + indexer_block_table_i32 = ( + RTPForwardContext._expand_block_table_for_atom_indexer( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + ) + ) plugin_md = AiterFlashAttentionMetadataForPluginMode( num_actual_tokens=num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, @@ -957,6 +1030,28 @@ def _build_plugin_attention_metadata( ) # Prefill-only fields shared across all full-attn layers in the step. plugin_md.rtp_cu_seqlens_q = query_start_loc + plugin_md.req_id_per_token = req_id_per_token + plugin_md.topk_tokens = 0 + plugin_md.sparse_block_size = int(seq_size_per_block) + cu_seqlen_ks = None + cu_seqlen_ke = None + if is_prefill: + prefill_lengths = (query_start_loc[1:] - query_start_loc[:-1]).to( + dtype=torch.int64 + ) + if in_capture and cg_bufs is not None and "seq_id" in cg_bufs: + seq_id_for_span = cg_bufs["seq_id"][:num_actual_tokens] + else: + seq_id_for_span = torch.repeat_interleave( + torch.arange(batch_size, device=device, dtype=torch.int64), + prefill_lengths, + ) + cu_seqlen_ks = ( + query_start_loc[:-1][seq_id_for_span].to(dtype=torch.int32).contiguous() + ) + cu_seqlen_ke = ( + torch.arange(num_actual_tokens, device=device, dtype=torch.int32) + 1 + ).contiguous() # Mark dummy probe (RTP initCapture's "forward for output datatype" feeds # all-zero seq_lens/block_tables); RTPFullAttention short-circuits to zeros. plugin_md.is_dummy_warmup = bool(is_dummy_warmup) @@ -973,11 +1068,17 @@ def _build_plugin_attention_metadata( else: plugin_md.rtp_has_prefix = False attn_metadata = AttentionMetaData( + cu_seqlens_q=query_start_loc, + cu_seqlens_k=query_start_loc, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - block_tables=plugin_md.block_table, + block_tables=indexer_block_table_i32, slot_mapping=slot_mapping, context_lens=seq_lens, + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + has_cached=False, + total_kv=int(num_actual_kv_tokens), ) attn_metadata.plugin_metadata = plugin_md return attn_metadata @@ -986,7 +1087,17 @@ def _build_plugin_attention_metadata( def collect_layer_maps(model: Any) -> LayerMaps: gdn_layer_map: Dict[int, GatedDeltaNet] = {} full_attn_layer_map: Dict[int, Any] = {} + mla_layer_map: Dict[int, Any] = {} rtp_attention_cls: type[Any] | None = None + rtp_mla_attention_cls: type[Any] | None = None + try: + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( + RTPMLAAttention, + ) + + rtp_mla_attention_cls = RTPMLAAttention + except (ImportError, ModuleNotFoundError): + rtp_mla_attention_cls = None try: from atom.plugin.rtpllm.attention_backend import AttentionForRTPLLM @@ -997,6 +1108,20 @@ def collect_layer_maps(model: Any) -> LayerMaps: for module in model.modules(): if isinstance(module, GatedDeltaNet): gdn_layer_map[int(module.layer_num)] = module + elif ( + getattr(module, "indexer", None) is not None + and getattr(module, "mla_attn", None) is not None + and getattr(module, "layer_num", None) is not None + ): + mla_layer_map[int(module.layer_num)] = module + elif rtp_mla_attention_cls is not None and isinstance( + module, rtp_mla_attention_cls + ): + layer_num = getattr(module, "layer_id", None) + if layer_num is None: + layer_num = getattr(module, "layer_num", None) + if layer_num is not None and int(layer_num) not in mla_layer_map: + mla_layer_map[int(layer_num)] = module elif isinstance(module, (PagedAttention, PagedAttentionImpl)) or ( rtp_attention_cls is not None and isinstance(module, rtp_attention_cls) ): @@ -1006,7 +1131,7 @@ def collect_layer_maps(model: Any) -> LayerMaps: layer_num = getattr(module, "layer_num", None) if layer_num is not None: full_attn_layer_map[int(layer_num)] = module - return gdn_layer_map, full_attn_layer_map + return gdn_layer_map, full_attn_layer_map, mla_layer_map @staticmethod def _build_kv_cache_tensors( @@ -1016,9 +1141,9 @@ def _build_kv_cache_tensors( if runtime.kv_cache is None: raise ValueError("RTP plugin requires initialized kv_cache for ATOM model.") - gdn_layer_map, full_attn_layer_map = layer_maps + gdn_layer_map, full_attn_layer_map, mla_layer_map = layer_maps - if not gdn_layer_map and not full_attn_layer_map: + if not gdn_layer_map and not full_attn_layer_map and not mla_layer_map: return {} cache_tensors: Dict[str, KVCacheTensor] = {} @@ -1109,6 +1234,31 @@ def _build_kv_cache_tensors( k_scale=None, v_scale=None, ) + # Build MLA cache references separately from full attention. MLA adapters + # own their kv_cache pointer and refresh it in bind() for every forward. + for layer_num in mla_layer_map.keys(): + layer_key = f"layer_{layer_num}" + if layer_key in cache_tensors: + continue + + layer_cache = runtime.kv_cache.get_layer_cache(layer_num) + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None: + raise ValueError( + f"Layer {layer_num} kv_cache_base is missing for MLA cache." + ) + if kv_cache_base.dim() < 1: + raise ValueError( + f"Layer {layer_num} MLA kv_cache_base has invalid shape " + f"{tuple(kv_cache_base.shape)}." + ) + cache_tensors[layer_key] = KVCacheTensor( + layer_num=layer_num, + k_cache=layer_cache, + v_cache=None, + k_scale=None, + v_scale=None, + ) return cache_tensors @staticmethod @@ -1118,10 +1268,12 @@ def _kv_cache_signature( ) -> Tuple[Any, ...]: if runtime.kv_cache is None: return ("no_kv_cache",) - gdn_layer_map, full_attn_layer_map = layer_maps + gdn_layer_map, full_attn_layer_map, mla_layer_map = layer_maps signature: list[Any] = [id(runtime.kv_cache)] all_layer_nums = sorted( - set(gdn_layer_map.keys()) | set(full_attn_layer_map.keys()) + set(gdn_layer_map.keys()) + | set(full_attn_layer_map.keys()) + | set(mla_layer_map.keys()) ) for layer_num in all_layer_nums: layer_cache = runtime.kv_cache.get_layer_cache(layer_num) @@ -1136,6 +1288,16 @@ def _kv_cache_signature( int(kv_cache_base.numel()), ) ) + kv_scale_base = getattr(layer_cache, "kv_scale_base", None) + if kv_scale_base is not None and kv_scale_base.numel() > 0: + signature.append( + ( + int(layer_num), + "scale", + int(kv_scale_base.data_ptr()), + int(kv_scale_base.numel()), + ) + ) return tuple(signature) @classmethod @@ -1194,7 +1356,8 @@ def build( attn_metadata = cls._build_plugin_attention_metadata( attn_inputs=attn_inputs, positions=positions, - seq_size_per_block=kernel_seq_size_per_block, + seq_size_per_block=seq_size_per_block, + kernel_seq_size_per_block=kernel_seq_size_per_block, cg_max_seq_len=int(cg_max_seq_len), cg_bufs=cg_bufs, ) @@ -1234,8 +1397,158 @@ def build( layer_group_map=layer_group_map, context=context, num_tokens=int(positions.numel()), + mla_layer_map=resolved_layer_maps[2], + use_rtp_indexer_cache=cls._use_rtp_indexer_cache(), + ) + + @staticmethod + def _use_rtp_indexer_cache() -> bool: + return os.getenv("ATOM_RTP_USE_RTP_INDEXER_CACHE", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + @staticmethod + def _resolve_rtp_indexer_cache( + *, + layer_num: int, + layer_cache: Any, + indexer: Any, + block_size: int, + ) -> torch.Tensor: + kv_scale_base = getattr(layer_cache, "kv_scale_base", None) + if kv_scale_base is None or kv_scale_base.numel() == 0: + raise ValueError( + f"Layer {layer_num} RTP indexer cache requires non-empty kv_scale_base." + ) + if kv_scale_base.dtype == torch.uint8: + kv_scale_base = kv_scale_base.view(dtypes.fp8) + if kv_scale_base.dtype != dtypes.fp8: + raise ValueError( + f"Layer {layer_num} RTP indexer cache dtype mismatch " + f"(got={kv_scale_base.dtype}, expected={dtypes.fp8} or torch.uint8)." + ) + if block_size <= 0: + raise ValueError( + f"Layer {layer_num} RTP indexer cache got invalid block_size={block_size}." + ) + head_dim = int(getattr(indexer, "head_dim", 0) or 0) + if head_dim <= 0: + raise ValueError( + f"Layer {layer_num} RTP indexer cache requires positive indexer.head_dim." + ) + if head_dim != 128: + warnings.warn( + "RTP indexer cache binding has only been layout-checked for " + "GLM5 head_dim=128; cross-kernel byte semantics are not verified " + f"for head_dim={head_dim}.", + RuntimeWarning, + stacklevel=2, + ) + expected_raw_dim = head_dim + (head_dim // 128) * 4 + expected_aligned_dim = ((expected_raw_dim + 15) // 16) * 16 + allowed_dims = {expected_raw_dim, expected_aligned_dim} + + if kv_scale_base.dim() == 3 and int(kv_scale_base.shape[-1]) in allowed_dims: + return kv_scale_base + if kv_scale_base.dim() == 2 and int(kv_scale_base.shape[1]) % block_size == 0: + per_token_dim = int(kv_scale_base.shape[1]) // block_size + if per_token_dim in allowed_dims: + return kv_scale_base.view( + kv_scale_base.shape[0], block_size, per_token_dim + ) + raise ValueError( + f"Layer {layer_num} RTP indexer cache layout mismatch " + f"(shape={tuple(kv_scale_base.shape)}, block_size={block_size}, " + f"allowed_last_dims={sorted(allowed_dims)})." ) + @staticmethod + def _build_fallback_indexer_cache( + *, + cache_owner: Any, + layer_cache: Any, + indexer: Any, + block_size: int, + ) -> torch.Tensor | None: + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None or kv_cache_base.dim() == 0: + return None + index_dim = int(getattr(indexer, "head_dim", 0) or 0) + 4 + if index_dim <= 4: + return None + aligned_dim = ((index_dim + 15) // 16) * 16 + num_tokens = int(kv_cache_base.shape[0]) * block_size + cached = getattr(cache_owner, "_rtp_indexer_kv_cache", None) + expected_shape = (num_tokens, 1, aligned_dim) + if ( + cached is None + or tuple(cached.shape) != expected_shape + or cached.device != kv_cache_base.device + or cached.dtype != dtypes.fp8 + ): + cached = torch.empty( + expected_shape, + device=kv_cache_base.device, + dtype=dtypes.fp8, + ) + setattr(cache_owner, "_rtp_indexer_kv_cache", cached) + return cached + + @staticmethod + def _attach_mla_layer_caches( + forward_context: "RTPForwardContext", + ) -> tuple[list[tuple[Any, str, Any]], list[tuple[list[Any], int, Any]]]: + restore_attrs: list[tuple[Any, str, Any]] = [] + restore_indices: list[tuple[list[Any], int, Any]] = [] + for layer_num, layer in forward_context.mla_layer_map.items(): + cache_tensor = forward_context.kv_cache_data.get(f"layer_{layer_num}") + if cache_tensor is None: + continue + cache_owner = getattr(layer, "mla_attn", layer) + restore_attrs.append( + (cache_owner, "kv_cache", getattr(cache_owner, "kv_cache", None)) + ) + cache_owner.kv_cache = cache_tensor.k_cache + indexer = getattr(layer, "indexer", None) + if indexer is None: + indexer = getattr(cache_owner, "indexer", None) + indexer_cache = getattr(indexer, "k_cache", None) + indexer_kv_cache = getattr(indexer_cache, "kv_cache", None) + if not isinstance(indexer_kv_cache, list) or not indexer_kv_cache: + continue + layer_cache = cache_tensor.k_cache + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None or kv_cache_base.dim() == 0: + continue + block_size = int( + getattr(forward_context, "rtp_seq_size_per_block", 0) + or getattr(forward_context, "rtp_kernel_seq_size_per_block", 0) + or getattr(get_current_atom_config(), "kv_cache_block_size", 0) + or 1 + ) + if bool(getattr(forward_context, "use_rtp_indexer_cache", False)): + indexer_cache_tensor = RTPForwardContext._resolve_rtp_indexer_cache( + layer_num=layer_num, + layer_cache=layer_cache, + indexer=indexer, + block_size=block_size, + ) + else: + indexer_cache_tensor = RTPForwardContext._build_fallback_indexer_cache( + cache_owner=cache_owner, + layer_cache=layer_cache, + indexer=indexer, + block_size=block_size, + ) + if indexer_cache_tensor is None: + continue + restore_indices.append((indexer_kv_cache, 0, indexer_kv_cache[0])) + indexer_kv_cache[0] = indexer_cache_tensor + return restore_attrs, restore_indices + @classmethod @contextmanager def bind( @@ -1265,8 +1578,16 @@ def bind( attn_md.rtp_kernel_seq_size_per_block = ( forward_context.rtp_kernel_seq_size_per_block ) + attn_md.rtp_seq_size_per_block = getattr( + forward_context, "rtp_seq_size_per_block", 0 + ) attn_md.rtp_layer_group_map = forward_context.layer_group_map + restore_mla_attrs: list[tuple[Any, str, Any]] = [] + restore_mla_indices: list[tuple[list[Any], int, Any]] = [] try: + restore_mla_attrs, restore_mla_indices = cls._attach_mla_layer_caches( + forward_context + ) set_kv_cache_data(forward_context.kv_cache_data) set_forward_context( attn_metadata=attn_md, @@ -1276,5 +1597,9 @@ def bind( ) yield finally: + for target, index, old_cache in reversed(restore_mla_indices): + target[index] = old_cache + for target, attr, old_cache in reversed(restore_mla_attrs): + setattr(target, attr, old_cache) reset_forward_context() set_kv_cache_data(prev_kv if prev_kv is not None else {}) diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index bd88b5b468..be1f581e01 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -4,7 +4,9 @@ import types from types import SimpleNamespace +import pytest import torch +from aiter import dtypes class _KwargsObject: @@ -40,10 +42,15 @@ def _install_forward_context_stubs(): utils_forward_context = types.ModuleType("atom.utils.forward_context") utils_forward_context.AttentionMetaData = _KwargsObject utils_forward_context.Context = _KwargsObject - utils_forward_context._forward_kv_cache_context = {} + utils_forward_context._forward_kv_cache_context = SimpleNamespace(kv_cache_data={}) utils_forward_context.reset_forward_context = lambda *args, **kwargs: None utils_forward_context.set_forward_context = lambda *args, **kwargs: None - utils_forward_context.set_kv_cache_data = lambda *args, **kwargs: None + utils_forward_context.get_forward_context = lambda *args, **kwargs: SimpleNamespace() + + def _set_kv_cache_data(value): + utils_forward_context._forward_kv_cache_context.kv_cache_data = value + + utils_forward_context.set_kv_cache_data = _set_kv_cache_data sys.modules["atom.utils.forward_context"] = utils_forward_context @@ -59,6 +66,7 @@ def _make_attn_inputs( sequence_lengths=None, sequence_lengths_plus_1_d=None, cu_seqlens=None, + kv_cache_block_id_device=None, kv_cache_kernel_block_id_device=None, is_prefill=False, is_cuda_graph=False, @@ -69,6 +77,7 @@ def _make_attn_inputs( sequence_lengths=sequence_lengths, sequence_lengths_plus_1_d=sequence_lengths_plus_1_d, cu_seqlens=cu_seqlens, + kv_cache_block_id_device=kv_cache_block_id_device, kv_cache_kernel_block_id_device=kv_cache_kernel_block_id_device, is_prefill=is_prefill, is_cuda_graph=is_cuda_graph, @@ -132,6 +141,68 @@ def test_rtpllm_forward_context_decode_metadata_state_indices_shape(): assert md.non_spec_state_indices_tensor.cpu().tolist() == [125] +def test_plugin_attention_metadata_slot_mapping_uses_physical_block_table(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1], dtype=torch.int32), + sequence_lengths=torch.tensor([1030], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor( + [[700, 701, 702]], dtype=torch.int32 + ), + is_prefill=False, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.tensor([1029], dtype=torch.int32), + seq_size_per_block=1024, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.plugin_metadata.slot_mapping.cpu().tolist() == [8 * 1024 + 5] + + +def test_plugin_attention_metadata_builds_req_id_per_token(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([2, 1], dtype=torch.int32), + prefix_lengths=torch.tensor([0, 0], dtype=torch.int32), + cu_seqlens=torch.tensor([0, 2, 3], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[3], [4]], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor([[30], [40]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.tensor([0, 1, 0], dtype=torch.int32), + seq_size_per_block=1024, + ) + + assert md.plugin_metadata.req_id_per_token.cpu().tolist() == [0, 0, 1] + assert md.plugin_metadata.sparse_block_size == 1024 + assert md.cu_seqlens_q.cpu().tolist() == [0, 2, 3] + assert md.cu_seqlens_k.cpu().tolist() == [0, 2, 3] + assert md.cu_seqlen_ks.cpu().tolist() == [0, 0, 2] + assert md.cu_seqlen_ke.cpu().tolist() == [1, 2, 3] + assert md.total_kv == 3 + + +def test_rtp_indexer_cache_accepts_byte_packed_kv_scale_base(): + kv_scale_base = torch.empty((2, 1024, 132), dtype=torch.uint8) + layer_cache = SimpleNamespace(kv_scale_base=kv_scale_base) + indexer = SimpleNamespace(head_dim=128) + + cache = RTPForwardContext._resolve_rtp_indexer_cache( + layer_num=0, + layer_cache=layer_cache, + indexer=indexer, + block_size=1024, + ) + + assert tuple(cache.shape) == (2, 1024, 132) + assert cache.dtype == dtypes.fp8 + + def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): input_lengths = torch.tensor([1], dtype=torch.int32) sequence_lengths = torch.tensor([35], dtype=torch.int32) @@ -159,3 +230,358 @@ def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): graph_inputs, device=input_lengths.device ) assert graph_seq_lens.cpu().tolist() == [36] + + +def test_collect_layer_maps_keeps_mla_layers_separate(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + model = SimpleNamespace(modules=lambda: [mla_layer]) + + gdn_map, full_attn_map, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert gdn_map == {} + assert full_attn_map == {} + assert mla_map == {7: mla_layer} + + +def test_collect_layer_maps_keeps_sparse_mla_owner_for_indexer_cache(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + sparse_owner = SimpleNamespace( + layer_num=7, + indexer=SimpleNamespace(), + mla_attn=mla_layer, + ) + model = SimpleNamespace(modules=lambda: [sparse_owner, mla_layer]) + + gdn_map, full_attn_map, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert gdn_map == {} + assert full_attn_map == {} + assert mla_map == {7: sparse_owner} + + +def test_collect_layer_maps_recognizes_atom_mla_wrapper_by_indexer_and_mla_attn(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + inner_mla = RTPMLAAttention(dense_backend=object(), layer_num=9) + atom_wrapper = SimpleNamespace( + layer_num=9, + indexer=SimpleNamespace(), + mla_attn=inner_mla, + ) + model = SimpleNamespace(modules=lambda: [atom_wrapper]) + + _, _, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert mla_map == {9: atom_wrapper} + + +def test_build_kv_cache_tensors_threads_raw_layer_cache_for_mla(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + layer_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + runtime = SimpleNamespace( + kv_cache=SimpleNamespace(get_layer_cache=lambda layer_num: layer_cache) + ) + mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + + cache_tensors = RTPForwardContext._build_kv_cache_tensors( + runtime=runtime, + layer_maps=({}, {}, {7: mla_layer}), + ) + + assert cache_tensors["layer_7"].layer_num == 7 + assert cache_tensors["layer_7"].k_cache is layer_cache + + +def test_bind_temporarily_attaches_mla_layer_cache(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_cache = SimpleNamespace(name="old-cache") + new_cache = SimpleNamespace(name="new-cache") + mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7, kv_cache=old_cache) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=new_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + use_rtp_indexer_cache=False, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + monkeypatch.setattr( + "atom.plugin.rtpllm.utils.forward_context.get_current_atom_config", + lambda: SimpleNamespace(kv_cache_block_size=99), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert mla_layer.kv_cache is new_cache + + assert mla_layer.kv_cache is old_cache + + +def test_bind_writes_kv_cache_to_mla_attn_owner_not_outer_wrapper(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + outer_cache = SimpleNamespace(name="outer-cache") + old_inner_cache = SimpleNamespace(name="old-inner-cache") + new_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[torch.empty(0)]), + ) + mla_layer = RTPMLAAttention( + dense_backend=object(), + layer_num=7, + kv_cache=old_inner_cache, + ) + outer = SimpleNamespace( + layer_num=7, + indexer=indexer, + mla_attn=mla_layer, + kv_cache=outer_cache, + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=new_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: outer}, + use_rtp_indexer_cache=False, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert outer.kv_cache is outer_cache + assert mla_layer.kv_cache is new_cache + + assert outer.kv_cache is outer_cache + assert mla_layer.kv_cache is old_inner_cache + + +def test_bind_temporarily_attaches_sparse_mla_indexer_cache(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_cache = SimpleNamespace(name="old-cache") + layer_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + old_index_cache = torch.empty(0) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[old_index_cache]), + ) + mla_layer = RTPMLAAttention( + dense_backend=object(), + layer_num=7, + kv_cache=old_cache, + mla_modules=SimpleNamespace(indexer=indexer), + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + use_rtp_indexer_cache=False, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + monkeypatch.setattr( + "atom.plugin.rtpllm.utils.forward_context.get_current_atom_config", + lambda: SimpleNamespace(kv_cache_block_size=16), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert mla_layer.kv_cache is layer_cache + assert indexer.k_cache.kv_cache[0] is not old_index_cache + assert indexer.k_cache.kv_cache[0].shape == (32, 1, 144) + + assert mla_layer.kv_cache is old_cache + assert indexer.k_cache.kv_cache[0] is old_index_cache + + +def test_bind_uses_rtp_kv_scale_base_when_enabled(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_cache = SimpleNamespace(name="old-cache") + old_index_cache = torch.empty(0) + kv_scale_base = torch.empty(2, 16, 132, dtype=dtypes.fp8) + layer_cache = SimpleNamespace( + kv_cache_base=torch.empty(2, 3), + kv_scale_base=kv_scale_base, + ) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[old_index_cache]), + ) + mla_layer = RTPMLAAttention( + dense_backend=object(), + layer_num=7, + kv_cache=old_cache, + mla_modules=SimpleNamespace(indexer=indexer), + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + use_rtp_indexer_cache=True, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert indexer.k_cache.kv_cache[0] is kv_scale_base + + assert indexer.k_cache.kv_cache[0] is old_index_cache + + +def test_bind_accepts_flattened_rtp_kv_scale_base_when_enabled(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_index_cache = torch.empty(0) + flat_kv_scale_base = torch.empty(2, 16 * 132, dtype=dtypes.fp8) + layer_cache = SimpleNamespace( + kv_cache_base=torch.empty(2, 3), + kv_scale_base=flat_kv_scale_base, + ) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[old_index_cache]), + ) + mla_layer = RTPMLAAttention( + dense_backend=object(), + layer_num=7, + mla_modules=SimpleNamespace(indexer=indexer), + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + use_rtp_indexer_cache=True, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert indexer.k_cache.kv_cache[0].data_ptr() == flat_kv_scale_base.data_ptr() + assert indexer.k_cache.kv_cache[0].shape == (2, 16, 132) + + assert indexer.k_cache.kv_cache[0] is old_index_cache + + +def test_bind_rejects_incompatible_indexer_cache_layout(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + layer_cache = SimpleNamespace( + kv_cache_base=torch.empty(2, 3), + kv_scale_base=torch.empty(2, 16, 64, dtype=dtypes.fp8), + ) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[torch.empty(0)]), + ) + mla_layer = RTPMLAAttention( + dense_backend=object(), + layer_num=7, + mla_modules=SimpleNamespace(indexer=indexer), + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + use_rtp_indexer_cache=True, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + + with pytest.raises(ValueError, match="layout mismatch"): + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + pass diff --git a/tests/plugin/test_rtpllm_glm5_indexer_contract.py b/tests/plugin/test_rtpllm_glm5_indexer_contract.py new file mode 100644 index 0000000000..bdea3d16b4 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_indexer_contract.py @@ -0,0 +1,230 @@ +"""Contract-executable tests for GLM5 RTP MLA M1.5 indexer behavior.""" + +import builtins +import sys +from types import SimpleNamespace + +import torch + +from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + +_FORBIDDEN_CUDA_SPARSE_MODULES = ( + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +) + + +def _guard_sparse_kernel_imports(monkeypatch): + original_import = builtins.__import__ + + def _guarded_import(name, *args, **kwargs): + if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): + raise AssertionError(f"M1.5 tests must not import sparse MLA kernels: {name}") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _guarded_import) + + +class _FakeDenseBackend: + def __init__(self, v_head_dim: int): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + } + ) + return q.new_empty((q.shape[0], q.shape[1], self.v_head_dim)) + + +class _FakeIndexer: + def __init__(self, topk_values): + self.calls = [] + self.index_topk = topk_values.shape[1] + self.topk_indices_buffer = torch.full( + (topk_values.shape[0], topk_values.shape[1] + 2), + -1, + dtype=torch.int32, + ) + self.topk_indices_buffer[ + : topk_values.shape[0], : topk_values.shape[1] + ].copy_(topk_values) + self.weights = torch.full(topk_values.shape, 99.0, dtype=torch.float32) + + def __call__(self, *args, **kwargs): + self.calls.append((args, kwargs)) + return self.weights + + +class _FakeQProj: + def __init__(self, output): + self.output = output + self.calls = [] + + def __call__(self, query, q_scale=None): + self.calls.append((query, q_scale)) + return self.output + + +class _FakeOProj: + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor + + +def _make_attention(topk_values): + token_count = topk_values.shape[0] + num_heads = 2 + qk_head_dim = 4 + v_head_dim = 3 + projected_q = torch.arange( + token_count * num_heads * qk_head_dim, dtype=torch.float32 + ).reshape(token_count, num_heads * qk_head_dim) + backend = _FakeDenseBackend(v_head_dim=v_head_dim) + indexer = _FakeIndexer(topk_values) + modules = SimpleNamespace( + q_proj=_FakeQProj(projected_q), + o_proj=_FakeOProj(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=v_head_dim, + qk_head_dim=qk_head_dim, + num_heads=num_heads, + num_local_heads=num_heads, + index_topk=topk_values.shape[1], + ) + attention = RTPMLAAttention( + mla_modules=modules, + dense_backend=backend, + layer_num=7, + kv_cache="kv-cache", + ) + return attention, modules, backend + + +def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace(topk_indices_buffer=topk_buffer, index_topk=4) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + attention = RTPMLAAttention(mla_modules=modules) + + assert attention.indexer is indexer + assert attention.topk_indices_buffer is topk_buffer + + +def _run_attention(attention, token_count: int): + query = torch.empty(token_count, 6) + compressed_kv = torch.empty(token_count, 8) + k_rope = torch.empty(token_count, 3) + positions = torch.arange(token_count, dtype=torch.int32) + return attention.forward( + query, + compressed_kv, + k_rope, + positions=positions, + ) + + +def test_indexer_buffer_topk_is_passed_to_dense_backend_when_emit_allowed(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert topk_indices.dtype == torch.int32 + assert topk_indices.shape == topk_values.shape + assert torch.equal(topk_indices, topk_values) + assert topk_indices is not modules.indexer.weights + assert not torch.equal(topk_indices.to(torch.float32), modules.indexer.weights) + + +def _patch_forward_context(monkeypatch, *, is_dummy_run, is_prefill, max_seqlen_k): + forward_context_mod = sys.modules["atom.utils.forward_context"] + + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=is_dummy_run, is_prefill=is_prefill), + attn_metadata=SimpleNamespace(max_seqlen_k=max_seqlen_k), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + +def test_dummy_run_does_not_emit_topk_to_dense_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=True, + is_prefill=False, + max_seqlen_k=4096, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + assert backend.calls[0]["topk_indices"] is None + + +def test_short_prefill_does_not_emit_topk_to_dense_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=4, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + assert backend.calls[0]["topk_indices"] is None + + +def test_prefill_within_topk_buffer_padding_does_not_emit_topk(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=5, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.index_topk == 4 + assert modules.indexer.topk_indices_buffer.shape[1] == 6 + assert modules.indexer.calls == [] + assert backend.calls[0]["topk_indices"] is None + diff --git a/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py b/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py new file mode 100644 index 0000000000..908b17cffb --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py @@ -0,0 +1,66 @@ +"""Static guards for the GLM5 rtp-llm plugin path.""" + +import ast +from pathlib import Path + + +_ATOM_ROOT = Path(__file__).resolve().parents[2] +_FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS = { + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +} + + +def _read_plugin_file(relative_path: str) -> str: + return (_ATOM_ROOT / relative_path).read_text() + + +def test_glm5_wrapper_does_not_use_mha_or_qwen_patches(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "RTPFullAttention" not in source + assert "apply_attention_mha_rtpllm_patch" not in source + assert "apply_attention_gdn_rtpllm_patch" not in source + assert "apply_qwen3_next_rtpllm_patch" not in source + + +def test_glm5_wrapper_does_not_reference_deepseek_mla_patch(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "apply_deepseek_mla_rtpllm_patch" not in source + + +def test_rtp_mla_prepare_does_not_keep_native_forward_mirror_helpers(): + assert not ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" + ).exists() + + +def test_glm5_mla_backend_is_not_full_attention_adapter(): + source = _read_plugin_file("atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py") + + assert "class RTPMLAAttention" in source + assert "use_mla" in source + assert "RTPFullAttention" not in source + + +def test_sparse_mla_backend_has_no_import_time_cuda_sparse_kernel_dependencies(): + backend_path = _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py" + assert backend_path.exists() + + tree = ast.parse(backend_path.read_text()) + imported_modules = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.update(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module is not None: + imported_modules.add(node.module) + + assert not any( + forbidden in module_name.split(".") + for module_name in imported_modules + for forbidden in _FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS + ) + diff --git a/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py b/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py new file mode 100644 index 0000000000..cd4c0602c6 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py @@ -0,0 +1,24 @@ +"""Shape-level tests for the GLM5 RTP MLA bridge.""" + +from types import SimpleNamespace + +import torch + +from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + +def test_mla_attention_legacy_boundary_shape_stays_executable_during_migration(): + q = torch.empty(2, 4, 256) + compressed_kv = torch.empty(2, 512) + k_pe = torch.empty(2, 64) + positions = torch.arange(2, dtype=torch.int32) + attention = RTPMLAAttention(mla_modules=SimpleNamespace(v_head_dim=128)) + + output = attention(q, compressed_kv, k_pe, positions=positions) + + assert output.shape == (2, 4, 128) + + +def test_mla_attention_is_marked_as_mla_adapter(): + assert RTPMLAAttention.use_mla is True + diff --git a/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py b/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py new file mode 100644 index 0000000000..406f5e88d9 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py @@ -0,0 +1,1053 @@ +"""Contract-executable tests for GLM5 RTP MLA native forward.""" + +import builtins +import importlib +import inspect +from types import SimpleNamespace + +import pytest +import torch + +from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + +_FORBIDDEN_CUDA_SPARSE_MODULES = ( + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +) + + +class _FakeDenseBackend: + def __init__(self, v_head_dim: int): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward( + self, + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + topk_indices=None, + positions=None, + ): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + "positions": positions, + } + ) + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + + +def test_rtp_mla_attention_keeps_legacy_dense_boundary_during_migration(): + backend = _FakeDenseBackend(v_head_dim=16) + attention = RTPMLAAttention(dense_backend=backend, layer_id=7, kv_cache="cache") + q = torch.empty(3, 2, 12, dtype=torch.bfloat16) + compressed_kv = torch.empty(3, 8, dtype=torch.bfloat16) + k_pe = torch.empty(3, 4, dtype=torch.bfloat16) + positions = torch.arange(3, dtype=torch.int32) + + output = attention.forward( + q, + compressed_kv, + k_pe, + positions=positions, + topk_indices=None, + ) + + assert output.shape == (3, 2, 16) + assert len(backend.calls) == 1 + call = backend.calls[0] + assert call["q"] is q + assert call["compressed_kv"] is compressed_kv + assert call["k_pe"] is k_pe + assert call["kv_cache"] == "cache" + assert call["layer_id"] == 7 + assert call["topk_indices"] is None + assert call["positions"] is positions + + +def _guard_sparse_kernel_imports(monkeypatch): + original_import = builtins.__import__ + + def _guarded_import(name, *args, **kwargs): + if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): + raise AssertionError(f"M1 dense contract must not import sparse kernel module: {name}") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _guarded_import) + + +def test_rtp_mla_attention_accepts_explicit_topk_and_passes_it_to_dense_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + attention = RTPMLAAttention(dense_backend=_FakeDenseBackend(v_head_dim=16)) + q = torch.empty(1, 2, 12) + compressed_kv = torch.empty(1, 8) + k_pe = torch.empty(1, 4) + positions = torch.arange(1, dtype=torch.int32) + topk = torch.tensor([[3, 1, 0, 2]], dtype=torch.int32) + + output = attention.forward( + q, + compressed_kv, + k_pe, + positions=positions, + topk_indices=topk, + ) + + assert output.shape == (1, 2, 16) + assert len(attention.dense_backend.calls) == 1 + assert attention.dense_backend.calls[0]["topk_indices"] is topk + + +def test_dense_backend_output_does_not_depend_on_topk_values(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + backend = _FakeDenseBackend(v_head_dim=16) + attention = RTPMLAAttention(dense_backend=backend) + q = torch.ones(2, 2, 12) + compressed_kv = torch.empty(2, 8) + k_pe = torch.empty(2, 4) + positions = torch.arange(2, dtype=torch.int32) + topk_a = torch.tensor([[3, 1, 0, 2], [2, 0, 1, 3]], dtype=torch.int32) + topk_b = torch.tensor([[0, 2, 1, 3], [3, 1, 2, 0]], dtype=torch.int32) + + out_a = attention.forward( + q, + compressed_kv, + k_pe, + positions=positions, + topk_indices=topk_a, + ) + out_b = attention.forward( + q, + compressed_kv, + k_pe, + positions=positions, + topk_indices=topk_b, + ) + + assert torch.equal(out_a, out_b) + assert backend.calls[0]["topk_indices"] is topk_a + assert backend.calls[1]["topk_indices"] is topk_b + + +def test_native_forward_signature_exposes_q_scale_argument(): + signature = inspect.signature(RTPMLAAttention.forward) + + assert "q_scale" in signature.parameters + + +@pytest.mark.parametrize("attr", ["q_proj", "o_proj", "kv_b_proj", "v_head_dim"]) +def test_constructor_injects_native_mla_module_attributes(attr): + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + v_head_dim=16, + ) + attention = RTPMLAAttention(mla_modules=modules) + + assert getattr(attention, attr) == getattr(modules, attr) + + +class _FakeQProj: + def __init__(self, output): + self.output = output + self.calls = [] + + def __call__(self, query, q_scale=None): + self.calls.append((query, q_scale)) + return self.output + + +class _FakeOProj: + def __init__(self, hidden_dim: int): + self.hidden_dim = hidden_dim + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor.new_empty((tensor.shape[0], self.hidden_dim)) + + +def test_native_five_tuple_projects_latent_query_and_applies_o_proj(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + token_count = 3 + num_heads = 2 + qk_head_dim = 4 + v_head_dim = 5 + hidden_dim = 7 + query = torch.arange(token_count * 6, dtype=torch.float32).reshape(token_count, 6) + q_scale = torch.ones(token_count, 1) + projected_q = torch.arange( + token_count * num_heads * qk_head_dim, dtype=torch.float32 + ).reshape(token_count, num_heads * qk_head_dim) + compressed_kv = torch.empty(token_count, 8) + k_rope = torch.empty(token_count, 3) + positions = torch.arange(token_count, dtype=torch.int32) + backend = _FakeDenseBackend(v_head_dim=v_head_dim) + modules = SimpleNamespace( + q_proj=_FakeQProj(projected_q), + o_proj=_FakeOProj(hidden_dim=hidden_dim), + kv_b_proj=object(), + v_head_dim=v_head_dim, + qk_head_dim=qk_head_dim, + num_heads=num_heads, + num_local_heads=num_heads, + ) + attention = RTPMLAAttention( + mla_modules=modules, + dense_backend=backend, + layer_num=5, + kv_cache="kv-cache", + ) + + output = attention.forward( + query, + compressed_kv, + k_rope, + positions=positions, + q_scale=q_scale, + ) + + assert modules.q_proj.calls == [(query, q_scale)] + assert len(backend.calls) == 1 + call = backend.calls[0] + assert call["q"].shape == (token_count, num_heads, qk_head_dim) + assert torch.equal(call["q"].reshape(token_count, -1), projected_q) + assert call["compressed_kv"] is compressed_kv + assert call["k_pe"] is k_rope + assert call["kv_cache"] == "kv-cache" + assert call["layer_id"] == 5 + assert len(modules.o_proj.calls) == 1 + assert modules.o_proj.calls[0].shape == (token_count, num_heads * v_head_dim) + assert output.shape == (token_count, hidden_dim) + + +def test_rtp_mla_attention_builds_m0_backend_from_mla_modules(): + modules = SimpleNamespace(v_head_dim=16) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + q = torch.empty(2, 4, 12) + compressed_kv = torch.empty(2, 8) + k_pe = torch.empty(2, 4) + + output = attention(q, compressed_kv, k_pe, positions=torch.arange(2)) + + assert output.shape == (2, 4, 16) + + +def test_rtp_mla_attention_defaults_to_sparse_backend_from_mla_modules(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( + RTPDenseMlaBackend, + ) + from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( + RTPSparseMlaBackend, + ) + + modules = SimpleNamespace(v_head_dim=16) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + assert isinstance(attention.dense_backend, RTPSparseMlaBackend) + assert isinstance(attention.dense_backend.dense_backend, RTPDenseMlaBackend) + + +class _FakeKVProj: + def __init__(self, output: torch.Tensor): + self.output = output + self.calls = [] + + def __call__(self, compressed_kv): + self.calls.append(compressed_kv) + output = self.output.to(device=compressed_kv.device, dtype=compressed_kv.dtype) + if output.shape[0] == 1 and compressed_kv.shape[0] != 1: + output = output.expand(compressed_kv.shape[0], -1).contiguous() + return output + + +class _DeterministicKVProj: + def __init__(self, output_dim: int): + self.output_dim = output_dim + self.calls = [] + + def __call__(self, compressed_kv): + self.calls.append(compressed_kv.detach().clone()) + token_signal = compressed_kv.float().mean(dim=-1, keepdim=True) + basis = torch.linspace( + 0.0, + 1.0, + self.output_dim, + device=compressed_kv.device, + dtype=torch.float32, + ).unsqueeze(0) + return (token_signal + basis).to(dtype=compressed_kv.dtype) + + +class _FakeRotaryEmbedding: + is_neox_style = True + + def __init__(self): + self.calls = [] + + def __call__(self, positions, query, key): + self.calls.append( + { + "positions": positions.detach().clone(), + "query": query.detach().clone(), + "key": key.detach().clone(), + } + ) + offset = positions.to(device=query.device, dtype=query.dtype) + while offset.ndim < query.ndim: + offset = offset.unsqueeze(-1) + query = query + offset + key_offset = positions.to(device=key.device, dtype=key.dtype) + while key_offset.ndim < key.ndim: + key_offset = key_offset.unsqueeze(-1) + key = key + key_offset + return query, key + + +def _patch_forward_context( + monkeypatch, + *, + is_prefill, + query_start_loc, + seq_lens=None, + block_table=None, + slot_mapping=None, + kv_cache_data=None, +): + plugin_metadata = SimpleNamespace( + query_start_loc=query_start_loc, + rtp_cu_seqlens_q=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + ) + fake_context = SimpleNamespace( + attn_metadata=SimpleNamespace( + plugin_metadata=plugin_metadata, + rtp_kernel_seq_size_per_block=4, + ), + context=SimpleNamespace(is_prefill=is_prefill), + kv_cache_data=kv_cache_data, + ) + forward_context_module = importlib.import_module("atom.utils.forward_context") + monkeypatch.setattr( + forward_context_module, + "get_forward_context", + lambda: fake_context, + ) + + +def _patch_forward_context_with_top_level_attn_metadata( + monkeypatch, + *, + is_prefill, + seq_lens, + block_table, + slot_mapping, + kv_cache_data=None, +): + fake_context = SimpleNamespace( + attn_metadata=SimpleNamespace( + plugin_metadata=None, + context_lens=seq_lens, + block_tables=block_table, + slot_mapping=slot_mapping, + cu_seqlens_q=None, + rtp_kernel_seq_size_per_block=4, + ), + context=SimpleNamespace(is_prefill=is_prefill), + kv_cache_data=kv_cache_data, + ) + forward_context_module = importlib.import_module("atom.utils.forward_context") + monkeypatch.setattr( + forward_context_module, + "get_forward_context", + lambda: fake_context, + ) + + +def _patch_forward_context_without_is_prefill(monkeypatch, *, query_start_loc): + plugin_metadata = SimpleNamespace( + query_start_loc=query_start_loc, + rtp_cu_seqlens_q=query_start_loc, + seq_lens=None, + block_table=None, + slot_mapping=None, + ) + fake_context = SimpleNamespace( + attn_metadata=SimpleNamespace( + plugin_metadata=plugin_metadata, + rtp_kernel_seq_size_per_block=4, + ), + context=SimpleNamespace(), + ) + forward_context_module = importlib.import_module("atom.utils.forward_context") + monkeypatch.setattr( + forward_context_module, + "get_forward_context", + lambda: fake_context, + ) + + +def test_default_dense_mla_backend_computes_nonzero_attention(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( + RTPSparseMlaBackend, + ) + + q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.ones(2, 4, dtype=torch.float32) + # Per token: [k_nope_dim=2, v_head_dim=1]. + kv_projection = torch.tensor([[1.0, 0.0, 5.0], [0.0, 1.0, 7.0]]) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(kv_projection), + ) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + output = attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) + + assert isinstance(attention.dense_backend, RTPSparseMlaBackend) + assert output.shape == (2, 1, 1) + assert not torch.equal(output, torch.zeros_like(output)) + assert len(modules.kv_b_proj.calls) == 1 + assert modules.kv_b_proj.calls[0] is compressed_kv + + +def test_default_dense_mla_backend_rejects_missing_multi_token_metadata(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(2, 1, 2) + compressed_kv = torch.ones(2, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(2, 3)), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + with pytest.raises(ValueError, match="query_start_loc metadata"): + attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) + + +def test_default_dense_mla_backend_decode_reads_history_from_raw_cache(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) + # The backend projects each latent token into [k_nope0, k_nope1, v]. + kv_projection = torch.tensor([[0.0, 0.0, 1.0]]) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(kv_projection), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + # Three historical latent tokens are already in cache. + layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + layer_cache.kv_cache_base[0, 1] = torch.tensor([2.0, 0.0, 0.0, 0.0]) + layer_cache.kv_cache_base[0, 2] = torch.tensor([3.0, 0.0, 0.0, 0.0]) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([4], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([3], dtype=torch.int32), + ) + attention = RTPMLAAttention( + mla_modules=modules, + layer_num=3, + kv_cache=layer_cache, + ) + + output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + assert output.shape == (1, 1, 1) + assert layer_cache.kv_cache_base[0, 3].tolist() == [9.0, 9.0, 9.0, 9.0] + + +def test_default_dense_mla_backend_decode_uses_top_level_rtp_metadata(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + _patch_forward_context_with_top_level_attn_metadata( + monkeypatch, + is_prefill=False, + seq_lens=torch.tensor([2], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([1], dtype=torch.int32), + ) + attention = RTPMLAAttention( + mla_modules=modules, + layer_num=3, + kv_cache=layer_cache, + ) + + output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + assert output.shape == (1, 1, 1) + assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] + + +def test_default_dense_mla_backend_decode_rebuilds_stale_query_start_loc(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + seq_lens=torch.tensor([2], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([1], dtype=torch.int32), + ) + attention = RTPMLAAttention( + mla_modules=modules, + layer_num=3, + kv_cache=layer_cache, + ) + + output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + assert output.shape == (1, 1, 1) + assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] + + +def test_default_sparse_wrapper_validates_topk_but_falls_back_to_dense(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( + RTPSparseMlaBackend, + ) + + dense_backend = _FakeDenseBackend(v_head_dim=4) + sparse_impl = SimpleNamespace(calls=[]) + backend = RTPSparseMlaBackend( + dense_backend=dense_backend, + sparse_impl=sparse_impl, + v_head_dim=4, + ) + q = torch.ones(2, 1, 3) + compressed_kv = torch.ones(2, 5) + k_pe = torch.ones(2, 2) + positions = torch.arange(2) + topk = torch.tensor([[1, 0], [0, 1]], dtype=torch.int32) + + output = backend.forward( + q, + compressed_kv, + k_pe, + kv_cache="cache", + layer_id=9, + topk_indices=topk, + positions=positions, + ) + + assert output.shape == (2, 1, 4) + assert len(dense_backend.calls) == 1 + assert dense_backend.calls[0]["topk_indices"] is topk + assert dense_backend.calls[0]["positions"] is positions + assert sparse_impl.calls == [] + + +def test_default_dense_mla_backend_resolves_kv_cache_from_forward_context(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([2], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([1], dtype=torch.int32), + kv_cache_data={"layer_3": SimpleNamespace(k_cache=layer_cache)}, + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + assert output.shape == (1, 1, 1) + assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] + + +def test_default_dense_mla_backend_accepts_noncontiguous_compressed_kv(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) + storage = torch.arange(16, dtype=torch.float32).reshape(2, 8) + compressed_kv = storage[:, ::2] + assert not compressed_kv.is_contiguous() + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), + ) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + output = attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) + + assert output.shape == (2, 1, 1) + assert modules.kv_b_proj.calls[0].is_contiguous() + + +def test_default_dense_mla_backend_skips_negative_slot_mapping(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) + compressed_kv = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32 + ) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + slot_mapping=torch.tensor([-1, 1], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) + + assert torch.equal(layer_cache.kv_cache_base[0, -1], torch.zeros(4)) + assert torch.equal(layer_cache.kv_cache_base[0, 1], compressed_kv[1]) + + +def test_default_dense_mla_backend_rejects_oob_slot_mapping(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.tensor([[[1.0, 0.0]]], dtype=torch.float32) + compressed_kv = torch.tensor([[1.0, 2.0, 3.0, 4.0]], dtype=torch.float32) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + slot_mapping=torch.tensor([4], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + with pytest.raises(RuntimeError, match="out-of-bounds slot_mapping"): + attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + assert torch.equal(layer_cache.kv_cache_base, torch.zeros_like(layer_cache.kv_cache_base)) + + +def test_default_dense_mla_backend_writes_post_rope_kpe_to_cache(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + rotary_emb = _FakeRotaryEmbedding() + q = torch.tensor([[[1.0, 2.0, 10.0, 20.0]], [[3.0, 4.0, 30.0, 40.0]]]) + compressed_kv = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + k_pe = torch.tensor([[100.0, 200.0], [300.0, 400.0]]) + positions = torch.tensor([5, 7], dtype=torch.long) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=2, + rotary_emb=rotary_emb, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 5)) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + slot_mapping=torch.tensor([0, 1], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + attention(q, compressed_kv, k_pe, positions=positions) + + expected_k_pe = k_pe + positions.to(k_pe.dtype).unsqueeze(-1) + expected_cache = torch.cat((compressed_kv, expected_k_pe), dim=-1) + assert torch.equal(layer_cache.kv_cache_base[0, :2], expected_cache) + assert torch.equal(rotary_emb.calls[0]["positions"], positions) + + +def test_default_dense_mla_backend_uses_post_rope_q_for_attention(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( + RTPDenseMlaBackend, + ) + + captured = {} + + def _fake_causal_attention(q, key, value, query_start_loc, scale): + del key, query_start_loc, scale + captured["q"] = q.detach().clone() + return value.new_zeros((q.shape[0], q.shape[1], value.shape[-1])) + + monkeypatch.setattr( + RTPDenseMlaBackend, + "_causal_attention", + staticmethod(_fake_causal_attention), + ) + rotary_emb = _FakeRotaryEmbedding() + q = torch.tensor([[[1.0, 2.0, 10.0, 20.0]], [[3.0, 4.0, 30.0, 40.0]]]) + compressed_kv = torch.ones(2, 3) + k_pe = torch.ones(2, 2) + positions = torch.tensor([5, 7], dtype=torch.long) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=2, + rotary_emb=rotary_emb, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), + ) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + attention(q, compressed_kv, k_pe, positions=positions) + + expected_q = q.clone() + expected_q[..., -2:] += positions.to(q.dtype).view(2, 1, 1) + assert torch.equal(captured["q"], expected_q) + + +def test_default_dense_mla_backend_decode_history_kpe_not_double_roped(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( + RTPDenseMlaBackend, + ) + + captured_k_pe = [] + original_project_kv = RTPDenseMlaBackend._project_kv + + def _capture_project_kv(self, q, compressed_kv, k_pe): + captured_k_pe.append(k_pe.detach().clone()) + return original_project_kv(self, q, compressed_kv, k_pe) + + monkeypatch.setattr(RTPDenseMlaBackend, "_project_kv", _capture_project_kv) + rotary_emb = _FakeRotaryEmbedding() + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=2, + rotary_emb=rotary_emb, + kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0]])), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 5)) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + prefill_k_pe = torch.tensor([[10.0, 20.0], [30.0, 40.0]]) + prefill_positions = torch.tensor([4, 5], dtype=torch.long) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + slot_mapping=torch.tensor([0, 1], dtype=torch.int32), + ) + attention( + torch.ones(2, 1, 4), + torch.ones(2, 3), + prefill_k_pe, + positions=prefill_positions, + ) + + decode_k_pe = torch.tensor([[50.0, 60.0]]) + decode_positions = torch.tensor([6], dtype=torch.long) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([3], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([2], dtype=torch.int32), + ) + attention( + torch.ones(1, 1, 4), + torch.ones(1, 3), + decode_k_pe, + positions=decode_positions, + ) + + expected_history_k_pe = torch.cat( + ( + prefill_k_pe + prefill_positions.to(prefill_k_pe.dtype).unsqueeze(-1), + decode_k_pe + decode_positions.to(decode_k_pe.dtype).unsqueeze(-1), + ), + dim=0, + ) + assert torch.equal(captured_k_pe[-1], expected_history_k_pe) + + +def test_default_dense_mla_backend_rejects_missing_is_prefill_metadata(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(1, 1, 2) + compressed_kv = torch.ones(1, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(1, 3)), + ) + _patch_forward_context_without_is_prefill( + monkeypatch, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + with pytest.raises(ValueError, match="context.is_prefill"): + attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + +@pytest.mark.parametrize( + ("field_name", "seq_lens", "block_table", "slot_mapping"), + [ + ("seq_lens", None, torch.tensor([[0]], dtype=torch.int32), torch.tensor([0])), + ("block_table", torch.tensor([1], dtype=torch.int32), None, torch.tensor([0])), + ("slot_mapping", torch.tensor([1], dtype=torch.int32), torch.tensor([[0]], dtype=torch.int32), None), + ], +) +def test_default_dense_mla_backend_decode_requires_rtp_metadata( + monkeypatch, field_name, seq_lens, block_table, slot_mapping +): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(1, 1, 2) + compressed_kv = torch.ones(1, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(1, 3)), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + with pytest.raises(ValueError, match=field_name): + attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + +def test_default_dense_mla_backend_decode_requires_readable_cache(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(1, 1, 2) + compressed_kv = torch.ones(1, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(1, 3)), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.empty(0)) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([1], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([0], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + with pytest.raises(ValueError, match="kv_cache_base"): + attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + +def test_default_dense_mla_backend_rejects_fp8_kv_cache(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(1, 1, 2) + compressed_kv = torch.ones(1, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(1, 3)), + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4, dtype=torch.uint8)) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([1], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([0], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + + with pytest.raises(NotImplementedError, match="FP8 KV cache"): + attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) + + +def test_default_dense_mla_backend_glm5_shape_bf16_cache_roundtrip(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + num_heads = 32 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + v_head_dim = 128 + kv_lora_rank = 512 + kv_dim = kv_lora_rank + qk_rope_head_dim + output_dim = num_heads * (qk_nope_head_dim + v_head_dim) + kv_proj = _DeterministicKVProj(output_dim) + rotary_emb = _FakeRotaryEmbedding() + modules = SimpleNamespace( + v_head_dim=v_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + rotary_emb=rotary_emb, + kv_b_proj=kv_proj, + ) + layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, kv_dim, dtype=torch.bfloat16)) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) + q_prefill = torch.randn( + 3, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=torch.bfloat16, + ) + compressed_prefill = torch.randn(3, kv_lora_rank, dtype=torch.bfloat16) + k_pe_prefill = torch.randn(3, qk_rope_head_dim, dtype=torch.bfloat16) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 3], dtype=torch.int32), + slot_mapping=torch.tensor([0, 1, 2], dtype=torch.int32), + ) + + prefill_output = attention( + q_prefill, + compressed_prefill, + k_pe_prefill, + positions=torch.arange(3), + ) + + assert prefill_output.shape == (3, num_heads, v_head_dim) + expected_prefill_k_pe = k_pe_prefill + torch.arange(3).to( + dtype=k_pe_prefill.dtype + ).unsqueeze(-1) + expected_prefill_cache = torch.cat((compressed_prefill, expected_prefill_k_pe), dim=-1) + assert torch.equal(layer_cache.kv_cache_base[0, :3], expected_prefill_cache) + + q_decode = torch.randn( + 1, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=torch.bfloat16, + ) + compressed_decode = torch.randn(1, kv_lora_rank, dtype=torch.bfloat16) + k_pe_decode = torch.randn(1, qk_rope_head_dim, dtype=torch.bfloat16) + _patch_forward_context( + monkeypatch, + is_prefill=False, + query_start_loc=torch.tensor([0, 1], dtype=torch.int32), + seq_lens=torch.tensor([4], dtype=torch.int32), + block_table=torch.tensor([[0]], dtype=torch.int32), + slot_mapping=torch.tensor([3], dtype=torch.int32), + ) + + decode_output = attention( + q_decode, + compressed_decode, + k_pe_decode, + positions=torch.arange(1), + ) + + assert decode_output.shape == (1, num_heads, v_head_dim) + expected_decode_k_pe = k_pe_decode + torch.arange(1).to( + dtype=k_pe_decode.dtype + ).unsqueeze(-1) + expected_decode_cache = torch.cat((compressed_decode, expected_decode_k_pe), dim=-1) + assert torch.equal(layer_cache.kv_cache_base[0, 3:4], expected_decode_cache) + expected_history = torch.cat((compressed_prefill, compressed_decode), dim=0) + assert torch.equal(kv_proj.calls[-1], expected_history) + + +def test_default_dense_mla_backend_rejects_bad_kv_projection_shape(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + q = torch.randn(2, 1, 2) + compressed_kv = torch.ones(2, 4) + modules = SimpleNamespace( + v_head_dim=1, + qk_nope_head_dim=2, + qk_rope_head_dim=0, + kv_b_proj=_FakeKVProj(torch.empty(2, 2)), + ) + _patch_forward_context( + monkeypatch, + is_prefill=True, + query_start_loc=torch.tensor([0, 2], dtype=torch.int32), + ) + attention = RTPMLAAttention(mla_modules=modules, layer_num=3) + + with pytest.raises(ValueError, match="kv_b_proj output shape mismatch"): + attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) + + +def test_rtp_mla_attention_explicit_dense_backend_overrides_sparse_default(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + dense_backend = _FakeDenseBackend(v_head_dim=16) + modules = SimpleNamespace(v_head_dim=16) + + attention = RTPMLAAttention(mla_modules=modules, dense_backend=dense_backend) + + assert attention.dense_backend is dense_backend + diff --git a/tests/plugin/test_rtpllm_glm5_mla_patch.py b/tests/plugin/test_rtpllm_glm5_mla_patch.py new file mode 100644 index 0000000000..78d0bc52ca --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_mla_patch.py @@ -0,0 +1,23 @@ +"""No-monkey-patch guards for GLM5 RTP MLA M1.5 forward.""" + +from pathlib import Path + + +_ATOM_ROOT = Path(__file__).resolve().parents[2] + + +def _read_plugin_file(relative_path: str) -> str: + return (_ATOM_ROOT / relative_path).read_text() + + +def test_rtp_mla_prepare_no_longer_contains_deepseek_forward_monkey_patch(): + assert not ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" + ).exists() + + +def test_glm5_wrapper_does_not_import_or_call_deepseek_mla_patch(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "apply_deepseek_mla_rtpllm_patch" not in source + diff --git a/tests/plugin/test_rtpllm_glm5_ownership.py b/tests/plugin/test_rtpllm_glm5_ownership.py new file mode 100644 index 0000000000..1bfa19cfc9 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_ownership.py @@ -0,0 +1,39 @@ +"""Ownership contract tests for GLM5 rtp-llm M0.""" + +from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( + GLM5_RTP_BRIDGE_MODE, + GLM5_RTP_BRIDGE_MODE_M0_DENSE, + GLM5_RTP_OWNERSHIP, +) + + +def test_glm5_bridge_mode_starts_in_m0_dense(): + assert GLM5_RTP_BRIDGE_MODE == GLM5_RTP_BRIDGE_MODE_M0_DENSE + + +def test_glm5_ownership_unique_and_separates_rope_paths(): + required = { + "main_q_norm", + "main_kv_norm", + "main_rope", + "main_kv_cache", + "indexer_k_norm", + "indexer_rope", + "indexer_cache", + "topk_selector", + } + + assert required <= set(GLM5_RTP_OWNERSHIP) + for key in required: + owner = GLM5_RTP_OWNERSHIP[key] + assert isinstance(owner, str) + assert owner + + assert GLM5_RTP_OWNERSHIP["main_rope"] != GLM5_RTP_OWNERSHIP["indexer_rope"] + + +def test_glm5_ownership_forbids_qwen_and_mha_components(): + forbidden = ("GatedDeltaNet", "RTPFullAttention", "Qwen3Next") + for owner in GLM5_RTP_OWNERSHIP.values(): + assert all(name not in owner for name in forbidden) + diff --git a/tests/plugin/test_rtpllm_glm5_registration.py b/tests/plugin/test_rtpllm_glm5_registration.py new file mode 100644 index 0000000000..8abdb77e23 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_registration.py @@ -0,0 +1,78 @@ +"""Tests for GLM5 rtp-llm plugin registration.""" + +import importlib +import sys +from types import ModuleType +from unittest.mock import MagicMock, call, patch + + +def _package(name: str) -> ModuleType: + module = ModuleType(name) + module.__path__ = [] + return module + + +def test_rtpllm_wrapper_registers_glm5_override_and_alias(): + register_model_mock = MagicMock() + + fake_rtp_register_mod = ModuleType("rtp_llm.model_factory_register") + fake_rtp_register_mod.register_model = register_model_mock + fake_rtp_register_mod._model_factory = {} + fake_rtp_register_mod._hf_architecture_2_ft = {} + + fake_atom_register_mod = ModuleType("atom.plugin.register") + fake_atom_register_mod._ATOM_SUPPORTED_MODELS = {} + + fake_atom_deepseek_mod = ModuleType("atom.models.deepseek_v2") + + class _FakeGlmMoeDsaForCausalLM: + pass + + fake_atom_deepseek_mod.GlmMoeDsaForCausalLM = _FakeGlmMoeDsaForCausalLM + + fake_atom_qwen_mod = ModuleType("atom.plugin.rtpllm.models.qwen3_5") + + class _FakeATOMQwen35Moe: + pass + + fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe + + fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") + + class _FakeATOMGlm5Moe: + pass + + fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe + + fake_modules = { + "rtp_llm": _package("rtp_llm"), + "rtp_llm.models": _package("rtp_llm.models"), + "rtp_llm.model_factory_register": fake_rtp_register_mod, + "atom.models.deepseek_v2": fake_atom_deepseek_mod, + "atom.plugin.register": fake_atom_register_mod, + "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, + "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, + } + + with patch.dict(sys.modules, fake_modules): + sys.modules.pop("atom.plugin.rtpllm.models", None) + sys.modules.pop("atom.plugin.rtpllm.models.base_model_wrapper", None) + importlib.import_module("atom.plugin.rtpllm.models") + + assert fake_rtp_register_mod._model_factory["glm_5"] is _FakeATOMGlm5Moe + assert ( + fake_rtp_register_mod._hf_architecture_2_ft["GlmMoeDsaForCausalLM"] + == "glm_5" + ) + assert ( + fake_atom_register_mod._ATOM_SUPPORTED_MODELS["GlmMoeDsaForCausalLM"] + is _FakeGlmMoeDsaForCausalLM + ) + register_model_mock.assert_has_calls( + [ + call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), + call("atom_glm5_moe", _FakeATOMGlm5Moe, []), + ], + any_order=False, + ) + diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py new file mode 100644 index 0000000000..065d265b6b --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -0,0 +1,353 @@ +"""Contract-executable tests for GLM5 RTP MLA M2 sparse topk consumption.""" + +import builtins +import importlib +import inspect +import sys +from types import SimpleNamespace + +import torch + + +_SPARSE_BACKEND_MODULE = ( + "atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend" +) +_FORBIDDEN_CUDA_SPARSE_MODULES = ( + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +) + + +def _guard_sparse_kernel_imports(monkeypatch): + original_import = builtins.__import__ + + def _guarded_import(name, *args, **kwargs): + if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): + raise AssertionError(f"M2 sparse contract must not import CUDA sparse kernel: {name}") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _guarded_import) + + +def _load_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + importlib.invalidate_caches() + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + return module.RTPSparseMlaBackend + + +class _FakeDenseBackend: + def __init__(self, v_head_dim: int = 5): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + } + ) + return q.new_full((q.shape[0], q.shape[1], self.v_head_dim), -1) + + +class _FakeSparseImpl: + def __init__(self, v_head_dim: int = 5): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward( + self, + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + *, + topk_indices, + attn_metadata, + ): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + "attn_metadata": attn_metadata, + } + ) + return q.new_full((q.shape[0], q.shape[1], self.v_head_dim), 7) + + +def _build_backend(backend_cls, dense_backend, sparse_impl): + params = inspect.signature(backend_cls).parameters + kwargs = {} + if "dense_backend" not in params: + raise AssertionError("RTPSparseMlaBackend must accept dense_backend= for dense fallback") + kwargs["dense_backend"] = dense_backend + + if "sparse_impl" in params: + kwargs["sparse_impl"] = sparse_impl + else: + raise AssertionError("RTPSparseMlaBackend must accept a mock sparse impl injection") + + if "v_head_dim" in params: + kwargs["v_head_dim"] = dense_backend.v_head_dim + return backend_cls(**kwargs) + + +def _make_inputs(): + return ( + torch.randn(3, 2, 4), + torch.randn(3, 8), + torch.randn(3, 3), + SimpleNamespace(name="kv-cache"), + 11, + ) + + +def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + dense_backend = _FakeDenseBackend() + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, dense_backend, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[4, 1], [3, 0], [2, 1]], dtype=torch.int32) + + output = backend.forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk + ) + + assert output.shape == (3, 2, sparse_impl.v_head_dim) + assert dense_backend.calls == [] + assert len(sparse_impl.calls) == 1 + assert sparse_impl.calls[0]["topk_indices"] is topk + assert sparse_impl.calls[0]["topk_indices"].dtype == torch.int32 + assert sparse_impl.calls[0]["topk_indices"].shape == (3, 2) + + +def test_sparse_backend_falls_back_to_dense_when_topk_is_none(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + dense_backend = _FakeDenseBackend() + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, dense_backend, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + + output = backend.forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None + ) + + assert output.shape == (3, 2, dense_backend.v_head_dim) + assert len(dense_backend.calls) == 1 + assert sparse_impl.calls == [] + assert dense_backend.calls[0]["q"] is q + assert dense_backend.calls[0]["compressed_kv"] is compressed_kv + assert dense_backend.calls[0]["k_pe"] is k_pe + assert dense_backend.calls[0]["kv_cache"] is kv_cache + assert dense_backend.calls[0]["layer_id"] == layer_id + assert dense_backend.calls[0]["topk_indices"] is None + + +def test_sparse_backend_threads_kv_cache_and_layer_id_to_sparse_impl(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + dense_backend = _FakeDenseBackend() + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, dense_backend, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + + call = sparse_impl.calls[0] + assert call["q"] is q + assert call["compressed_kv"] is compressed_kv + assert call["k_pe"] is k_pe + assert call["kv_cache"] is kv_cache + assert call["layer_id"] == layer_id + + +def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + forward_context_mod = sys.modules["atom.utils.forward_context"] + + attn_metadata = SimpleNamespace(block_table="block-table", seq_lens="seq-lens") + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + dense_backend = _FakeDenseBackend() + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, dense_backend, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + + assert sparse_impl.calls[0]["attn_metadata"] is attn_metadata + + +def test_sparse_backend_forward_signature_matches_dense_boundary(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + + signature = inspect.signature(backend_cls.forward) + params = signature.parameters + + assert list(params) == [ + "self", + "q", + "compressed_kv", + "k_pe", + "kv_cache", + "layer_id", + "topk_indices", + "positions", + ] + assert params["topk_indices"].default is None + + +def test_sparse_backend_converts_request_local_topk_to_global_slots(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + convert = module._RealSparseMlaImpl._convert_topk_to_global + plugin_metadata = SimpleNamespace( + block_table=torch.tensor([[7, 8], [20, 21]], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + ) + attn_metadata = SimpleNamespace(plugin_metadata=plugin_metadata) + topk = torch.tensor( + [ + [0, 1029, -1], + [1024, 2048, 5], + ], + dtype=torch.int32, + ) + + del backend_cls + global_topk = convert( + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=1024, + ) + + assert global_topk.cpu().tolist() == [ + [7 * 1024, 8 * 1024 + 5, 0], + [21 * 1024, 0, 20 * 1024 + 5], + ] + + +def test_real_sparse_decode_uses_atom_aiter_metadata(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = {} + + import aiter + + def fake_metadata_info(*args, **kwargs): + calls["metadata_info"] = (args, kwargs) + return ( + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + ) + + def fake_metadata_v1(*args, **kwargs): + calls["metadata_v1"] = (args, kwargs) + + monkeypatch.setattr(aiter, "get_mla_metadata_info_v1", fake_metadata_info, raising=False) + monkeypatch.setattr(aiter, "get_mla_metadata_v1", fake_metadata_v1, raising=False) + + fake_mla = type(sys)("aiter.mla") + + def fake_mla_decode_fwd(q, kv, output, qo_indptr, paged_kv_indptr, paged_kv_indices, + paged_kv_last_page_len, *args, **kwargs): + calls["mla_decode_fwd"] = { + "q": q, + "kv": kv, + "output": output, + "qo_indptr": qo_indptr, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_indices": paged_kv_indices, + "paged_kv_last_page_len": paged_kv_last_page_len, + "args": args, + "kwargs": kwargs, + } + output.fill_(3) + + fake_mla.mla_decode_fwd = fake_mla_decode_fwd + monkeypatch.setitem(sys.modules, "aiter.mla", fake_mla) + + fake_sparse_helpers = type(sys)("atom.plugin.attention_mla_sparse") + + def fake_generate_sparse_seqlen(query_lens, seq_lens, query_start_loc, topk, + num_tokens, max_query_len): + return torch.tensor([3, 2], dtype=torch.int32, device=query_lens.device) + + def fake_convert(req_id, block_table, token_indices, cu_seqlens, out, + BLOCK_SIZE=1, NUM_TOPK_TOKENS=0, BLOCK_N=128): + out[:5] = torch.tensor([0, 1, 2, 4, 5], dtype=torch.int32, device=out.device) + + fake_sparse_helpers.generate_sparse_seqlen_triton = fake_generate_sparse_seqlen + fake_sparse_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem( + sys.modules, + "atom.plugin.attention_mla_sparse", + fake_sparse_helpers, + ) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5) + kv_cache = torch.randn(8, 1, 5) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace( + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([3, 2], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [1]], dtype=torch.int32), + ) + ) + + output = impl._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + + assert output.shape == (2, 2, 4) + assert torch.all(output == 3) + decode_call = calls["mla_decode_fwd"] + assert decode_call["q"].shape == (2, 16, 5) + assert decode_call["output"].shape == (2, 16, 4) + assert decode_call["paged_kv_indptr"].tolist() == [0, 3, 5] + assert decode_call["paged_kv_indices"][:5].tolist() == [0, 1, 2, 4, 5] + assert decode_call["kwargs"]["page_size"] == 1 + assert decode_call["kwargs"]["work_meta_data"] is not None + assert decode_call["kwargs"]["reduce_final_map"] is not None diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py new file mode 100644 index 0000000000..7410111019 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -0,0 +1,320 @@ +"""Lifecycle tests for the GLM5 rtp-llm wrapper.""" + +from contextlib import nullcontext +import importlib +import os +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import torch + + +def _package(name: str) -> ModuleType: + module = ModuleType(name) + module.__path__ = [] + return module + + +def _install_fake_rtp_modules() -> dict[str, ModuleType]: + fake_config_mod = ModuleType("rtp_llm.config.model_config") + + class _FakeModelConfig: + pass + + fake_config_mod.ModelConfig = _FakeModelConfig + + fake_factory_register_mod = ModuleType("rtp_llm.model_factory_register") + fake_factory_register_mod.register_model = MagicMock() + fake_factory_register_mod._model_factory = {} + fake_factory_register_mod._hf_architecture_2_ft = {} + + fake_deepseek_mod = ModuleType("rtp_llm.models.deepseek_v2") + + class _FakeDeepSeekV2: + def _get_device_str(self): + return "cpu" + + def _create_python_model(self): + self.native_create_python_model_called = True + + def load(self, skip_python_model=False): + self.native_load_called = skip_python_model + + fake_deepseek_mod.DeepSeekV2 = _FakeDeepSeekV2 + + fake_weight_info_mod = ModuleType("rtp_llm.model_loader.model_weight_info") + + class _FakeModelWeights: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.global_weights = {} + + def set_global_weight(self, name, tensor): + self.global_weights[name] = tensor + + fake_weight_info_mod.ModelWeights = _FakeModelWeights + + fake_module_base_mod = ModuleType("rtp_llm.models_py.model_desc.module_base") + + class _FakeGptModelBase: + def __init__(self, *args, **kwargs): + self.init_args = args + self.init_kwargs = kwargs + + fake_module_base_mod.GptModelBase = _FakeGptModelBase + + fake_ops_mod = ModuleType("rtp_llm.ops") + + class _FakeParallelismConfig: + pass + + fake_ops_mod.ParallelismConfig = _FakeParallelismConfig + + fake_compute_ops_mod = ModuleType("rtp_llm.ops.compute_ops") + + class _FakePyModelInputs: + pass + + class _FakePyModelOutputs: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + + fake_compute_ops_mod.PyModelInputs = _FakePyModelInputs + fake_compute_ops_mod.PyModelOutputs = _FakePyModelOutputs + + fake_weight_mod = ModuleType("rtp_llm.utils.model_weight") + fake_weight_mod.W = SimpleNamespace( + lm_head="lm_head", + embedding="embedding", + final_ln_gamma="final_ln_gamma", + ) + + fake_loader_mod = ModuleType("atom.model_loader.loader") + + class _FakeWeightsMapper: + def __init__(self, **kwargs): + self.kwargs = kwargs + + fake_loader_mod.WeightsMapper = _FakeWeightsMapper + fake_loader_mod.load_model_in_plugin_mode = MagicMock() + + return { + "atom.model_loader": _package("atom.model_loader"), + "atom.model_loader.loader": fake_loader_mod, + "rtp_llm": _package("rtp_llm"), + "rtp_llm.config": _package("rtp_llm.config"), + "rtp_llm.config.model_config": fake_config_mod, + "rtp_llm.model_factory_register": fake_factory_register_mod, + "rtp_llm.models": _package("rtp_llm.models"), + "rtp_llm.models.deepseek_v2": fake_deepseek_mod, + "rtp_llm.model_loader": _package("rtp_llm.model_loader"), + "rtp_llm.model_loader.model_weight_info": fake_weight_info_mod, + "rtp_llm.models_py": _package("rtp_llm.models_py"), + "rtp_llm.models_py.model_desc": _package("rtp_llm.models_py.model_desc"), + "rtp_llm.models_py.model_desc.module_base": fake_module_base_mod, + "rtp_llm.ops": fake_ops_mod, + "rtp_llm.ops.compute_ops": fake_compute_ops_mod, + "rtp_llm.utils": _package("rtp_llm.utils"), + "rtp_llm.utils.model_weight": fake_weight_mod, + } + + +def _make_wrapper_instance(cls): + instance = cls.__new__(cls) + instance.model_config = SimpleNamespace( + num_layers=1, + compute_dtype=torch.bfloat16, + ) + instance.parallelism_config = SimpleNamespace() + instance.max_generate_batch_size = 1 + instance.fmha_config = None + instance.hw_kernel_config = None + instance.device_resource_config = None + return instance + + +def test_glm5_load_skip_python_model_does_not_create_atom_model(): + fake_modules = _install_fake_rtp_modules() + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + instance._create_python_model = MagicMock() + + instance.load(skip_python_model=True) + + instance._create_python_model.assert_not_called() + assert instance.device == "cpu" + assert isinstance(instance.model_weights_loader, module._NoopModelWeightsLoader) + assert isinstance(instance.weight_manager, module._NoopWeightManager) + + +def _patch_optional_attr(module, attr): + if hasattr(module, attr): + return patch.object(module, attr) + return nullcontext(MagicMock(name=attr)) + + +def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): + fake_modules = _install_fake_rtp_modules() + fake_atom_model = MagicMock(name="atom_model") + fake_atom_model.to.return_value = fake_atom_model + + with patch.dict( + sys.modules, + fake_modules, + ), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), patch("atom.prepare_model", return_value=fake_atom_model, create=True) as prepare_model: + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + instance.device = "cpu" + instance.weight = MagicMock() + + with _patch_optional_attr( + module, "apply_attention_mla_rtpllm_patch" + ) as mla_patch, _patch_optional_attr( + module, "apply_deepseek_mla_rtpllm_patch" + ) as deepseek_patch: + result = instance._create_python_model() + + prepare_model.assert_called_once_with(config=instance, engine="rtpllm") + mla_patch.assert_not_called() + deepseek_patch.assert_not_called() + load_model_in_plugin_mode = fake_modules[ + "atom.model_loader.loader" + ].load_model_in_plugin_mode + load_model_in_plugin_mode.assert_called_once() + assert result is instance.py_model + + +def test_glm5_support_cuda_graph_honors_eager_env(): + fake_modules = _install_fake_rtp_modules() + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + { + "RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models", + "ENABLE_CUDA_GRAPH": "0", + }, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + + assert instance.support_cuda_graph() is False + + +def test_glm5_runtime_forward_wraps_model_call_in_rtp_context(monkeypatch): + fake_modules = _install_fake_rtp_modules() + expected_input_ids = torch.tensor([10, 11], dtype=torch.int64) + position_ids = torch.tensor([5, 6], dtype=torch.int32) + hidden_states = torch.randn(2, 4) + events = [] + + class _FakeAtomModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1)) + + def forward(self, *, input_ids, positions, intermediate_tensors, inputs_embeds): + events.append(("model", bool(_FakeRTPForwardContext.in_context))) + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(positions, position_ids.to(torch.long)) + assert positions.dtype == torch.long + assert intermediate_tensors is None + assert inputs_embeds is None + return hidden_states + + class _FakeBind: + def __enter__(self): + _FakeRTPForwardContext.in_context = True + events.append(("enter", None)) + + def __exit__(self, exc_type, exc, tb): + events.append(("exit", None)) + _FakeRTPForwardContext.in_context = False + + class _FakeRTPForwardContext: + in_context = False + + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + @staticmethod + def bind(**kwargs): + assert torch.equal(kwargs["positions"], position_ids.to(torch.long)) + assert kwargs["positions"].dtype == torch.long + return _FakeBind() + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + monkeypatch.setattr(module, "RTPForwardContext", _FakeRTPForwardContext) + runtime = module._ATOMGlm5MoeRuntime( + model_config=SimpleNamespace(max_seq_len=16), + parallelism_config=SimpleNamespace(), + weights=MagicMock(), + max_generate_batch_size=2, + atom_model=_FakeAtomModel(), + ) + runtime.kv_cache = SimpleNamespace() + inputs = SimpleNamespace( + input_ids=expected_input_ids, + input_hiddens=None, + attention_inputs=SimpleNamespace(position_ids=position_ids), + ) + + output = runtime.forward(inputs) + + assert output.hidden_states is hidden_states + assert events == [("enter", None), ("model", True), ("exit", None)] + + +def test_glm5_runtime_prepare_fmha_impl_bypasses_native_mla_factory(monkeypatch): + fake_modules = _install_fake_rtp_modules() + + class _FakeRTPForwardContext: + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + monkeypatch.setattr(module, "RTPForwardContext", _FakeRTPForwardContext) + atom_model = torch.nn.Linear(1, 1) + runtime = module._ATOMGlm5MoeRuntime( + model_config=SimpleNamespace(max_seq_len=16), + parallelism_config=SimpleNamespace(), + weights=MagicMock(), + max_generate_batch_size=2, + atom_model=atom_model, + ) + inputs = SimpleNamespace(attention_inputs=SimpleNamespace()) + + attn_pyobj = runtime.prepare_fmha_impl(inputs, is_cuda_graph=False) + + assert attn_pyobj.fmha_params is None + assert attn_pyobj.is_cuda_graph is False + assert hasattr(attn_pyobj, "prepare_cuda_graph") + diff --git a/tests/plugin/test_rtpllm_model_wrapper.py b/tests/plugin/test_rtpllm_model_wrapper.py index 9ff4838d2a..cafbbbbd4c 100644 --- a/tests/plugin/test_rtpllm_model_wrapper.py +++ b/tests/plugin/test_rtpllm_model_wrapper.py @@ -3,7 +3,7 @@ import importlib import sys from types import ModuleType -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch def _package(name: str) -> ModuleType: @@ -26,12 +26,19 @@ class _FakeATOMQwen35Moe: pass fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe + fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") + + class _FakeATOMGlm5Moe: + pass + + fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe fake_modules = { "rtp_llm": _package("rtp_llm"), "rtp_llm.models": _package("rtp_llm.models"), "rtp_llm.model_factory_register": fake_register_mod, "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, + "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, } with patch.dict(sys.modules, fake_modules): @@ -46,6 +53,10 @@ class _FakeATOMQwen35Moe: ] == "qwen35_moe" ) - register_model_mock.assert_called_with( - "atom_qwen35_moe", _FakeATOMQwen35Moe, [] + register_model_mock.assert_has_calls( + [ + call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), + call("atom_glm5_moe", _FakeATOMGlm5Moe, []), + ], + any_order=False, ) diff --git a/tests/plugin/test_rtpllm_prepare_model.py b/tests/plugin/test_rtpllm_prepare_model.py index 0ff0114fa0..6dcc7c3460 100644 --- a/tests/plugin/test_rtpllm_prepare_model.py +++ b/tests/plugin/test_rtpllm_prepare_model.py @@ -65,3 +65,47 @@ def test_prepare_model_rtpllm_happy_path(): fake_quant_config.remap_layer_name.assert_called_once() fake_model_cls.assert_called_once_with(atom_config=fake_atom_config) assert result is fake_model + + +def test_prepare_model_rtpllm_glm5_reapplies_mla_attention_patch(): + fake_atom_config = _Obj( + hf_config=_Obj(architectures=["GlmMoeDsaForCausalLM"]), + plugin_config=_Obj(is_plugin_mode=True), + quant_config=_Obj( + exclude_layers=[], + remap_layer_name=MagicMock(), + ), + ) + fake_model = MagicMock(name="FakeGlm5") + fake_model_cls = MagicMock(return_value=fake_model) + + fake_register = MagicMock() + fake_register._ATOM_SUPPORTED_MODELS = {"GlmMoeDsaForCausalLM": fake_model_cls} + fake_register.register_ops_to_sglang = MagicMock() + fake_register.init_aiter_dist = MagicMock() + fake_register.set_attn_cls = MagicMock() + + fake_config_mod = MagicMock() + fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( + return_value=fake_atom_config + ) + + fake_rtpllm_attention_backend = MagicMock() + + with patch.dict( + sys.modules, + { + "atom.plugin.register": fake_register, + "atom.plugin.config": fake_config_mod, + "atom.plugin.rtpllm.attention_backend": fake_rtpllm_attention_backend, + }, + ): + result = plugin_prepare.prepare_model( + config=_Obj(model_config=_Obj()), engine="rtpllm" + ) + + fake_register.set_attn_cls.assert_called_once() + fake_rtpllm_attention_backend.apply_attention_mla_rtpllm_patch.assert_called_once() + fake_atom_config.quant_config.remap_layer_name.assert_called_once() + fake_model_cls.assert_called_once_with(atom_config=fake_atom_config) + assert result is fake_model From bf1c92fdee925f74d86aea560404699276b52fc5 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Tue, 2 Jun 2026 07:03:52 +0000 Subject: [PATCH 02/20] feat: RTPLLM GLM5 enable cuda graph --- .../rtp_dense_mla_backend.py | 36 +- .../attention_backend/rtp_mla_attention.py | 8 + .../rtp_sparse_mla_backend.py | 689 ++++++++++++++---- atom/plugin/rtpllm/models/glm5.py | 275 ++++++- atom/plugin/rtpllm/utils/forward_context.py | 310 ++++++-- .../test_rtpllm_forward_context_semantics.py | 216 ++---- .../test_rtpllm_glm5_indexer_contract.py | 24 + .../test_rtpllm_glm5_mla_forward_contract.py | 36 +- ...est_rtpllm_glm5_sparse_backend_contract.py | 162 ++++ .../test_rtpllm_glm5_wrapper_lifecycle.py | 57 ++ 10 files changed, 1428 insertions(+), 385 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py index 8cd9b06f25..28f9903f2f 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py @@ -50,6 +50,19 @@ def __init__(self, *, mla_modules: Any) -> None: self.qk_rope_head_dim = getattr(mla_modules, "qk_rope_head_dim", None) self._projection_checked = False + def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 + del attn_inputs + + def prewarm_for_cuda_graph( + self, + *, + max_num_tokens: int, + max_seq_len: int, + query_dtype: torch.dtype, + device: torch.device, + ) -> None: + del max_num_tokens, max_seq_len, query_dtype, device + @staticmethod def _read_is_prefill(context: Any) -> bool: if context is None or not hasattr(context, "is_prefill"): @@ -62,19 +75,19 @@ def _read_is_prefill(context: Any) -> bool: def _get_metadata(num_tokens: int, device: torch.device) -> _DenseMlaMetadata: attn_metadata = None context = None - rtp_seq_size_per_block = 1 + rtp_seq_size_per_block = 0 try: from atom.utils.forward_context import get_forward_context forward_context = get_forward_context() attn_metadata = getattr(forward_context, "attn_metadata", None) context = getattr(forward_context, "context", None) + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) rtp_seq_size_per_block = int( - getattr(attn_metadata, "rtp_seq_size_per_block", 0) - or getattr(attn_metadata, "rtp_kernel_seq_size_per_block", 0) + getattr(plugin_metadata, "sparse_block_size", 0) + or getattr(attn_metadata, "rtp_seq_size_per_block", 0) or 0 ) - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) query_start_loc = getattr(plugin_metadata, "query_start_loc", None) if query_start_loc is None: query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) @@ -378,17 +391,26 @@ def _gather_latent_history( ) if flat_cache is None: return None + block_size = int(metadata.block_size) + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if ( + block_size <= 1 + and isinstance(kv_cache_base, torch.Tensor) + and kv_cache_base.dim() == 3 + and int(kv_cache_base.shape[-1]) == kv_dim + ): + block_size = int(kv_cache_base.shape[1]) seq_len = int(metadata.seq_lens[batch_idx].item()) if seq_len <= 0: return None block_row = metadata.block_table[batch_idx].long() positions = torch.arange(seq_len, dtype=torch.long, device=flat_cache.device) - block_cols = torch.div(positions, metadata.block_size, rounding_mode="floor") + block_cols = torch.div(positions, block_size, rounding_mode="floor") block_col_max = int(block_cols.max().item()) if block_col_max >= int(block_row.numel()): return None - offsets = positions.remainder(metadata.block_size) - slots = block_row[block_cols] * metadata.block_size + offsets + offsets = positions.remainder(block_size) + slots = block_row[block_cols] * block_size + offsets flat_size = int(flat_cache.shape[0]) if bool(((slots < 0) | (slots >= flat_size)).any().item()): bad_slots = slots[(slots < 0) | (slots >= flat_size)] diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index 3944a92c10..8b86e5d6cb 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -55,6 +55,13 @@ def _should_emit_topk_indices(attn) -> bool: return True +def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None: + if indexer is None or not hasattr(indexer, "sparse_attn_indexer_impl"): + return + __import__("atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend") + indexer.sparse_attn_indexer_impl = torch.ops.aiter.rtp_sparse_attn_indexer + + class RTPMLAAttention: """Dense RTP MLA adapter for the native GLM5 MLA call contract.""" @@ -69,6 +76,7 @@ def __init__(self, *args, **kwargs) -> None: self.o_proj = getattr(mla_modules, "o_proj", None) self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) self.indexer = getattr(mla_modules, "indexer", None) + _use_rtp_sparse_attn_indexer(self.indexer) self.qk_head_dim = getattr(mla_modules, "qk_head_dim", None) self.v_head_dim = getattr(mla_modules, "v_head_dim", None) self.q_lora_rank = getattr(mla_modules, "q_lora_rank", None) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index b82270cd83..27721be6dd 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -3,12 +3,13 @@ from __future__ import annotations import inspect -import os from dataclasses import dataclass from typing import Any, Optional import torch +from atom.utils.custom_register import direct_register_custom_op + class _SparseUnavailable(RuntimeError): pass @@ -95,6 +96,8 @@ def __init__( ) self._absorbed_weights: _AbsorbedWeights | None = None self._cache_write_scale: dict[torch.device, torch.Tensor] = {} + self._cg_sparse_bufs: dict[str, torch.Tensor] | None = None + self._cg_workspace_signature: tuple[Any, ...] | None = None @staticmethod def _unwrap_linear_output(value: Any) -> torch.Tensor: @@ -105,10 +108,26 @@ def _unwrap_linear_output(value: Any) -> torch.Tensor: return value def _infer_num_heads(self, q: torch.Tensor) -> int: - if self.num_heads > 0: - return self.num_heads - self.num_heads = int(q.shape[1]) - return self.num_heads + num_heads = int(q.shape[1]) + if self.num_heads != num_heads: + self.num_heads = num_heads + return num_heads + + def _infer_num_heads_from_weight(self, fallback: int) -> int: + try: + weight = self._read_kv_b_proj_weight() + except Exception: + return int(fallback) + per_head_dim = int(self.qk_nope_head_dim + self.v_head_dim) + if per_head_dim <= 0 or weight.ndim != 2: + return int(fallback) + for dim in weight.shape: + dim_i = int(dim) + if dim_i > 0 and dim_i % per_head_dim == 0: + candidate = dim_i // per_head_dim + if candidate > 0: + return max(int(fallback), int(candidate)) + return int(fallback) def _read_kv_b_proj_weight(self) -> torch.Tensor: if self.kv_b_proj is None: @@ -188,10 +207,44 @@ def _apply_rope( f"(positions={None if positions is None else int(positions.numel())}, " f"tokens={int(q.shape[0])})." ) - q_rope = q.clone() - k_pe_rope = k_pe.clone() + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if self._cg_sparse_bufs is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires RoPE buffers.") + if positions.device != q.device or positions.dtype != torch.long: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int64 positions on device." + ) + if not positions.is_contiguous(): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires contiguous positions." + ) + q_rope = self._cg_sparse_bufs["q_rope"][: q.shape[0], : q.shape[1], : q.shape[2]] + q_rope.copy_(q) + if k_pe.dim() == 2: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_2d"][ + : k_pe.shape[0], : k_pe.shape[1] + ] + elif k_pe.dim() == 3 and int(k_pe.shape[1]) == 1: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_3d"][ + : k_pe.shape[0], : k_pe.shape[1], : k_pe.shape[2] + ] + elif k_pe.dim() == 3: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_heads"][ + : k_pe.shape[0], : k_pe.shape[1], : k_pe.shape[2] + ] + else: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA capture got invalid k_pe ndim={k_pe.dim()}." + ) + k_pe_rope.copy_(k_pe) + rope_positions = positions.view(-1) + else: + q_rope = q.clone() + k_pe_rope = k_pe.clone() + rope_positions = positions.reshape(-1).to(device=q.device, dtype=torch.long) rotated_q_pe, rotated_k_pe = self.rotary_emb( - positions.to(device=q.device, dtype=torch.long), + rope_positions, q_rope[..., -rope_dim:], k_pe_rope, ) @@ -212,9 +265,6 @@ def _cache_dtype_name(self, kv_cache_base: torch.Tensor) -> str: } if kv_cache_base.dtype not in fp8_dtypes: return "auto" - explicit = os.getenv("ATOM_RTP_MLA_FP8_CACHE_DTYPE", "").strip() - if explicit: - return explicit return "fp8_model1_mla" if self.kv_lora_rank == 448 else "fp8_ds_mla" def _write_current_to_cache( @@ -243,12 +293,23 @@ def _write_current_to_cache( if scale is None: scale = torch.tensor(1.0, dtype=torch.float32, device=compressed_kv.device) self._cache_write_scale[compressed_kv.device] = scale + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if slot_mapping.device != compressed_kv.device or slot_mapping.dtype != torch.int64: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int64 slot_mapping on device." + ) + slot_mapping_for_cache = slot_mapping + else: + slot_mapping_for_cache = slot_mapping.to( + device=compressed_kv.device, dtype=torch.int64 + ) try: concat_and_cache_mla( compressed_kv, k_pe, kv_cache_base, - slot_mapping.to(device=compressed_kv.device, dtype=torch.int64), + slot_mapping_for_cache, kv_cache_dtype=self._cache_dtype_name(kv_cache_base), scale=scale, ) @@ -390,6 +451,136 @@ def _aiter_dtype_for_tensor(tensor: torch.Tensor) -> Any: return dtypes.d_dtypes["fp16"] return dtypes.d_dtypes["bf16"] + @staticmethod + def _aiter_dtype_for_torch_dtype(dtype: torch.dtype, *, assume_fp8: bool = False) -> Any: + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + if assume_fp8: + return dtypes.fp8 + if dtype == torch.float16: + return dtypes.d_dtypes["fp16"] + return dtypes.d_dtypes["bf16"] + + def _resolve_topk_for_prewarm(self) -> int: + for obj, attr in ( + (getattr(self.mla_modules, "indexer", None), "index_topk"), + (getattr(self.mla_modules, "indexer", None), "topk_tokens"), + (self.mla_modules, "index_topk"), + (getattr(self.mla_modules, "config", None), "index_topk"), + ): + value = getattr(obj, attr, None) if obj is not None else None + if value is not None: + return int(value) + return 2048 + + def prewarm_for_cuda_graph( + self, + *, + max_num_tokens: int, + max_seq_len: int, + query_dtype: torch.dtype, + device: torch.device, + ) -> None: + del max_seq_len + try: + from aiter import get_mla_metadata_info_v1 + except Exception as exc: + raise _SparseUnavailable(f"aiter metadata prewarm unavailable: {exc}") from exc + + max_tokens = int(max_num_tokens) + if max_tokens <= 0: + return + num_heads = int(self.num_heads or getattr(self.mla_modules, "num_local_heads", 0) or 0) + if num_heads <= 0: + # Lazily inferred in eager path; graph capture needs a stable budget. + num_heads = int(getattr(self.mla_modules, "num_heads", 0) or 1) + num_heads = self._infer_num_heads_from_weight(num_heads) + self.num_heads = num_heads + padded_num_heads = max(num_heads, 16) + if padded_num_heads % num_heads != 0: + padded_num_heads = ((padded_num_heads + num_heads - 1) // num_heads) * num_heads + topk = self._resolve_topk_for_prewarm() + latent_dim = self.kv_lora_rank + self.qk_rope_head_dim + q_dtype = self._aiter_dtype_for_torch_dtype(query_dtype) + kv_dtype = self._aiter_dtype_for_torch_dtype(query_dtype, assume_fp8=True) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max(max_tokens, 1), + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + self._cg_sparse_bufs = { + "qo_indptr": torch.arange(max_tokens + 1, device=device, dtype=torch.int32), + "sparse_seqlen": torch.empty(max_tokens, device=device, dtype=torch.int32), + "paged_kv_indptr": torch.empty(max_tokens + 1, device=device, dtype=torch.int32), + "paged_kv_last_page_len": torch.ones(max_tokens, device=device, dtype=torch.int32), + "paged_kv_indices": torch.empty(max_tokens * topk, device=device, dtype=torch.int32), + "q_rope": torch.empty( + max_tokens, + num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + device=device, + dtype=query_dtype, + ), + "k_pe_rope_2d": torch.empty( + max_tokens, self.qk_rope_head_dim, device=device, dtype=query_dtype + ), + "k_pe_rope_3d": torch.empty( + max_tokens, 1, self.qk_rope_head_dim, device=device, dtype=query_dtype + ), + "k_pe_rope_heads": torch.empty( + max_tokens, num_heads, self.qk_rope_head_dim, device=device, dtype=query_dtype + ), + "q_latent_nope_t": torch.empty( + num_heads, max_tokens, self.kv_lora_rank, device=device, dtype=query_dtype + ), + "q_latent": torch.empty( + max_tokens, num_heads, latent_dim, device=device, dtype=query_dtype + ), + "q_for_kernel": torch.empty( + max_tokens, padded_num_heads, latent_dim, device=device, dtype=query_dtype + ), + "latent_output": torch.empty( + max_tokens, padded_num_heads, self.kv_lora_rank, device=device, dtype=query_dtype + ), + "final_output_t": torch.empty( + num_heads, max_tokens, self.v_head_dim, device=device, dtype=query_dtype + ), + "work_meta_data": torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device), + "work_indptr": torch.empty(work_indptr_size, dtype=work_indptr_type, device=device), + "work_info_set": torch.empty(work_info_set_size, dtype=work_info_set_type, device=device), + "reduce_indptr": torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device), + "reduce_final_map": torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ), + "reduce_partial_map": torch.empty( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=device + ), + } + self._cg_sparse_bufs["paged_kv_indptr"].zero_() + self._cache_write_scale[device] = torch.tensor( + 1.0, dtype=torch.float32, device=device + ) + self._cg_workspace_signature = ( + max_tokens, + padded_num_heads, + topk, + query_dtype, + device, + ) + def _build_atom_sparse_metadata( self, *, @@ -399,8 +590,6 @@ def _build_atom_sparse_metadata( attn_metadata: Any, block_size: int, ) -> _AtomSparseMetadata: - if torch.cuda.is_current_stream_capturing(): - raise _SparseUnavailable("ATOM sparse MLA metadata is not graph-capture safe yet.") try: from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 from atom.plugin.attention_mla_sparse import ( @@ -418,13 +607,22 @@ def _build_atom_sparse_metadata( num_heads = int(q_latent.shape[1]) topk = int(topk_indices.shape[1]) device = q_latent.device + in_capture = torch.cuda.is_current_stream_capturing() + cg_bufs = getattr(plugin_metadata, "cg_bufs", None) + sparse_bufs = self._cg_sparse_bufs query_start_loc = getattr(plugin_metadata, "query_start_loc", None) if query_start_loc is None: query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) if not isinstance(query_start_loc, torch.Tensor) or int(query_start_loc.numel()) < 2: raise _SparseUnavailable("GLM5 RTP sparse MLA requires query_start_loc.") - query_start_loc = query_start_loc.to(device=device, dtype=torch.int32).contiguous() + if in_capture: + if query_start_loc.device != device or query_start_loc.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 query_start_loc on device." + ) + else: + query_start_loc = query_start_loc.to(device=device, dtype=torch.int32).contiguous() seq_lens = getattr(plugin_metadata, "seq_lens", None) if seq_lens is None: @@ -433,38 +631,87 @@ def _build_atom_sparse_metadata( query_start_loc.numel() ): raise _SparseUnavailable("GLM5 RTP sparse MLA requires seq_lens per request.") - seq_lens = seq_lens.to(device=device, dtype=torch.int32).contiguous() - - req_id = self._build_req_id_per_token(attn_metadata, num_tokens, device).to( - dtype=torch.int32 - ) - block_table = self._block_table(attn_metadata, device).to(dtype=torch.int32) - topk_indices_i32 = topk_indices.to(device=device, dtype=torch.int32).contiguous() - query_lens = (query_start_loc[1:] - query_start_loc[:-1]).contiguous() - - if device.type == "cpu": - sparse_seqlen = self._generate_sparse_seqlen_torch( - query_lens=query_lens, - seq_lens=seq_lens, - query_start_loc=query_start_loc, - topk=topk, - num_tokens=num_tokens, - ) + if in_capture: + if seq_lens.device != device or seq_lens.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 seq_lens on device." + ) + else: + seq_lens = seq_lens.to(device=device, dtype=torch.int32).contiguous() + + if in_capture: + if not isinstance(cg_bufs, dict) or sparse_bufs is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires prewarmed buffers.") + req_id = cg_bufs.get("seq_id_i32", None) + if not isinstance(req_id, torch.Tensor): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires prewarmed seq_id_i32." + ) + req_id = req_id[:num_tokens] + block_table = getattr(plugin_metadata, "block_table", None) + if not isinstance(block_table, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires block_table.") + if block_table.device != device or block_table.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 block_table on device." + ) + topk_indices_i32 = topk_indices + if topk_indices_i32.device != device or topk_indices_i32.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 topk_indices on device." + ) + if not topk_indices_i32.is_contiguous(): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires contiguous topk_indices." + ) + sparse_seqlen = sparse_bufs["sparse_seqlen"][:num_tokens] + torch.clamp(seq_lens[:num_tokens], min=0, max=topk, out=sparse_seqlen) + max_query_len_for_sparse = 1 else: - sparse_seqlen = generate_sparse_seqlen_triton( - query_lens, - seq_lens, - query_start_loc, - topk, - num_tokens, - int(torch.max(query_lens).detach().cpu().item()) if num_tokens else 1, + req_id = self._build_req_id_per_token(attn_metadata, num_tokens, device).to( + dtype=torch.int32 + ) + block_table = self._block_table(attn_metadata, device).to(dtype=torch.int32) + topk_indices_i32 = topk_indices.to(device=device, dtype=torch.int32).contiguous() + query_lens = (query_start_loc[1:] - query_start_loc[:-1]).contiguous() + max_query_len_for_sparse = ( + int(torch.max(query_lens).detach().cpu().item()) if num_tokens else 1 ) - qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) - paged_kv_indptr = torch.zeros((num_tokens + 1,), device=device, dtype=torch.int32) + if device.type == "cpu": + sparse_seqlen = self._generate_sparse_seqlen_torch( + query_lens=query_lens, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + topk=topk, + num_tokens=num_tokens, + ) + else: + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + query_start_loc, + topk, + num_tokens, + max_query_len_for_sparse, + ) + + if in_capture: + qo_indptr = sparse_bufs["qo_indptr"][: num_tokens + 1] + paged_kv_indptr = sparse_bufs["paged_kv_indptr"][: num_tokens + 1] + paged_kv_indptr[0].zero_() + paged_kv_last_page_len = sparse_bufs["paged_kv_last_page_len"][:num_tokens] + if int(sparse_bufs["paged_kv_indices"].numel()) < num_tokens * topk: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_indices buffer is too small." + ) + paged_kv_indices = sparse_bufs["paged_kv_indices"][: num_tokens * topk] + else: + qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) + paged_kv_indptr = torch.zeros((num_tokens + 1,), device=device, dtype=torch.int32) + paged_kv_last_page_len = torch.ones((num_tokens,), device=device, dtype=torch.int32) + paged_kv_indices = torch.zeros((num_tokens * topk,), device=device, dtype=torch.int32) torch.cumsum(sparse_seqlen, dim=0, out=paged_kv_indptr[1:]) - paged_kv_last_page_len = torch.ones((num_tokens,), device=device, dtype=torch.int32) - paged_kv_indices = torch.zeros((num_tokens * topk,), device=device, dtype=torch.int32) triton_convert_req_index_to_global_index( req_id, @@ -482,32 +729,40 @@ def _build_atom_sparse_metadata( head_repeat_factor = padded_num_heads // num_heads q_dtype = self._aiter_dtype_for_tensor(q_latent) kv_dtype = self._aiter_dtype_for_tensor(kv_cache_base) - ( - (work_meta_data_size, work_meta_data_type), - (work_indptr_size, work_indptr_type), - (work_info_set_size, work_info_set_type), - (reduce_indptr_size, reduce_indptr_type), - (reduce_final_map_size, reduce_final_map_type), - (reduce_partial_map_size, reduce_partial_map_type), - ) = get_mla_metadata_info_v1( - max(num_tokens, 1), - 1, - padded_num_heads, - q_dtype, - kv_dtype, - is_sparse=True, - fast_mode=True, - ) - work_meta_data = torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device) - work_indptr = torch.empty(work_indptr_size, dtype=work_indptr_type, device=device) - work_info_set = torch.empty(work_info_set_size, dtype=work_info_set_type, device=device) - reduce_indptr = torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device) - reduce_final_map = torch.empty( - reduce_final_map_size, dtype=reduce_final_map_type, device=device - ) - reduce_partial_map = torch.empty( - reduce_partial_map_size, dtype=reduce_partial_map_type, device=device - ) + if in_capture: + work_meta_data = sparse_bufs["work_meta_data"] + work_indptr = sparse_bufs["work_indptr"] + work_info_set = sparse_bufs["work_info_set"] + reduce_indptr = sparse_bufs["reduce_indptr"] + reduce_final_map = sparse_bufs["reduce_final_map"] + reduce_partial_map = sparse_bufs["reduce_partial_map"] + else: + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max(num_tokens, 1), + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + work_meta_data = torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device) + work_indptr = torch.empty(work_indptr_size, dtype=work_indptr_type, device=device) + work_info_set = torch.empty(work_info_set_size, dtype=work_info_set_type, device=device) + reduce_indptr = torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device) + reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + reduce_partial_map = torch.empty( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=device + ) get_mla_metadata_v1( qo_indptr, paged_kv_indptr, @@ -523,8 +778,8 @@ def _build_atom_sparse_metadata( reduce_partial_map, page_size=1, kv_granularity=16, - max_seqlen_qo=1, - uni_seqlen_qo=1, + max_seqlen_qo=max_query_len_for_sparse, + uni_seqlen_qo=max_query_len_for_sparse, fast_mode=True, dtype_q=q_dtype, dtype_kv=kv_dtype, @@ -553,9 +808,24 @@ def _run_sparse_decode( attn_metadata: Any, block_size: int, ) -> torch.Tensor: + if torch.cuda.is_current_stream_capturing(): + return self._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + is_prefill = bool(getattr(plugin_metadata, "num_prefills", 0) or 0) try: from flash_mla import flash_mla_sparse_fwd except Exception as exc: + if is_prefill: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA prefill requires flash_mla_sparse_fwd; " + "refusing to run prefill through the decode kernel." + ) from exc return self._run_aiter_sparse_decode( q_latent=q_latent, kv_cache_base=kv_cache_base, @@ -605,17 +875,32 @@ def _run_aiter_sparse_decode( attn_metadata=attn_metadata, block_size=block_size, ) + in_capture = torch.cuda.is_current_stream_capturing() if sparse_meta.head_repeat_factor > 1: - q_for_kernel = q_latent.repeat_interleave( - sparse_meta.head_repeat_factor, dim=1 - ) + if in_capture and self._cg_sparse_bufs is not None: + q_for_kernel = self._cg_sparse_bufs["q_for_kernel"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + for repeat_idx in range(sparse_meta.head_repeat_factor): + q_for_kernel[ + :, repeat_idx :: sparse_meta.head_repeat_factor, : + ].copy_(q_latent) + else: + q_for_kernel = q_latent.repeat_interleave( + sparse_meta.head_repeat_factor, dim=1 + ) else: q_for_kernel = q_latent - output = torch.empty( - (num_tokens, sparse_meta.padded_num_heads, self.kv_lora_rank), - dtype=q_for_kernel.dtype, - device=q_latent.device, - ) + if in_capture and self._cg_sparse_bufs is not None: + output = self._cg_sparse_bufs["latent_output"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + else: + output = torch.empty( + (num_tokens, sparse_meta.padded_num_heads, self.kv_lora_rank), + dtype=q_for_kernel.dtype, + device=q_latent.device, + ) try: kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) mla_decode_fwd( @@ -639,7 +924,9 @@ def _run_aiter_sparse_decode( except Exception as exc: raise _SparseUnavailable(f"mla_decode_fwd failed: {exc}") from exc if sparse_meta.head_repeat_factor > 1: - output = output[:, :: sparse_meta.head_repeat_factor, :].contiguous() + output = output[:, :: sparse_meta.head_repeat_factor, :] + if not in_capture: + output = output.contiguous() return output def forward( @@ -658,7 +945,7 @@ def forward( if attn_metadata is None: raise _SparseUnavailable("GLM5 RTP sparse MLA requires attn_metadata.") if getattr(getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False): - raise _SparseUnavailable("GLM5 RTP sparse MLA skips dummy warmup.") + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) q_rope, k_pe_rope = self._apply_rope(q, k_pe, positions) kv_cache_base = self._write_current_to_cache( compressed_kv=compressed_kv, @@ -669,17 +956,36 @@ def forward( absorbed = self._get_absorbed_weights(q_rope) q_nope = q_rope[..., : self.qk_nope_head_dim] - q_latent_nope = torch.bmm( - q_nope.transpose(0, 1).to(dtype=absorbed.w_kc.dtype), - absorbed.w_kc, - ).transpose(0, 1) - q_latent = torch.empty( - q.shape[0], - q.shape[1], - self.kv_lora_rank + self.qk_rope_head_dim, - dtype=q_latent_nope.dtype, - device=q.device, - ) + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if self._cg_sparse_bufs is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires q buffers.") + if q_nope.dtype != absorbed.w_kc.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires q_nope dtype to match absorbed weights." + ) + q_latent_nope_t = self._cg_sparse_bufs["q_latent_nope_t"][ + : q.shape[1], : q.shape[0], : + ] + torch.bmm(q_nope.transpose(0, 1), absorbed.w_kc, out=q_latent_nope_t) + q_latent_nope = q_latent_nope_t.transpose(0, 1) + q_latent = self._cg_sparse_bufs["q_latent"][ + : q.shape[0], + : q.shape[1], + : self.kv_lora_rank + self.qk_rope_head_dim, + ] + else: + q_latent_nope = torch.bmm( + q_nope.transpose(0, 1).to(dtype=absorbed.w_kc.dtype), + absorbed.w_kc, + ).transpose(0, 1) + q_latent = torch.empty( + q.shape[0], + q.shape[1], + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=q_latent_nope.dtype, + device=q.device, + ) q_latent[..., : self.kv_lora_rank] = q_latent_nope if self.qk_rope_head_dim > 0: q_latent[..., self.kv_lora_rank :] = q_rope[ @@ -699,6 +1005,21 @@ def forward( attn_metadata=attn_metadata, block_size=block_size, ) + if in_capture: + if latent_output.dtype != absorbed.w_vc.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires latent output dtype to match absorbed weights." + ) + output_t = self._cg_sparse_bufs["final_output_t"][ + : q.shape[1], : q.shape[0], : + ] + torch.bmm(latent_output.transpose(0, 1), absorbed.w_vc, out=output_t) + output = output_t.transpose(0, 1) + if output.dtype != q.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires final output dtype to match q." + ) + return output output = torch.bmm( latent_output.transpose(0, 1).to(dtype=absorbed.w_vc.dtype), absorbed.w_vc, @@ -751,6 +1072,34 @@ def __init__( self.sparse_impl = _ContractSparseMlaImpl(self.v_head_dim) self._default_mock = True + def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 + del attn_inputs + + def prewarm_for_cuda_graph( + self, + *, + max_num_tokens: int, + max_seq_len: int, + query_dtype: torch.dtype, + device: torch.device, + ) -> None: + dense_prewarm = getattr(self.dense_backend, "prewarm_for_cuda_graph", None) + if callable(dense_prewarm): + dense_prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=query_dtype, + device=device, + ) + sparse_prewarm = getattr(self.sparse_impl, "prewarm_for_cuda_graph", None) + if callable(sparse_prewarm): + sparse_prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=query_dtype, + device=device, + ) + @staticmethod def _get_attn_metadata() -> object: try: @@ -777,24 +1126,6 @@ def _validate_topk_indices(q: torch.Tensor, topk_indices: torch.Tensor) -> None: f"got {topk_indices.shape[0]} and {q.shape[0]}" ) - @staticmethod - def _enable_sparse_mock() -> bool: - return os.getenv("ATOM_RTP_ENABLE_SPARSE_MLA_MOCK", "0").strip().lower() in { - "1", - "true", - "yes", - "on", - } - - @staticmethod - def _strict_sparse() -> bool: - return os.getenv("ATOM_RTP_SPARSE_MLA_STRICT", "0").strip().lower() in { - "1", - "true", - "yes", - "on", - } - @staticmethod def _impl_accepts_positions(impl: object) -> bool: try: @@ -849,23 +1180,24 @@ def forward( topk_indices: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: + attn_metadata = self._get_attn_metadata() + if getattr(getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False): + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + if topk_indices is None: return self._dense_forward( q, compressed_kv, k_pe, kv_cache, layer_id, None, positions ) self._validate_topk_indices(q, topk_indices) - if ( - (self._default_mock and not self._enable_sparse_mock()) - or not callable(getattr(self.sparse_impl, "forward", None)) - ): - return self._dense_forward( - q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions + if self._default_mock or not callable(getattr(self.sparse_impl, "forward", None)): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA is unavailable; refusing dense fallback." ) kwargs = { "topk_indices": topk_indices, - "attn_metadata": self._get_attn_metadata(), + "attn_metadata": attn_metadata, } if self._impl_accepts_positions(self.sparse_impl): kwargs["positions"] = positions @@ -879,8 +1211,121 @@ def forward( **kwargs, ) except _SparseUnavailable: - if self._strict_sparse(): - raise - return self._dense_forward( - q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions - ) + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if bool(getattr(plugin_metadata, "num_prefills", 0) or 0): + return self._dense_forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions + ) + raise + + +def rtp_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, +) -> torch.Tensor: + from atom.models.deepseek_v2 import sparse_attn_indexer + + return sparse_attn_indexer( + hidden_states, + k_cache_prefix, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + ) + + +def rtp_sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, +) -> torch.Tensor: + from atom.models.deepseek_v2 import sparse_attn_indexer_fake + + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + ) + + +direct_register_custom_op( + op_name="rtp_sparse_attn_indexer", + op_func=rtp_sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=rtp_sparse_attn_indexer_fake, +) diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index aa9236d893..885cd4690c 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -41,16 +41,80 @@ def load_lora_weights(self, adapter_name, lora_path, device): # noqa: ANN001 class _ATOMGlm5AttnPyObj: - """Minimal attention object so RTP does not build native MLA fmha_impl.""" + """Container returned to RTP CudaGraphRunner for replay-time hooks.""" - is_cuda_graph = False + def __init__(self, runtime: "_ATOMGlm5MoeRuntime") -> None: + self._runtime = runtime + self.is_cuda_graph = False + self._rtp_mla_layers: list[Any] = [] + self._rtp_sparse_mla_backends: list[Any] = [] + self._rtp_dense_mla_backends: list[Any] = [] + self._collect_mla_layers() + + @staticmethod + def _append_unique(items: list[Any], value: Any) -> None: + if value is not None and all(value is not item for item in items): + items.append(value) + + def _collect_mla_layers(self) -> None: + try: + from atom.plugin.rtpllm.attention_backend import ( + RTPDenseMlaBackend, + RTPMLAAttention, + RTPSparseMlaBackend, + ) + except (ImportError, ModuleNotFoundError): + RTPDenseMlaBackend = None + RTPMLAAttention = None + RTPSparseMlaBackend = None + + candidates: list[Any] = [] + _, _, mla_layer_map = self._runtime._rtp_layer_maps + candidates.extend(mla_layer_map.values()) + for module in self._runtime.model.modules(): + candidates.append(module) + mla_attn = getattr(module, "mla_attn", None) + if mla_attn is not None: + candidates.append(mla_attn) + + for candidate in candidates: + if RTPMLAAttention is not None and isinstance(candidate, RTPMLAAttention): + self._append_unique(self._rtp_mla_layers, candidate) + backend = getattr(candidate, "dense_backend", None) + else: + backend = getattr(candidate, "dense_backend", None) + if backend is None and RTPSparseMlaBackend is not None and isinstance( + candidate, RTPSparseMlaBackend + ): + backend = candidate + + if RTPSparseMlaBackend is not None and isinstance( + backend, RTPSparseMlaBackend + ): + self._append_unique(self._rtp_sparse_mla_backends, backend) + dense_backend = getattr(backend, "dense_backend", None) + if RTPDenseMlaBackend is not None and isinstance( + dense_backend, RTPDenseMlaBackend + ): + self._append_unique(self._rtp_dense_mla_backends, dense_backend) + elif RTPDenseMlaBackend is not None and isinstance( + backend, RTPDenseMlaBackend + ): + self._append_unique(self._rtp_dense_mla_backends, backend) @property def fmha_params(self): return None def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 - return None + for layer in self._rtp_mla_layers: + prepare = getattr(layer, "prepare_cuda_graph", None) + if callable(prepare): + prepare(attn_inputs) + for backend in self._rtp_sparse_mla_backends + self._rtp_dense_mla_backends: + prepare = getattr(backend, "prepare_cuda_graph", None) + if callable(prepare): + prepare(attn_inputs) class _ATOMGlm5MoeRuntime(GptModelBase): @@ -90,12 +154,20 @@ def __init__( self._rtp_kv_cache_signature: tuple | None = None self._rtp_layer_group_map: dict[int, int] | None = None self._rtp_layer_group_map_signature: tuple | None = None + decode_caps = getattr(py_hw_kernel_config, "decode_capture_batch_sizes", None) + if decode_caps: + self._cg_max_num_tokens: int = min( + int(max(decode_caps)), int(max_generate_batch_size) + ) + else: + self._cg_max_num_tokens: int = int(max_generate_batch_size) self._cg_max_seq_len: int = int( getattr(model_config, "max_seq_len", 0) or getattr(model_config, "max_position_embeddings", 0) or 32768 ) self._atom_attn_pyobj: _ATOMGlm5AttnPyObj | None = None + self._cg_layers_prewarmed: bool = False def load_weights(self): return None @@ -104,12 +176,132 @@ def prepare_fmha_impl( self, inputs: PyModelInputs, is_cuda_graph: bool = False ) -> _ATOMGlm5AttnPyObj: if self._atom_attn_pyobj is None: - self._atom_attn_pyobj = _ATOMGlm5AttnPyObj() + self._atom_attn_pyobj = _ATOMGlm5AttnPyObj(self) self._atom_attn_pyobj.is_cuda_graph = bool(is_cuda_graph) if bool(is_cuda_graph): inputs.attention_inputs.is_cuda_graph = True + self._ensure_cuda_graph_prewarmed() return self._atom_attn_pyobj + def _ensure_cuda_graph_prewarmed(self) -> None: + if self._cg_layers_prewarmed: + return + if self._atom_attn_pyobj is None: + return + + max_num_tokens = int(self._cg_max_num_tokens) + max_seq_len = int(self._cg_max_seq_len) + if max_num_tokens <= 0 or max_seq_len <= 0: + logger.warning( + "ATOM GLM5 cuda-graph prewarm skipped: invalid budget " + "(max_num_tokens=%d, max_seq_len=%d)", + max_num_tokens, + max_seq_len, + ) + return + + device = self._get_model_device() + dtype = self._get_model_dtype() + kv_cache = getattr(self, "kv_cache", None) + seq_size_per_block = ( + int(getattr(kv_cache, "seq_size_per_block", 0)) + or int(os.getenv("SEQ_SIZE_PER_BLOCK", "0") or 0) + or 1 + ) + kernel_seq_size_per_block = ( + int(getattr(kv_cache, "kernel_seq_size_per_block", 0)) + or int(os.getenv("KERNEL_SEQ_SIZE_PER_BLOCK", "0") or 0) + or seq_size_per_block + ) + physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block + 1 + recovered_physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block + indexer_max_blocks = ( + int(max_seq_len) + kernel_seq_size_per_block - 1 + ) // kernel_seq_size_per_block + 1 + block_table_max_blocks = max(physical_max_blocks, indexer_max_blocks) + + for backend in self._atom_attn_pyobj._rtp_sparse_mla_backends: + prewarm = getattr(backend, "prewarm_for_cuda_graph", None) + if callable(prewarm): + prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=dtype, + device=device, + ) + for backend in self._atom_attn_pyobj._rtp_dense_mla_backends: + prewarm = getattr(backend, "prewarm_for_cuda_graph", None) + if callable(prewarm): + prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=dtype, + device=device, + ) + + self._cg_meta_bufs: dict[str, torch.Tensor] = { + "query_start_loc": torch.arange( + 0, max_num_tokens + 1, device=device, dtype=torch.int32 + ), + "seq_id": torch.arange(0, max_num_tokens, device=device, dtype=torch.int64), + "seq_id_i32": torch.arange( + 0, max_num_tokens, device=device, dtype=torch.int32 + ), + "positions_i32": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "positions_i64": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "block_col": torch.empty(max_num_tokens, device=device, dtype=torch.int32), + "block_col_i64": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "slot_base": torch.empty(max_num_tokens, device=device, dtype=torch.int32), + "token_offset": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "slot_mapping": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "seq_lens_i32": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "physical_block_table_i32": torch.empty( + max_num_tokens, + recovered_physical_max_blocks, + device=device, + dtype=torch.int32, + ), + "block_table_i32": torch.empty( + max_num_tokens, block_table_max_blocks, device=device, dtype=torch.int32 + ), + "indexer_block_table_i32": torch.empty( + max_num_tokens, indexer_max_blocks, device=device, dtype=torch.int32 + ), + } + self._cg_layers_prewarmed = True + logger.info( + "ATOM GLM5 cuda-graph prewarmed " + "(max_num_tokens=%d, max_seq_len=%d, sparse_layers=%d, dense_layers=%d, " + "physical_block_table_i32[%dx%d], block_table_i32[%dx%d], " + "indexer_block_table_i32[%dx%d])", + max_num_tokens, + max_seq_len, + len(self._atom_attn_pyobj._rtp_sparse_mla_backends), + len(self._atom_attn_pyobj._rtp_dense_mla_backends), + max_num_tokens, + recovered_physical_max_blocks, + max_num_tokens, + block_table_max_blocks, + max_num_tokens, + indexer_max_blocks, + ) + @staticmethod def _get_forward_context_cls(): global RTPForwardContext @@ -182,6 +374,16 @@ def _build_positions_from_attention_inputs( starts = prefix_lengths_i32[: int(input_lengths_i32.numel())] return self._build_token_positions(input_lengths_i32, starts) + sequence_lengths_plus_1 = getattr(attn_inputs, "sequence_lengths_plus_1_d", None) + if sequence_lengths_plus_1 is not None and sequence_lengths_plus_1.numel() > 0: + seq_plus_one_i32 = sequence_lengths_plus_1.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(seq_plus_one_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = seq_plus_one_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + return self._build_token_positions(input_lengths_i32, starts) + sequence_lengths = getattr(attn_inputs, "sequence_lengths", None) if sequence_lengths is None or sequence_lengths.numel() == 0: return None @@ -193,6 +395,38 @@ def _build_positions_from_attention_inputs( starts = sequence_lengths_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + 1 return self._build_token_positions(input_lengths_i32, starts) + def _build_graph_decode_positions( + self, attn_inputs: Any, model_device: torch.device + ) -> torch.Tensor | None: + sequence_lengths_plus_1 = getattr(attn_inputs, "sequence_lengths_plus_1_d", None) + if sequence_lengths_plus_1 is None or sequence_lengths_plus_1.numel() == 0: + return None + input_lengths = getattr(attn_inputs, "input_lengths", None) + if input_lengths is None or input_lengths.numel() == 0: + return None + num_tokens = int(input_lengths.numel()) + seq_plus_one_i32 = sequence_lengths_plus_1.to( + device=model_device, dtype=torch.int32, non_blocking=True + ) + if int(seq_plus_one_i32.numel()) < num_tokens: + return None + cg_bufs = getattr(self, "_cg_meta_bufs", None) + if isinstance(cg_bufs, dict): + positions_buf = cg_bufs.get("positions_i32") + if isinstance(positions_buf, torch.Tensor) and int(positions_buf.numel()) >= num_tokens: + positions_i32 = positions_buf[:num_tokens] + torch.sub(seq_plus_one_i32[:num_tokens], 1, out=positions_i32) + positions_i64_buf = cg_bufs.get("positions_i64") + if ( + isinstance(positions_i64_buf, torch.Tensor) + and int(positions_i64_buf.numel()) >= num_tokens + ): + positions_i64 = positions_i64_buf[:num_tokens] + positions_i64.copy_(positions_i32) + return positions_i64 + return positions_i32 + return (seq_plus_one_i32[:num_tokens] - 1).to(dtype=torch.long).contiguous() + def _extract_combo_positions( self, inputs: PyModelInputs, model_device: torch.device ) -> torch.Tensor | None: @@ -214,7 +448,20 @@ def _extract_positions( raise ValueError( "GLM5 RTP plugin requires inputs.attention_inputs to provide position metadata." ) - positions = getattr(attn_inputs, "position_ids", None) + positions = None + graph_decode = bool(getattr(attn_inputs, "is_cuda_graph", False)) and not bool( + getattr(attn_inputs, "is_prefill", False) + ) + if graph_decode: + # RTP CudaGraphRunner refreshes sequence_lengths_plus_1_d before + # replay, but not position_ids. Build decode positions from the + # refreshed RTP length tensors so RoPE advances on every replay. + positions = self._build_graph_decode_positions( + attn_inputs=attn_inputs, + model_device=model_device, + ) + if positions is None or positions.numel() == 0: + positions = getattr(attn_inputs, "position_ids", None) if positions is None or positions.numel() == 0: positions = self._extract_combo_positions( inputs=inputs, model_device=model_device @@ -228,9 +475,16 @@ def _extract_positions( raise ValueError( "GLM5 RTP plugin requires real position metadata from attention_inputs." ) - positions = positions.to( - device=model_device, dtype=torch.long, non_blocking=True - ).contiguous() + if torch.cuda.is_current_stream_capturing(): + if positions.device != model_device: + raise RuntimeError( + "GLM5 RTP cuda-graph capture requires positions on model device." + ) + positions = positions.contiguous() + else: + positions = positions.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() if not torch.cuda.is_current_stream_capturing(): pos_tokens = int(positions.shape[-1]) if positions.dim() > 0 else int(positions.numel()) if token_num > 0 and pos_tokens != token_num: @@ -261,7 +515,8 @@ def _extract_positions( return positions def forward(self, inputs: PyModelInputs, fmha_impl=None) -> PyModelOutputs: # noqa: ANN001 - if bool(getattr(fmha_impl, "is_cuda_graph", False)): + is_cuda_graph = bool(getattr(fmha_impl, "is_cuda_graph", False)) + if is_cuda_graph: inputs.attention_inputs.is_cuda_graph = True model_device = self._get_model_device() model_dtype = self._get_model_dtype() @@ -278,6 +533,8 @@ def forward(self, inputs: PyModelInputs, fmha_impl=None) -> PyModelOutputs: # n positions = self._extract_positions( inputs=inputs, model_device=model_device, token_num=token_num ) + if is_cuda_graph and token_num > 0: + positions = positions[:token_num] if input_ids is None or input_ids.numel() == 0: inputs_embeds = inputs.input_hiddens if ( diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 265b4612dd..92638b4e47 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -9,6 +9,13 @@ import torch from aiter import dtypes +try: + import triton + import triton.language as tl +except (ImportError, ModuleNotFoundError): + triton = None + tl = None + from atom.config import KVCacheTensor, get_current_atom_config from atom.model_ops.attention_gdn import GatedDeltaNet from atom.model_ops.attention_mha import PagedAttentionImpl @@ -63,6 +70,46 @@ class AiterFlashAttentionMetadataForPluginMode: context: Any = None +if triton is not None: + + @triton.jit + def _expand_block_table_for_atom_indexer_kernel( + block_table, + output, + num_cols: tl.constexpr, + output_cols: tl.constexpr, + block_ratio: tl.constexpr, + BLOCK_RATIO: tl.constexpr, + ): + row = tl.program_id(0) + col = tl.program_id(1) + offsets = tl.arange(0, BLOCK_RATIO) + value = tl.load(block_table + row * num_cols + col) + expanded = value * block_ratio + offsets + expanded = tl.where(value >= 0, expanded, -1) + tl.store(output + row * output_cols + col * block_ratio + offsets, expanded) + + @triton.jit + def _recover_physical_block_table_from_kernel_kernel( + kernel_block_table, + output, + kernel_cols: tl.constexpr, + physical_cols: tl.constexpr, + block_ratio: tl.constexpr, + ): + row = tl.program_id(0) + col = tl.program_id(1) + kernel_col = col * block_ratio + value = tl.load( + kernel_block_table + row * kernel_cols + kernel_col, + mask=kernel_col < kernel_cols, + other=-1, + ) + physical = value // block_ratio + physical = tl.where(value >= 0, physical, -1) + tl.store(output + row * physical_cols + col, physical) + + @dataclass(frozen=True) class RTPForwardContext: gdn_metadata: GDNAttentionMetadata @@ -76,7 +123,6 @@ class RTPForwardContext: context: Context num_tokens: int mla_layer_map: Dict[int, Any] - use_rtp_indexer_cache: bool = False LayerMaps = tuple[Dict[int, GatedDeltaNet], Dict[int, Any], Dict[int, Any]] @staticmethod @@ -325,6 +371,76 @@ def _select_physical_block_table_for_layer( group_id=group_id, ) + @staticmethod + def _recover_physical_block_table_from_kernel( + kernel_block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return kernel_block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot recover physical block_table from kernel block_table: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + if kernel_block_table.dim() == 1: + kernel_block_table = kernel_block_table.unsqueeze(0) + if kernel_block_table.dim() != 2: + raise ValueError( + "RTP plugin invalid kernel block_table shape for physical recovery: " + f"{tuple(kernel_block_table.shape)}" + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + bs_now = int(kernel_block_table.shape[0]) + kernel_cols = int(kernel_block_table.shape[1]) + if kernel_cols < block_ratio or kernel_cols % block_ratio != 0: + return kernel_block_table.to( + device=kernel_block_table.device, dtype=torch.int32, non_blocking=True + ).contiguous() + physical_cols = (kernel_cols + block_ratio - 1) // block_ratio + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is not None: + if triton is None: + raise RuntimeError( + "RTP plugin cuda-graph capture requires Triton for capture-safe " + "physical block_table recovery." + ) + out_buf = cg_bufs.get("physical_block_table_i32") + if not isinstance(out_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed physical_block_table_i32." + ) + if int(out_buf.shape[0]) < bs_now or int(out_buf.shape[1]) < physical_cols: + raise RuntimeError( + "RTP plugin prewarmed block_table_i32 buffer is too small for " + "physical recovery " + f"(buffer={tuple(out_buf.shape)}, required=({bs_now}, {physical_cols}))." + ) + out_view = out_buf[:bs_now, :physical_cols] + _recover_physical_block_table_from_kernel_kernel[(bs_now, physical_cols)]( + kernel_block_table, + out_view, + kernel_cols, + physical_cols, + block_ratio, + ) + return out_view + + sampled = kernel_block_table[:, : physical_cols * block_ratio : block_ratio] + recovered = torch.div(sampled, block_ratio, rounding_mode="floor") + recovered = torch.where(sampled >= 0, recovered, sampled) + return recovered.to( + device=kernel_block_table.device, dtype=torch.int32, non_blocking=True + ).contiguous() + @staticmethod def _build_layer_group_map(attn_inputs: Any) -> Dict[int, int]: layer_to_group = getattr(attn_inputs, "kv_cache_layer_to_group", None) @@ -496,10 +612,9 @@ def _build_gdn_metadata( def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: """Build kernel seq_lens using RTP-native field priority. - Non-cuda-graph decode keeps the pre-cuda-graph field priority: - sequence_lengths_plus_1_d first, then sequence_lengths + input_lengths. - Cuda-graph warmup/replay keeps the graph-safe priority introduced for - dummy inputs. + Decode uses RTP's canonical sequence_lengths_plus_1_d first in both + eager and CUDA-graph paths. This keeps context_lens aligned with the + block-table slot/state-index calculation during graph replay. """ input_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "input_lengths", None), @@ -532,22 +647,18 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) return (prefix_lengths + input_lengths).contiguous() - non_cuda_graph_mode = not torch.cuda.is_current_stream_capturing() and not bool( - getattr(attn_inputs, "is_cuda_graph", False) + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, ) - if non_cuda_graph_mode: - sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), - device=device, - ) - if sequence_lengths_plus_1 is not None: - if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): - raise ValueError( - "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " - f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " - f"input_lengths={int(input_lengths.numel())})." - ) - return sequence_lengths_plus_1.contiguous() + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() sequence_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "sequence_lengths", None), @@ -564,20 +675,6 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: # real context length is sequence_lengths + input_lengths. return (sequence_lengths + input_lengths).contiguous() - if not non_cuda_graph_mode: - sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), - device=device, - ) - if sequence_lengths_plus_1 is not None: - if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): - raise ValueError( - "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " - f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " - f"input_lengths={int(input_lengths.numel())})." - ) - return sequence_lengths_plus_1.contiguous() - raise ValueError( "RTP decode requires attention_inputs.sequence_lengths_plus_1_d or " "sequence_lengths for seq_lens." @@ -846,6 +943,56 @@ def _expand_block_table_for_atom_indexer( expanded = torch.where(base.unsqueeze(-1) >= 0, expanded, -1) return expanded.reshape(base.shape[0], base.shape[1] * block_ratio).contiguous() + @staticmethod + def _expand_block_table_for_atom_indexer_capture( + block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot expand block_table for ATOM indexer: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + if triton is None: + raise RuntimeError( + "RTP plugin cuda-graph capture requires Triton for capture-safe " + "ATOM indexer block_table expansion." + ) + out_buf = cg_bufs.get("indexer_block_table_i32") + if not isinstance(out_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed indexer_block_table_i32." + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + bs_now = int(block_table.shape[0]) + cols_now = int(block_table.shape[1]) + expanded_cols = cols_now * block_ratio + if int(out_buf.shape[0]) < bs_now or int(out_buf.shape[1]) < expanded_cols: + raise RuntimeError( + "RTP plugin prewarmed indexer_block_table_i32 buffer is too small " + f"(buffer={tuple(out_buf.shape)}, required=({bs_now}, {expanded_cols}))." + ) + out_view = out_buf[:bs_now, :expanded_cols] + _expand_block_table_for_atom_indexer_kernel[(bs_now, cols_now)]( + block_table, + out_view, + cols_now, + expanded_cols, + block_ratio, + BLOCK_RATIO=block_ratio, + ) + return out_view + @staticmethod def _build_plugin_attention_metadata( *, @@ -856,9 +1003,23 @@ def _build_plugin_attention_metadata( cg_max_seq_len: int = 0, cg_bufs: dict | None = None, ) -> AttentionMetaData: - block_table = RTPForwardContext._select_physical_block_table_for_layer( - attn_inputs=attn_inputs, - ) + physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if physical_block_table is not None and physical_block_table.numel() > 0: + block_table = physical_block_table + else: + kernel_block_table = RTPForwardContext._select_block_table_for_layer( + attn_inputs=attn_inputs, + ) + block_table = ( + None + if kernel_block_table is None + else RTPForwardContext._recover_physical_block_table_from_kernel( + kernel_block_table, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + ) if block_table is None or block_table.numel() == 0: raise ValueError( "RTP plugin requires kv_cache_block_id_device for plugin attention metadata." @@ -894,6 +1055,20 @@ def _build_plugin_attention_metadata( # slice here so slot_mapping and num_actual_tokens are correctly sized. if in_capture and not is_prefill: positions = positions[:batch_size] + if positions.dtype != torch.int32: + positions_i32_buf = cg_bufs.get("positions_i32") + if not isinstance(positions_i32_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed positions_i32 buffer." + ) + if int(positions_i32_buf.shape[0]) < batch_size: + raise RuntimeError( + "RTP plugin prewarmed positions_i32 buffer is too small " + f"(buffer={int(positions_i32_buf.shape[0])}, required={batch_size})." + ) + positions_i32 = positions_i32_buf[:batch_size] + positions_i32.copy_(positions, non_blocking=True) + positions = positions_i32 num_actual_tokens = int(positions.numel()) query_start_loc = RTPForwardContext._build_query_start_loc_for_plugin( @@ -979,24 +1154,42 @@ def _build_plugin_attention_metadata( in_capture = torch.cuda.is_current_stream_capturing() if in_capture and cg_bufs is not None: - # Zero-alloc capture path: always route through prewarmed block_table_i32. - bt_buf = cg_bufs["block_table_i32"] - bs_now = int(block_table.shape[0]) - cols_now = int(block_table.shape[1]) - if int(bt_buf.shape[0]) < bs_now or int(bt_buf.shape[1]) < cols_now: + # Capture must keep the compact physical table layout. Copying into a + # wider prewarmed table and slicing columns would create a strided view + # that the downstream Triton expand kernel does not understand. + if block_table.dtype != torch.int32: raise RuntimeError( - "RTP plugin prewarmed block_table_i32 buffer is too small " - f"(buffer={tuple(bt_buf.shape)}, required=({bs_now}, {cols_now}))." + "RTP plugin capture requires block_table to be int32 to avoid allocation." ) - bt_view = bt_buf[:bs_now, :cols_now] - bt_view.copy_(block_table, non_blocking=True) - block_table_i32 = bt_view + if not block_table.is_contiguous(): + raise RuntimeError( + "RTP plugin capture requires block_table to be contiguous to avoid allocation." + ) + block_table_i32 = block_table else: block_table_i32 = block_table.to( device=device, dtype=torch.int32, non_blocking=True ).contiguous() if in_capture: - indexer_block_table_i32 = block_table_i32 + expected_kernel_cols = 0 + if cg_max_seq_len > 0 and int(kernel_seq_size_per_block) > 0: + expected_kernel_cols = ( + int(cg_max_seq_len) + int(kernel_seq_size_per_block) - 1 + ) // int(kernel_seq_size_per_block) + if ( + expected_kernel_cols > 0 + and int(block_table_i32.shape[1]) >= expected_kernel_cols + ): + indexer_block_table_i32 = block_table_i32 + else: + indexer_block_table_i32 = ( + RTPForwardContext._expand_block_table_for_atom_indexer_capture( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + ) else: indexer_block_table_i32 = ( RTPForwardContext._expand_block_table_for_atom_indexer( @@ -1033,6 +1226,7 @@ def _build_plugin_attention_metadata( plugin_md.req_id_per_token = req_id_per_token plugin_md.topk_tokens = 0 plugin_md.sparse_block_size = int(seq_size_per_block) + plugin_md.cg_bufs = cg_bufs cu_seqlen_ks = None cu_seqlen_ke = None if is_prefill: @@ -1529,20 +1723,12 @@ def _attach_mla_layer_caches( or getattr(get_current_atom_config(), "kv_cache_block_size", 0) or 1 ) - if bool(getattr(forward_context, "use_rtp_indexer_cache", False)): - indexer_cache_tensor = RTPForwardContext._resolve_rtp_indexer_cache( - layer_num=layer_num, - layer_cache=layer_cache, - indexer=indexer, - block_size=block_size, - ) - else: - indexer_cache_tensor = RTPForwardContext._build_fallback_indexer_cache( - cache_owner=cache_owner, - layer_cache=layer_cache, - indexer=indexer, - block_size=block_size, - ) + indexer_cache_tensor = RTPForwardContext._build_fallback_indexer_cache( + cache_owner=cache_owner, + layer_cache=layer_cache, + indexer=indexer, + block_size=block_size, + ) if indexer_cache_tensor is None: continue restore_indices.append((indexer_kv_cache, 0, indexer_kv_cache[0])) diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index be1f581e01..aeae4f2066 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -4,9 +4,7 @@ import types from types import SimpleNamespace -import pytest import torch -from aiter import dtypes class _KwargsObject: @@ -162,6 +160,53 @@ def test_plugin_attention_metadata_slot_mapping_uses_physical_block_table(): assert md.plugin_metadata.slot_mapping.cpu().tolist() == [8 * 1024 + 5] +def test_recover_physical_block_table_accepts_expanded_kernel_layout(): + expanded = torch.tensor( + [[448, 449, 450, 451, 452, 453, 454, 455]], dtype=torch.int32 + ) + + recovered = RTPForwardContext._recover_physical_block_table_from_kernel( + expanded, + seq_size_per_block=1024, + kernel_seq_size_per_block=128, + ) + + assert recovered.cpu().tolist() == [[56]] + + +def test_recover_physical_block_table_keeps_compact_physical_layout(): + compact = torch.tensor([[7, 8, 9]], dtype=torch.int32) + + recovered = RTPForwardContext._recover_physical_block_table_from_kernel( + compact, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert recovered.cpu().tolist() == [[7, 8, 9]] + + +def test_plugin_attention_metadata_keeps_indexer_block_table_expanded(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1030], dtype=torch.int32), + prefix_lengths=torch.tensor([0], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.arange(1030, dtype=torch.int32), + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.block_tables.shape == (1, 128) + assert md.block_tables[0, :4].cpu().tolist() == [448, 449, 450, 451] + assert md.block_tables[0, 64:68].cpu().tolist() == [512, 513, 514, 515] + + def test_plugin_attention_metadata_builds_req_id_per_token(): attn_inputs = _make_attn_inputs( input_lengths=torch.tensor([2, 1], dtype=torch.int32), @@ -187,23 +232,7 @@ def test_plugin_attention_metadata_builds_req_id_per_token(): assert md.total_kv == 3 -def test_rtp_indexer_cache_accepts_byte_packed_kv_scale_base(): - kv_scale_base = torch.empty((2, 1024, 132), dtype=torch.uint8) - layer_cache = SimpleNamespace(kv_scale_base=kv_scale_base) - indexer = SimpleNamespace(head_dim=128) - - cache = RTPForwardContext._resolve_rtp_indexer_cache( - layer_num=0, - layer_cache=layer_cache, - indexer=indexer, - block_size=1024, - ) - - assert tuple(cache.shape) == (2, 1024, 132) - assert cache.dtype == dtypes.fp8 - - -def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): +def test_rtpllm_decode_seq_lens_uses_rtp_plus_one_in_graph_and_eager_modes(): input_lengths = torch.tensor([1], dtype=torch.int32) sequence_lengths = torch.tensor([35], dtype=torch.int32) sequence_lengths_plus_1 = torch.tensor([35], dtype=torch.int32) @@ -229,7 +258,7 @@ def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): graph_seq_lens = RTPForwardContext._build_seq_lens( graph_inputs, device=input_lengths.device ) - assert graph_seq_lens.cpu().tolist() == [36] + assert graph_seq_lens.cpu().tolist() == [35] def test_collect_layer_maps_keeps_mla_layers_separate(): @@ -313,7 +342,6 @@ def test_bind_temporarily_attaches_mla_layer_cache(monkeypatch): context=SimpleNamespace(), num_tokens=1, mla_layer_map={7: mla_layer}, - use_rtp_indexer_cache=False, ) monkeypatch.setattr( @@ -368,7 +396,6 @@ def test_bind_writes_kv_cache_to_mla_attn_owner_not_outer_wrapper(monkeypatch): context=SimpleNamespace(), num_tokens=1, mla_layer_map={7: outer}, - use_rtp_indexer_cache=False, ) monkeypatch.setattr( @@ -416,7 +443,6 @@ def test_bind_temporarily_attaches_sparse_mla_indexer_cache(monkeypatch): context=SimpleNamespace(), num_tokens=1, mla_layer_map={7: mla_layer}, - use_rtp_indexer_cache=False, ) monkeypatch.setattr( @@ -441,147 +467,3 @@ def test_bind_temporarily_attaches_sparse_mla_indexer_cache(monkeypatch): assert mla_layer.kv_cache is old_cache assert indexer.k_cache.kv_cache[0] is old_index_cache - - -def test_bind_uses_rtp_kv_scale_base_when_enabled(monkeypatch): - from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - - old_cache = SimpleNamespace(name="old-cache") - old_index_cache = torch.empty(0) - kv_scale_base = torch.empty(2, 16, 132, dtype=dtypes.fp8) - layer_cache = SimpleNamespace( - kv_cache_base=torch.empty(2, 3), - kv_scale_base=kv_scale_base, - ) - indexer = SimpleNamespace( - head_dim=128, - k_cache=SimpleNamespace(kv_cache=[old_index_cache]), - ) - mla_layer = RTPMLAAttention( - dense_backend=object(), - layer_num=7, - kv_cache=old_cache, - mla_modules=SimpleNamespace(indexer=indexer), - ) - forward_context = SimpleNamespace( - attn_metadata=SimpleNamespace(), - gdn_metadata=SimpleNamespace(), - rtp_attn_inputs=SimpleNamespace(), - rtp_kernel_seq_size_per_block=16, - layer_group_map={}, - kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, - context=SimpleNamespace(), - num_tokens=1, - mla_layer_map={7: mla_layer}, - use_rtp_indexer_cache=True, - ) - - monkeypatch.setattr( - RTPForwardContext, - "build", - classmethod(lambda cls, **kwargs: forward_context), - ) - - with RTPForwardContext.bind( - model=SimpleNamespace(), - runtime=SimpleNamespace(), - inputs=SimpleNamespace(), - positions=torch.tensor([0], dtype=torch.int32), - ): - assert indexer.k_cache.kv_cache[0] is kv_scale_base - - assert indexer.k_cache.kv_cache[0] is old_index_cache - - -def test_bind_accepts_flattened_rtp_kv_scale_base_when_enabled(monkeypatch): - from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - - old_index_cache = torch.empty(0) - flat_kv_scale_base = torch.empty(2, 16 * 132, dtype=dtypes.fp8) - layer_cache = SimpleNamespace( - kv_cache_base=torch.empty(2, 3), - kv_scale_base=flat_kv_scale_base, - ) - indexer = SimpleNamespace( - head_dim=128, - k_cache=SimpleNamespace(kv_cache=[old_index_cache]), - ) - mla_layer = RTPMLAAttention( - dense_backend=object(), - layer_num=7, - mla_modules=SimpleNamespace(indexer=indexer), - ) - forward_context = SimpleNamespace( - attn_metadata=SimpleNamespace(), - gdn_metadata=SimpleNamespace(), - rtp_attn_inputs=SimpleNamespace(), - rtp_kernel_seq_size_per_block=16, - layer_group_map={}, - kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, - context=SimpleNamespace(), - num_tokens=1, - mla_layer_map={7: mla_layer}, - use_rtp_indexer_cache=True, - ) - - monkeypatch.setattr( - RTPForwardContext, - "build", - classmethod(lambda cls, **kwargs: forward_context), - ) - - with RTPForwardContext.bind( - model=SimpleNamespace(), - runtime=SimpleNamespace(), - inputs=SimpleNamespace(), - positions=torch.tensor([0], dtype=torch.int32), - ): - assert indexer.k_cache.kv_cache[0].data_ptr() == flat_kv_scale_base.data_ptr() - assert indexer.k_cache.kv_cache[0].shape == (2, 16, 132) - - assert indexer.k_cache.kv_cache[0] is old_index_cache - - -def test_bind_rejects_incompatible_indexer_cache_layout(monkeypatch): - from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - - layer_cache = SimpleNamespace( - kv_cache_base=torch.empty(2, 3), - kv_scale_base=torch.empty(2, 16, 64, dtype=dtypes.fp8), - ) - indexer = SimpleNamespace( - head_dim=128, - k_cache=SimpleNamespace(kv_cache=[torch.empty(0)]), - ) - mla_layer = RTPMLAAttention( - dense_backend=object(), - layer_num=7, - mla_modules=SimpleNamespace(indexer=indexer), - ) - forward_context = SimpleNamespace( - attn_metadata=SimpleNamespace(), - gdn_metadata=SimpleNamespace(), - rtp_attn_inputs=SimpleNamespace(), - rtp_kernel_seq_size_per_block=16, - layer_group_map={}, - kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, - context=SimpleNamespace(), - num_tokens=1, - mla_layer_map={7: mla_layer}, - use_rtp_indexer_cache=True, - ) - - monkeypatch.setattr( - RTPForwardContext, - "build", - classmethod(lambda cls, **kwargs: forward_context), - ) - - with pytest.raises(ValueError, match="layout mismatch"): - with RTPForwardContext.bind( - model=SimpleNamespace(), - runtime=SimpleNamespace(), - inputs=SimpleNamespace(), - positions=torch.tensor([0], dtype=torch.int32), - ): - pass diff --git a/tests/plugin/test_rtpllm_glm5_indexer_contract.py b/tests/plugin/test_rtpllm_glm5_indexer_contract.py index bdea3d16b4..84d9841cca 100644 --- a/tests/plugin/test_rtpllm_glm5_indexer_contract.py +++ b/tests/plugin/test_rtpllm_glm5_indexer_contract.py @@ -131,6 +131,30 @@ def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): assert attention.topk_indices_buffer is topk_buffer +def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr(torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False) + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace( + topk_indices_buffer=topk_buffer, + index_topk=4, + sparse_attn_indexer_impl=default_op, + ) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + attention = RTPMLAAttention(mla_modules=modules, dense_backend=object()) + + assert attention.indexer is indexer + assert indexer.sparse_attn_indexer_impl is rtp_op + + def _run_attention(attention, token_count: int): query = torch.empty(token_count, 6) compressed_kv = torch.empty(token_count, 8) diff --git a/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py b/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py index 406f5e88d9..a4088fd106 100644 --- a/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py +++ b/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py @@ -550,15 +550,13 @@ def test_default_dense_mla_backend_decode_rebuilds_stale_query_start_loc(monkeyp assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] -def test_default_sparse_wrapper_validates_topk_but_falls_back_to_dense(monkeypatch): +def test_default_sparse_wrapper_refuses_mock_dense_fallback(monkeypatch): _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( - RTPSparseMlaBackend, - ) + from atom.plugin.rtpllm.attention_backend import rtp_sparse_mla_backend dense_backend = _FakeDenseBackend(v_head_dim=4) sparse_impl = SimpleNamespace(calls=[]) - backend = RTPSparseMlaBackend( + backend = rtp_sparse_mla_backend.RTPSparseMlaBackend( dense_backend=dense_backend, sparse_impl=sparse_impl, v_head_dim=4, @@ -569,20 +567,22 @@ def test_default_sparse_wrapper_validates_topk_but_falls_back_to_dense(monkeypat positions = torch.arange(2) topk = torch.tensor([[1, 0], [0, 1]], dtype=torch.int32) - output = backend.forward( - q, - compressed_kv, - k_pe, - kv_cache="cache", - layer_id=9, - topk_indices=topk, - positions=positions, - ) + try: + backend.forward( + q, + compressed_kv, + k_pe, + kv_cache="cache", + layer_id=9, + topk_indices=topk, + positions=positions, + ) + except rtp_sparse_mla_backend._SparseUnavailable: + pass + else: + raise AssertionError("default sparse mock must not silently fallback to dense") - assert output.shape == (2, 1, 4) - assert len(dense_backend.calls) == 1 - assert dense_backend.calls[0]["topk_indices"] is topk - assert dense_backend.calls[0]["positions"] is positions + assert dense_backend.calls == [] assert sparse_impl.calls == [] diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index 065d265b6b..6d6147a683 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -38,6 +38,98 @@ def _load_sparse_backend(monkeypatch): return module.RTPSparseMlaBackend +def test_rtp_sparse_attn_indexer_bridge_forwards_to_main_indexer(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = [] + expected = torch.empty(1) + + def fake_sparse_attn_indexer(*args): + calls.append(args) + return expected + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = fake_sparse_attn_indexer + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + tensor = torch.empty(1) + output = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + tensor, + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][0] is tensor + assert calls[0][1] == "indexer.prefix" + assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) + + +def test_rtp_sparse_attn_indexer_fake_bridge_forwards_to_main_fake(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = [] + expected = torch.empty(1) + + def fake_sparse_attn_indexer_fake(*args): + calls.append(args) + return expected + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer_fake = fake_sparse_attn_indexer_fake + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + tensor = torch.empty(1) + output = module.rtp_sparse_attn_indexer_fake( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + tensor, + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][0] is tensor + assert calls[0][1] == "indexer.prefix" + assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) + + class _FakeDenseBackend: def __init__(self, v_head_dim: int = 5): self.v_head_dim = v_head_dim @@ -200,6 +292,76 @@ def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): assert sparse_impl.calls[0]["attn_metadata"] is attn_metadata +def test_sparse_backend_prefill_missing_flashmla_falls_back_to_dense(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = sys.modules["atom.utils.forward_context"] + + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=1, is_dummy_warmup=False) + ) + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + class _MissingPrefillSparse: + def forward(self, *args, **kwargs): + raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") + + dense_backend = _FakeDenseBackend() + backend = _build_backend(backend_cls, dense_backend, _MissingPrefillSparse()) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + output = backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + + assert output.shape == (3, 2, dense_backend.v_head_dim) + assert len(dense_backend.calls) == 1 + assert dense_backend.calls[0]["topk_indices"] is topk + + +def test_sparse_backend_decode_missing_sparse_kernel_still_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = sys.modules["atom.utils.forward_context"] + + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=0, is_dummy_warmup=False) + ) + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + class _MissingDecodeSparse: + def forward(self, *args, **kwargs): + raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") + + backend = _build_backend(backend_cls, _FakeDenseBackend(), _MissingDecodeSparse()) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + except module._SparseUnavailable: + pass + else: + raise AssertionError("decode sparse unavailability must not fall back to dense") + + def test_sparse_backend_forward_signature_matches_dense_boundary(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index 7410111019..4510425373 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -318,3 +318,60 @@ def collect_layer_maps(model): assert attn_pyobj.is_cuda_graph is False assert hasattr(attn_pyobj, "prepare_cuda_graph") + +def test_glm5_runtime_decode_positions_prefer_sequence_lengths_plus_one(): + fake_modules = _install_fake_rtp_modules() + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + runtime = object.__new__(module._ATOMGlm5MoeRuntime) + attn_inputs = SimpleNamespace( + input_lengths=torch.tensor([1, 2], dtype=torch.int32), + is_prefill=False, + sequence_lengths=torch.tensor([999, 999], dtype=torch.int32), + sequence_lengths_plus_1_d=torch.tensor([35, 50], dtype=torch.int32), + ) + + positions = runtime._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=torch.device("cpu"), + ) + + assert positions.cpu().tolist() == [34, 48, 49] + + +def test_glm5_runtime_graph_decode_ignores_stale_position_ids(): + fake_modules = _install_fake_rtp_modules() + + with patch.dict(sys.modules, fake_modules), patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + runtime = object.__new__(module._ATOMGlm5MoeRuntime) + inputs = SimpleNamespace( + bert_embedding_inputs=None, + attention_inputs=SimpleNamespace( + input_lengths=torch.tensor([1, 2], dtype=torch.int32), + is_prefill=False, + is_cuda_graph=True, + position_ids=torch.tensor([0, 0, 0], dtype=torch.int32), + sequence_lengths_plus_1_d=torch.tensor([35, 50], dtype=torch.int32), + ), + ) + + positions = runtime._extract_positions( + inputs=inputs, + model_device=torch.device("cpu"), + token_num=3, + ) + + assert positions.cpu().tolist() == [34, 48, 49] + From a49dc53bc6a085f76a2558b158bf1ca01a8d7822 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Thu, 4 Jun 2026 11:05:45 +0000 Subject: [PATCH 03/20] fix: RTP glm5 qwen35 cuda graph conflict --- atom/plugin/rtpllm/models/__init__.py | 13 +- atom/plugin/rtpllm/models/glm5.py | 48 +- atom/plugin/rtpllm/models/qwen3_5.py | 24 +- atom/plugin/rtpllm/utils/__init__.py | 12 +- atom/plugin/rtpllm/utils/forward_context.py | 503 ++++++++++++++++-- .../test_rtpllm_forward_context_semantics.py | 55 +- 6 files changed, 578 insertions(+), 77 deletions(-) diff --git a/atom/plugin/rtpllm/models/__init__.py b/atom/plugin/rtpllm/models/__init__.py index 0fd1c7d8bb..c99f5363fd 100644 --- a/atom/plugin/rtpllm/models/__init__.py +++ b/atom/plugin/rtpllm/models/__init__.py @@ -6,9 +6,14 @@ ATOMGlm5Moe = None ATOMQwen35Moe = None else: - from atom.models.deepseek_v2 import GlmMoeDsaForCausalLM - from atom.plugin.register import _ATOM_SUPPORTED_MODELS - - _ATOM_SUPPORTED_MODELS.setdefault("GlmMoeDsaForCausalLM", GlmMoeDsaForCausalLM) + try: + from atom.models.deepseek_v2 import GlmMoeDsaForCausalLM + from atom.plugin.register import _ATOM_SUPPORTED_MODELS + except ImportError: + # Unit tests may stub partial module trees and intentionally skip + # full model imports. Keep wrapper symbols importable in that case. + pass + else: + _ATOM_SUPPORTED_MODELS.setdefault("GlmMoeDsaForCausalLM", GlmMoeDsaForCausalLM) __all__ = ["ATOMGlm5Moe", "ATOMQwen35Moe"] diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index 885cd4690c..5783d081b2 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -83,8 +83,10 @@ def _collect_mla_layers(self) -> None: backend = getattr(candidate, "dense_backend", None) else: backend = getattr(candidate, "dense_backend", None) - if backend is None and RTPSparseMlaBackend is not None and isinstance( - candidate, RTPSparseMlaBackend + if ( + backend is None + and RTPSparseMlaBackend is not None + and isinstance(candidate, RTPSparseMlaBackend) ): backend = candidate @@ -306,7 +308,9 @@ def _ensure_cuda_graph_prewarmed(self) -> None: def _get_forward_context_cls(): global RTPForwardContext if RTPForwardContext is None: - from atom.plugin.rtpllm.utils import RTPForwardContext as _RTPForwardContext + from atom.plugin.rtpllm.utils import ( + RTPForwardMLAContext as _RTPForwardContext, + ) RTPForwardContext = _RTPForwardContext return RTPForwardContext @@ -374,14 +378,18 @@ def _build_positions_from_attention_inputs( starts = prefix_lengths_i32[: int(input_lengths_i32.numel())] return self._build_token_positions(input_lengths_i32, starts) - sequence_lengths_plus_1 = getattr(attn_inputs, "sequence_lengths_plus_1_d", None) + sequence_lengths_plus_1 = getattr( + attn_inputs, "sequence_lengths_plus_1_d", None + ) if sequence_lengths_plus_1 is not None and sequence_lengths_plus_1.numel() > 0: seq_plus_one_i32 = sequence_lengths_plus_1.to( device=model_device, dtype=torch.int32, non_blocking=True ).contiguous() if int(seq_plus_one_i32.numel()) < int(input_lengths_i32.numel()): return None - starts = seq_plus_one_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + starts = ( + seq_plus_one_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + ) return self._build_token_positions(input_lengths_i32, starts) sequence_lengths = getattr(attn_inputs, "sequence_lengths", None) @@ -392,13 +400,19 @@ def _build_positions_from_attention_inputs( ).contiguous() if int(sequence_lengths_i32.numel()) < int(input_lengths_i32.numel()): return None - starts = sequence_lengths_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + 1 + starts = ( + sequence_lengths_i32[: int(input_lengths_i32.numel())] + - input_lengths_i32 + + 1 + ) return self._build_token_positions(input_lengths_i32, starts) def _build_graph_decode_positions( self, attn_inputs: Any, model_device: torch.device ) -> torch.Tensor | None: - sequence_lengths_plus_1 = getattr(attn_inputs, "sequence_lengths_plus_1_d", None) + sequence_lengths_plus_1 = getattr( + attn_inputs, "sequence_lengths_plus_1_d", None + ) if sequence_lengths_plus_1 is None or sequence_lengths_plus_1.numel() == 0: return None input_lengths = getattr(attn_inputs, "input_lengths", None) @@ -413,7 +427,10 @@ def _build_graph_decode_positions( cg_bufs = getattr(self, "_cg_meta_bufs", None) if isinstance(cg_bufs, dict): positions_buf = cg_bufs.get("positions_i32") - if isinstance(positions_buf, torch.Tensor) and int(positions_buf.numel()) >= num_tokens: + if ( + isinstance(positions_buf, torch.Tensor) + and int(positions_buf.numel()) >= num_tokens + ): positions_i32 = positions_buf[:num_tokens] torch.sub(seq_plus_one_i32[:num_tokens], 1, out=positions_i32) positions_i64_buf = cg_bufs.get("positions_i64") @@ -486,7 +503,11 @@ def _extract_positions( device=model_device, dtype=torch.long, non_blocking=True ).contiguous() if not torch.cuda.is_current_stream_capturing(): - pos_tokens = int(positions.shape[-1]) if positions.dim() > 0 else int(positions.numel()) + pos_tokens = ( + int(positions.shape[-1]) + if positions.dim() > 0 + else int(positions.numel()) + ) if token_num > 0 and pos_tokens != token_num: rebuilt_positions = self._build_positions_from_attention_inputs( attn_inputs=attn_inputs, @@ -514,7 +535,9 @@ def _extract_positions( ) return positions - def forward(self, inputs: PyModelInputs, fmha_impl=None) -> PyModelOutputs: # noqa: ANN001 + def forward( + self, inputs: PyModelInputs, fmha_impl=None + ) -> PyModelOutputs: # noqa: ANN001 is_cuda_graph = bool(getattr(fmha_impl, "is_cuda_graph", False)) if is_cuda_graph: inputs.attention_inputs.is_cuda_graph = True @@ -653,7 +676,9 @@ def _inject_rtp_projection_weights(self, atom_model: Any) -> None: f"{weight_name} candidates={candidates}" for weight_name, candidates in missing ) - raise ValueError(f"Cannot locate GLM5 RTP runtime projection weights: {details}") + raise ValueError( + f"Cannot locate GLM5 RTP runtime projection weights: {details}" + ) def _assert_norm_weights_loaded(self, atom_model: Any) -> None: params = self._get_named_parameters(atom_model) @@ -774,4 +799,3 @@ def _create_python_model(self): ) logger.info("Created ATOM GLM5 runtime for rtp-llm plugin mode") return self.py_model - diff --git a/atom/plugin/rtpllm/models/qwen3_5.py b/atom/plugin/rtpllm/models/qwen3_5.py index eb41294831..4e44455fd5 100644 --- a/atom/plugin/rtpllm/models/qwen3_5.py +++ b/atom/plugin/rtpllm/models/qwen3_5.py @@ -24,7 +24,7 @@ apply_attention_mha_rtpllm_patch, ) from atom.plugin.rtpllm.models.qwen3_next import apply_qwen3_next_rtpllm_patch -from atom.plugin.rtpllm.utils import RTPForwardContext +from atom.plugin.rtpllm.utils import RTPForwardQwen35HybridContext logger = logging.getLogger("atom.plugin.rtpllm.models") @@ -127,7 +127,9 @@ def __init__( self._model_device = first_param.device self._model_dtype = first_param.dtype # Cache module layer maps once to avoid per-forward model.modules() traversal. - self._rtp_layer_maps = RTPForwardContext.collect_layer_maps(model=self.model) + self._rtp_layer_maps = RTPForwardQwen35HybridContext.collect_layer_maps( + model=self.model + ) # Lazy-built in forward_context; invalidated by kv buffer signature change. self._rtp_kv_cache_data: dict | None = None self._rtp_kv_cache_signature: tuple | None = None @@ -384,9 +386,17 @@ def _ensure_cuda_graph_prewarmed(self) -> None: or int(getattr(kv_cache, "seq_size_per_block", 0)) or 1 ) + seq_size_per_block = ( + int(getattr(kv_cache, "seq_size_per_block", 0)) + or kernel_seq_size_per_block + or 1 + ) max_blocks = ( int(max_seq_len) + kernel_seq_size_per_block - 1 ) // kernel_seq_size_per_block + 1 + physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block # query_start_loc for decode: always [0, 1, 2, ..., bs], i.e. arange(bs+1). # seq_id for decode slot_mapping: seq_id[i] == i, i.e. arange(bs). self._cg_meta_bufs: dict = { @@ -403,12 +413,16 @@ def _ensure_cuda_graph_prewarmed(self) -> None: "block_table_i32": torch.empty( max_bs, max_blocks, device=device, dtype=torch.int32 ), + "physical_block_table_i32": torch.empty( + max_bs, max(physical_max_blocks, 1), device=device, dtype=torch.int32 + ), } self._cg_layers_prewarmed = True logger.info( "ATOM RTPFullAttention cuda-graph prewarmed for %d layers " "(max_num_tokens=%d, max_seq_len=%d, rtp_kv_heads=%s, " - "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d])", + "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d], " + "physical_block_table_i32[%dx%d])", len(self._atom_attn_pyobj._rtp_full_attn_layers), max_num_tokens, max_seq_len, @@ -417,6 +431,8 @@ def _ensure_cuda_graph_prewarmed(self) -> None: max_bs, max_bs, max_blocks, + max_bs, + max(physical_max_blocks, 1), ) def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutputs: @@ -452,7 +468,7 @@ def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutput ): inputs_embeds = inputs_embeds.to(dtype=model_dtype) - with RTPForwardContext.bind( + with RTPForwardQwen35HybridContext.bind( model=self.model, runtime=self, inputs=inputs, diff --git a/atom/plugin/rtpllm/utils/__init__.py b/atom/plugin/rtpllm/utils/__init__.py index 7a85bec249..d82cf33c0e 100644 --- a/atom/plugin/rtpllm/utils/__init__.py +++ b/atom/plugin/rtpllm/utils/__init__.py @@ -1,3 +1,11 @@ -from .forward_context import RTPForwardContext +from .forward_context import ( + RTPForwardContext, + RTPForwardMLAContext, + RTPForwardQwen35HybridContext, +) -__all__ = ["RTPForwardContext"] +__all__ = [ + "RTPForwardContext", + "RTPForwardMLAContext", + "RTPForwardQwen35HybridContext", +] diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 92638b4e47..4517262d9c 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -993,8 +993,32 @@ def _expand_block_table_for_atom_indexer_capture( ) return out_view - @staticmethod + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if physical_block_table is not None and physical_block_table.numel() > 0: + return physical_block_table + kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) + if kernel_block_table is None or kernel_block_table.numel() == 0: + return None + return cls._recover_physical_block_table_from_kernel( + kernel_block_table, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + + @classmethod def _build_plugin_attention_metadata( + cls, *, attn_inputs: Any, positions: torch.Tensor, @@ -1003,35 +1027,25 @@ def _build_plugin_attention_metadata( cg_max_seq_len: int = 0, cg_bufs: dict | None = None, ) -> AttentionMetaData: - physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) - if physical_block_table is not None and physical_block_table.numel() > 0: - block_table = physical_block_table - else: - kernel_block_table = RTPForwardContext._select_block_table_for_layer( - attn_inputs=attn_inputs, - ) - block_table = ( - None - if kernel_block_table is None - else RTPForwardContext._recover_physical_block_table_from_kernel( - kernel_block_table, - seq_size_per_block=int(seq_size_per_block), - kernel_seq_size_per_block=int(kernel_seq_size_per_block), - cg_bufs=cg_bufs, - ) - ) + in_capture = torch.cuda.is_current_stream_capturing() + block_table = cls._resolve_plugin_block_table( + attn_inputs=attn_inputs, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + in_capture=in_capture, + ) if block_table is None or block_table.numel() == 0: raise ValueError( "RTP plugin requires kv_cache_block_id_device for plugin attention metadata." ) device = positions.device is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) - in_capture = torch.cuda.is_current_stream_capturing() if in_capture and cg_bufs is None: raise RuntimeError( "RTP plugin capture requires prewarmed cg_bufs; metadata fallback path is disabled." ) - seq_lens = RTPForwardContext._build_seq_lens(attn_inputs, device=device) + seq_lens = cls._build_seq_lens(attn_inputs, device=device) if in_capture and cg_bufs is not None: bs_now = int(seq_lens.shape[0]) seq_lens_buf = cg_bufs["seq_lens_i32"] @@ -1071,21 +1085,21 @@ def _build_plugin_attention_metadata( positions = positions_i32 num_actual_tokens = int(positions.numel()) - query_start_loc = RTPForwardContext._build_query_start_loc_for_plugin( + query_start_loc = cls._build_query_start_loc_for_plugin( attn_inputs=attn_inputs, seq_lens=seq_lens, num_tokens=num_actual_tokens, device=device, cg_bufs=cg_bufs, ) - slot_mapping = RTPForwardContext._build_slot_mapping( + slot_mapping = cls._build_slot_mapping( positions=positions, query_start_loc=query_start_loc, block_table=block_table, seq_size_per_block=seq_size_per_block, cg_bufs=cg_bufs, ) - req_id_per_token = RTPForwardContext._build_req_id_per_token( + req_id_per_token = cls._build_req_id_per_token( query_start_loc=query_start_loc, num_tokens=num_actual_tokens, device=device, @@ -1170,34 +1184,14 @@ def _build_plugin_attention_metadata( block_table_i32 = block_table.to( device=device, dtype=torch.int32, non_blocking=True ).contiguous() - if in_capture: - expected_kernel_cols = 0 - if cg_max_seq_len > 0 and int(kernel_seq_size_per_block) > 0: - expected_kernel_cols = ( - int(cg_max_seq_len) + int(kernel_seq_size_per_block) - 1 - ) // int(kernel_seq_size_per_block) - if ( - expected_kernel_cols > 0 - and int(block_table_i32.shape[1]) >= expected_kernel_cols - ): - indexer_block_table_i32 = block_table_i32 - else: - indexer_block_table_i32 = ( - RTPForwardContext._expand_block_table_for_atom_indexer_capture( - block_table_i32, - seq_size_per_block=int(seq_size_per_block), - kernel_seq_size_per_block=int(kernel_seq_size_per_block), - cg_bufs=cg_bufs, - ) - ) - else: - indexer_block_table_i32 = ( - RTPForwardContext._expand_block_table_for_atom_indexer( - block_table_i32, - seq_size_per_block=int(seq_size_per_block), - kernel_seq_size_per_block=int(kernel_seq_size_per_block), - ) - ) + indexer_block_table_i32 = cls._build_indexer_block_tables( + block_table_i32=block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_max_seq_len=int(cg_max_seq_len), + in_capture=in_capture, + cg_bufs=cg_bufs, + ) plugin_md = AiterFlashAttentionMetadataForPluginMode( num_actual_tokens=num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, @@ -1591,7 +1585,7 @@ def build( layer_group_map=layer_group_map, context=context, num_tokens=int(positions.numel()), - mla_layer_map=resolved_layer_maps[2], + mla_layer_map=cls._resolve_mla_layer_map(resolved_layer_maps), use_rtp_indexer_cache=cls._use_rtp_indexer_cache(), ) @@ -1789,3 +1783,408 @@ def bind( setattr(target, attr, old_cache) reset_forward_context() set_kv_cache_data(prev_kv if prev_kv is not None else {}) + + +@dataclass(frozen=True) +class RTPForwardMLAContext(RTPForwardContext): + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if physical_block_table is not None and physical_block_table.numel() > 0: + return physical_block_table + kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) + if kernel_block_table is None or kernel_block_table.numel() == 0: + return None + return cls._recover_physical_block_table_from_kernel( + kernel_block_table, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs if in_capture else None, + ) + + @classmethod + def _build_indexer_block_tables( + cls, + *, + block_table_i32: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_max_seq_len: int, + in_capture: bool, + cg_bufs: dict | None, + ) -> torch.Tensor: + if in_capture: + expected_kernel_cols = 0 + if cg_max_seq_len > 0 and int(kernel_seq_size_per_block) > 0: + expected_kernel_cols = ( + int(cg_max_seq_len) + int(kernel_seq_size_per_block) - 1 + ) // int(kernel_seq_size_per_block) + if ( + expected_kernel_cols > 0 + and int(block_table_i32.shape[1]) >= expected_kernel_cols + ): + return block_table_i32 + return cls._expand_block_table_for_atom_indexer_capture( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + return cls._expand_block_table_for_atom_indexer( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + ) + + @classmethod + def _resolve_mla_layer_map( + cls, layer_maps: RTPForwardContext.LayerMaps + ) -> Dict[int, Any]: + del cls + return layer_maps[2] + + +@dataclass(frozen=True) +class RTPForwardQwen35HybridContext(RTPForwardContext): + @staticmethod + def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: + """Qwen3.5 decode-cudagraph compatible seq_lens priority. + + Keep the validated sequence_lengths_plus_1_d ordering from + `develop/rtp_atom_0526_qwen35_cuda_graph_ok`. + """ + input_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "input_lengths", None), + device=device, + ) + if input_lengths is None: + raise ValueError( + "RTP plugin requires attention_inputs.input_lengths for seq_lens." + ) + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + if is_prefill: + prefix_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "prefix_lengths_d", None), + device=device, + ) + if prefix_lengths is None: + prefix_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "prefix_lengths", None), + device=device, + ) + if prefix_lengths is None: + raise ValueError( + "RTP prefill requires attention_inputs.prefix_lengths for seq_lens." + ) + if int(prefix_lengths.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin prefix_lengths/input_lengths batch mismatch " + f"(prefix_lengths={int(prefix_lengths.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return (prefix_lengths + input_lengths).contiguous() + + non_cuda_graph_mode = not torch.cuda.is_current_stream_capturing() and not bool( + getattr(attn_inputs, "is_cuda_graph", False) + ) + if non_cuda_graph_mode: + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() + + sequence_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths", None), + device=device, + ) + if sequence_lengths is not None: + if int(sequence_lengths.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths/input_lengths batch mismatch " + f"(sequence_lengths={int(sequence_lengths.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return (sequence_lengths + input_lengths).contiguous() + + if not non_cuda_graph_mode: + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() + + raise ValueError( + "RTP decode requires attention_inputs.sequence_lengths_plus_1_d or " + "sequence_lengths for seq_lens." + ) + + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + del cls, seq_size_per_block, kernel_seq_size_per_block, cg_bufs, in_capture + return RTPForwardContext._select_block_table_for_layer(attn_inputs=attn_inputs) + + @staticmethod + def _build_query_start_loc_for_plugin( + *, + attn_inputs: Any, + seq_lens: torch.Tensor, + num_tokens: int, + device: torch.device, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + batch_size = int(seq_lens.numel()) + if batch_size <= 0: + raise ValueError( + "RTP plugin cannot build query_start_loc with empty seq_lens." + ) + + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is not None: + return cg_bufs["query_start_loc"][: batch_size + 1] + + if in_capture: + raise ValueError( + "RTP plugin capture requires prewarmed cg_bufs for query_start_loc " + f"(batch={batch_size}, num_tokens={int(num_tokens)})." + ) + + qsl = RTPForwardContext._query_start_loc(attn_inputs, device=device) + if qsl is not None and qsl.numel() == batch_size + 1: + lengths = qsl[1:] - qsl[:-1] + qsl_stats = torch.stack([qsl[-1], torch.min(lengths)], dim=0).to( + device="cpu" + ) + qsl_total_tokens, qsl_min_len = [int(v) for v in qsl_stats.tolist()] + if qsl_total_tokens == int(num_tokens) and qsl_min_len > 0: + return qsl.contiguous() + + input_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "input_lengths", None), + device=device, + ) + if input_lengths is not None and int(input_lengths.numel()) == batch_size: + input_stats = torch.stack( + [torch.min(input_lengths), torch.sum(input_lengths)], + dim=0, + ).to(device="cpu") + min_input_len, total_input_len = [int(v) for v in input_stats.tolist()] + if min_input_len > 0 and total_input_len == int(num_tokens): + prefix = torch.zeros((1,), dtype=torch.int32, device=device) + return torch.cat( + [prefix, input_lengths.cumsum(dim=0)], dim=0 + ).contiguous() + + if int(num_tokens) == batch_size: + prefix = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + return prefix.contiguous() + if batch_size == 1: + return torch.tensor([0, int(num_tokens)], dtype=torch.int32, device=device) + + raise ValueError( + "RTP plugin failed to build valid query_start_loc for plugin attention " + f"(batch={batch_size}, num_tokens={int(num_tokens)})." + ) + + @classmethod + def _build_plugin_attention_metadata( + cls, + *, + attn_inputs: Any, + positions: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int = 0, + cg_max_seq_len: int = 0, + cg_bufs: dict | None = None, + ) -> AttentionMetaData: + del kernel_seq_size_per_block + block_table = cls._resolve_plugin_block_table( + attn_inputs=attn_inputs, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=0, + cg_bufs=cg_bufs, + in_capture=torch.cuda.is_current_stream_capturing(), + ) + if block_table is None or block_table.numel() == 0: + raise ValueError( + "RTP plugin requires kv_cache_kernel_block_id_device for plugin attention metadata." + ) + device = positions.device + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is None: + raise RuntimeError( + "RTP plugin capture requires prewarmed cg_bufs; metadata fallback path is disabled." + ) + seq_lens = cls._build_seq_lens(attn_inputs, device=device) + if in_capture and cg_bufs is not None: + bs_now = int(seq_lens.shape[0]) + seq_lens_buf = cg_bufs["seq_lens_i32"] + if int(seq_lens_buf.shape[0]) < bs_now: + raise RuntimeError( + "RTP plugin prewarmed seq_lens_i32 buffer is too small " + f"(buffer={int(seq_lens_buf.shape[0])}, required={bs_now})." + ) + seq_lens_view = seq_lens_buf[:bs_now] + seq_lens_view.copy_(seq_lens, non_blocking=True) + seq_lens = seq_lens_view + else: + seq_lens = seq_lens.to( + device=device, dtype=torch.int32, non_blocking=True + ).contiguous() + batch_size = int(seq_lens.numel()) + + if in_capture and not is_prefill: + positions = positions[:batch_size] + num_actual_tokens = int(positions.numel()) + + query_start_loc = cls._build_query_start_loc_for_plugin( + attn_inputs=attn_inputs, + seq_lens=seq_lens, + num_tokens=num_actual_tokens, + device=device, + cg_bufs=cg_bufs, + ) + slot_mapping = cls._build_slot_mapping( + positions=positions, + query_start_loc=query_start_loc, + block_table=block_table, + seq_size_per_block=seq_size_per_block, + cg_bufs=cg_bufs, + ) + + is_dummy_warmup = False + if in_capture: + max_query_len = 1 + if cg_max_seq_len <= 0: + raise RuntimeError( + "RTP plugin cuda-graph capture requires cg_max_seq_len; " + "did you forget to thread it through RTPForwardContext.bind?" + ) + max_seq_len = int(cg_max_seq_len) + num_actual_kv_tokens = max_seq_len * batch_size + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + stats = torch.stack( + [ + torch.max(query_lens), + torch.max(seq_lens), + torch.sum(seq_lens), + ], + dim=0, + ).to(device="cpu") + max_query_len, max_seq_len, num_actual_kv_tokens = [ + int(v) for v in stats.tolist() + ] + if max_seq_len <= 0: + is_dummy_warmup = True + max_seq_len = int(cg_max_seq_len) if cg_max_seq_len > 0 else 1 + if max_query_len <= 0: + max_query_len = 1 + + decode_md = None + prefill_md = None + if is_prefill: + prefill_md = AiterFlashAttentionPrefillMetadata( + max_query_len=max_query_len, + max_seq_len=max_seq_len, + query_start_loc=query_start_loc, + ) + else: + decode_md = AiterFlashAttentionDecodeMetadata( + max_query_len=max_query_len, + max_seq_len=max_seq_len, + query_start_loc=query_start_loc, + ) + + if in_capture and cg_bufs is not None: + bt_buf = cg_bufs["block_table_i32"] + bs_now = int(block_table.shape[0]) + cols_now = int(block_table.shape[1]) + if int(bt_buf.shape[0]) < bs_now or int(bt_buf.shape[1]) < cols_now: + raise RuntimeError( + "RTP plugin prewarmed block_table_i32 buffer is too small " + f"(buffer={tuple(bt_buf.shape)}, required=({bs_now}, {cols_now}))." + ) + bt_view = bt_buf[:bs_now, :cols_now] + bt_view.copy_(block_table, non_blocking=True) + block_table_i32 = bt_view + else: + block_table_i32 = block_table.to( + device=device, dtype=torch.int32, non_blocking=True + ).contiguous() + + plugin_md = AiterFlashAttentionMetadataForPluginMode( + num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + block_table=block_table_i32, + num_decodes=0 if is_prefill else batch_size, + num_decode_tokens=0 if is_prefill else num_actual_tokens, + num_prefills=batch_size if is_prefill else 0, + num_prefill_tokens=num_actual_tokens if is_prefill else 0, + num_extends=0, + num_extend_tokens=0, + decode_metadata=decode_md, + prefill_metadata=prefill_md, + extend_metadata=None, + use_cascade=False, + common_prefix_len=0, + total_tokens=0, + context=None, + ) + plugin_md.rtp_cu_seqlens_q = query_start_loc + plugin_md.is_dummy_warmup = bool(is_dummy_warmup) + prefix_lengths = getattr(attn_inputs, "prefix_lengths", None) + if ( + prefix_lengths is not None + and int(prefix_lengths.numel()) > 0 + and not in_capture + ): + plugin_md.rtp_has_prefix = bool((prefix_lengths > 0).any().item()) + else: + plugin_md.rtp_has_prefix = False + + attn_metadata = AttentionMetaData( + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + block_tables=plugin_md.block_table, + slot_mapping=slot_mapping, + context_lens=seq_lens, + ) + attn_metadata.plugin_metadata = plugin_md + return attn_metadata diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index aeae4f2066..3d8207703f 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -43,7 +43,9 @@ def _install_forward_context_stubs(): utils_forward_context._forward_kv_cache_context = SimpleNamespace(kv_cache_data={}) utils_forward_context.reset_forward_context = lambda *args, **kwargs: None utils_forward_context.set_forward_context = lambda *args, **kwargs: None - utils_forward_context.get_forward_context = lambda *args, **kwargs: SimpleNamespace() + utils_forward_context.get_forward_context = ( + lambda *args, **kwargs: SimpleNamespace() + ) def _set_kv_cache_data(value): utils_forward_context._forward_kv_cache_context.kv_cache_data = value @@ -54,7 +56,10 @@ def _set_kv_cache_data(value): _install_forward_context_stubs() -from atom.plugin.rtpllm.utils.forward_context import RTPForwardContext # noqa: E402 +from atom.plugin.rtpllm.utils.forward_context import ( # noqa: E402 + RTPForwardContext, + RTPForwardMLAContext, +) def _make_attn_inputs( @@ -194,7 +199,7 @@ def test_plugin_attention_metadata_keeps_indexer_block_table_expanded(): is_prefill=True, ) - md = RTPForwardContext._build_plugin_attention_metadata( + md = RTPForwardMLAContext._build_plugin_attention_metadata( attn_inputs=attn_inputs, positions=torch.arange(1030, dtype=torch.int32), seq_size_per_block=1024, @@ -207,6 +212,50 @@ def test_plugin_attention_metadata_keeps_indexer_block_table_expanded(): assert md.block_tables[0, 64:68].cpu().tolist() == [512, 513, 514, 515] +def test_plugin_attention_metadata_keeps_physical_block_table_for_base_context(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1030], dtype=torch.int32), + prefix_lengths=torch.tensor([0], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.arange(1030, dtype=torch.int32), + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.block_tables.shape == (1, 2) + assert md.block_tables.cpu().tolist() == [[7, 8]] + + +def test_base_context_capture_recovers_physical_table_with_prewarmed_buffer(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1], dtype=torch.int32), + sequence_lengths=torch.tensor([35], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor( + [[448, 449, 450, 451, 452, 453, 454, 455]], dtype=torch.int32 + ), + is_prefill=False, + is_cuda_graph=True, + ) + + cg_bufs = {"physical_block_table_i32": torch.empty((1, 1), dtype=torch.int32)} + block_table = RTPForwardContext._resolve_plugin_block_table( + attn_inputs=attn_inputs, + seq_size_per_block=1024, + kernel_seq_size_per_block=128, + cg_bufs=cg_bufs, + in_capture=True, + ) + + assert block_table is not None + assert block_table.cpu().tolist() == [[56]] + + def test_plugin_attention_metadata_builds_req_id_per_token(): attn_inputs = _make_attn_inputs( input_lengths=torch.tensor([2, 1], dtype=torch.int32), From d9cbe9d838772689865bb8add2244eee30bbc1dc Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 5 Jun 2026 03:04:57 +0000 Subject: [PATCH 04/20] fix: RTP crash when long input_len > 16384 --- atom/plugin/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index ae72b4fad2..17787ccd13 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -409,7 +409,7 @@ def _generate_atom_config_from_rtpllm_config(config: Any): return Config( model=rtpllm_model_config.ckpt_path, - max_num_batched_tokens=max(16384, max_generate_batch_size), + max_num_batched_tokens=max(max_model_len, max_generate_batch_size), max_num_seqs=max_generate_batch_size, max_model_len=max_model_len, gpu_memory_utilization=0.9, From 828109092ee9c546037414a6d649476a2f64572c Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 5 Jun 2026 09:00:57 +0000 Subject: [PATCH 05/20] fix:[RTP] making GLM5 run true Sparse MLA --- .../rtp_sparse_mla_backend.py | 621 ++++++++++++++---- atom/plugin/rtpllm/utils/forward_context.py | 24 +- .../test_rtpllm_forward_context_semantics.py | 26 + ...est_rtpllm_glm5_sparse_backend_contract.py | 173 ++++- .../test_rtpllm_glm5_wrapper_lifecycle.py | 118 +++- 5 files changed, 766 insertions(+), 196 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 27721be6dd..b3fddabf22 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +import os from dataclasses import dataclass from typing import Any, Optional @@ -35,6 +36,7 @@ class _AtomSparseMetadata: reduce_partial_map: torch.Tensor padded_num_heads: int head_repeat_factor: int + page_size: int class _ContractSparseMlaImpl: @@ -99,12 +101,129 @@ def __init__( self._cg_sparse_bufs: dict[str, torch.Tensor] | None = None self._cg_workspace_signature: tuple[Any, ...] | None = None + @staticmethod + def _resolve_sparse_page_size() -> int: + value = os.getenv("ATOM_RTP_GLM5_SPARSE_PAGE_SIZE", "1") + try: + page_size = int(value) + except ValueError as exc: + raise _SparseUnavailable( + f"Invalid ATOM_RTP_GLM5_SPARSE_PAGE_SIZE={value!r}." + ) from exc + if page_size <= 0: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires positive page_size, " + f"got page_size={page_size}." + ) + return page_size + + @staticmethod + def _validate_sparse_index_contract( + *, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + num_tokens: int, + page_size: int, + max_slots: int, + ) -> None: + if int(paged_kv_indptr.numel()) != num_tokens + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_indptr length " + f"(got={int(paged_kv_indptr.numel())}, expected={num_tokens + 1})." + ) + if int(paged_kv_indptr[0].item()) != 0: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr[0] must be 0, " + f"got {int(paged_kv_indptr[0].item())}." + ) + if num_tokens > 0: + deltas = paged_kv_indptr[1:] - paged_kv_indptr[:-1] + if bool((deltas < 0).any().item()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr must be non-decreasing." + ) + used = int(paged_kv_indptr[-1].item()) + if used < 0 or used > int(paged_kv_indices.numel()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr[-1] out of range " + f"(used={used}, capacity={int(paged_kv_indices.numel())})." + ) + if used == 0: + return + used_indices = paged_kv_indices[:used] + min_index = int(used_indices.min().item()) + max_index = int(used_indices.max().item()) + if min_index < 0 or max_index >= max_slots: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA produced out-of-range paged_kv_indices " + f"(min={min_index}, max={max_index}, slots={max_slots}, " + f"page_size={page_size})." + ) + + @staticmethod + def _validate_sparse_last_page_contract( + *, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + num_tokens: int, + page_size: int, + ) -> None: + if int(paged_kv_last_page_len.numel()) != int(num_tokens): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_last_page_len length " + f"(got={int(paged_kv_last_page_len.numel())}, expected={int(num_tokens)})." + ) + if num_tokens <= 0: + return + deltas = paged_kv_indptr[1:] - paged_kv_indptr[:-1] + active_mask = deltas > 0 + if not bool(active_mask.any().item()): + return + active_last_page_len = paged_kv_last_page_len[active_mask] + min_last_page_len = int(active_last_page_len.min().item()) + max_last_page_len = int(active_last_page_len.max().item()) + if min_last_page_len < 1 or max_last_page_len > int(page_size): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_last_page_len range " + f"(min={min_last_page_len}, max={max_last_page_len}, " + f"page_size={int(page_size)})." + ) + if int(page_size) == 1 and bool((active_last_page_len != 1).any().item()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA expects paged_kv_last_page_len==1 when page_size=1." + ) + + @staticmethod + def _to_page_indices( + *, token_indices: torch.Tensor, page_size: int, max_slots: int + ) -> torch.Tensor: + if page_size == 1: + return token_indices.to(dtype=torch.int32) + page_indices = torch.div( + token_indices.to(dtype=torch.int64), + int(page_size), + rounding_mode="floor", + ).to(dtype=torch.int32) + page_indices.clamp_(min=0, max=max(int(max_slots) - 1, 0)) + return page_indices + + @staticmethod + def _kv_token_slot_capacity(kv_cache_base: torch.Tensor) -> int: + if kv_cache_base.ndim <= 0: + return 0 + latent_dim = int(kv_cache_base.shape[-1]) if kv_cache_base.ndim >= 1 else 0 + if latent_dim <= 0: + return 0 + return int(kv_cache_base.numel() // latent_dim) + @staticmethod def _unwrap_linear_output(value: Any) -> torch.Tensor: if isinstance(value, tuple): value = value[0] if not isinstance(value, torch.Tensor): - raise TypeError(f"Expected kv_b_proj to return Tensor, got {type(value)!r}.") + raise TypeError( + f"Expected kv_b_proj to return Tensor, got {type(value)!r}." + ) return value def _infer_num_heads(self, q: torch.Tensor) -> int: @@ -139,7 +258,9 @@ def _read_kv_b_proj_weight(self) -> torch.Tensor: except Exception: weight = getattr(self.kv_b_proj, "weight", None) if not isinstance(weight, torch.Tensor): - raise _SparseUnavailable("GLM5 RTP sparse MLA cannot read kv_b_proj.weight.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA cannot read kv_b_proj.weight." + ) if weight.dtype in ( getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e4m3fnuz", None), @@ -164,9 +285,15 @@ def _get_absorbed_weights(self, q: torch.Tensor) -> _AbsorbedWeights: raise _SparseUnavailable( f"GLM5 RTP sparse MLA got invalid kv_b_proj weight shape {tuple(weight.shape)}." ) - if int(weight.shape[0]) == expected_out and int(weight.shape[1]) == self.kv_lora_rank: + if ( + int(weight.shape[0]) == expected_out + and int(weight.shape[1]) == self.kv_lora_rank + ): kv_b_weight = weight.T.contiguous() - elif int(weight.shape[1]) == expected_out and int(weight.shape[0]) == self.kv_lora_rank: + elif ( + int(weight.shape[1]) == expected_out + and int(weight.shape[0]) == self.kv_lora_rank + ): kv_b_weight = weight.contiguous() else: raise _SparseUnavailable( @@ -180,9 +307,7 @@ def _get_absorbed_weights(self, q: torch.Tensor) -> _AbsorbedWeights: num_heads, self.qk_nope_head_dim + self.v_head_dim, ) - w_uk, w_uv = kv_b_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + w_uk, w_uv = kv_b_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) absorbed = _AbsorbedWeights( w_kc=w_uk.permute(1, 2, 0).contiguous(), w_vc=w_uv.permute(1, 0, 2).contiguous(), @@ -210,7 +335,9 @@ def _apply_rope( in_capture = torch.cuda.is_current_stream_capturing() if in_capture: if self._cg_sparse_bufs is None: - raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires RoPE buffers.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires RoPE buffers." + ) if positions.device != q.device or positions.dtype != torch.long: raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires int64 positions on device." @@ -219,7 +346,9 @@ def _apply_rope( raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires contiguous positions." ) - q_rope = self._cg_sparse_bufs["q_rope"][: q.shape[0], : q.shape[1], : q.shape[2]] + q_rope = self._cg_sparse_bufs["q_rope"][ + : q.shape[0], : q.shape[1], : q.shape[2] + ] q_rope.copy_(q) if k_pe.dim() == 2: k_pe_rope = self._cg_sparse_bufs["k_pe_rope_2d"][ @@ -287,7 +416,9 @@ def _write_current_to_cache( try: from aiter import concat_and_cache_mla except Exception as exc: - raise _SparseUnavailable(f"aiter.concat_and_cache_mla unavailable: {exc}") from exc + raise _SparseUnavailable( + f"aiter.concat_and_cache_mla unavailable: {exc}" + ) from exc scale = self._cache_write_scale.get(compressed_kv.device) if scale is None: @@ -295,7 +426,10 @@ def _write_current_to_cache( self._cache_write_scale[compressed_kv.device] = scale in_capture = torch.cuda.is_current_stream_capturing() if in_capture: - if slot_mapping.device != compressed_kv.device or slot_mapping.dtype != torch.int64: + if ( + slot_mapping.device != compressed_kv.device + or slot_mapping.dtype != torch.int64 + ): raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires int64 slot_mapping on device." ) @@ -332,7 +466,10 @@ def _build_req_id_per_token( query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) if query_start_loc is None: query_start_loc = getattr(attn_metadata, "cu_seqlens_q", None) - if isinstance(query_start_loc, torch.Tensor) and int(query_start_loc.numel()) >= 2: + if ( + isinstance(query_start_loc, torch.Tensor) + and int(query_start_loc.numel()) >= 2 + ): qsl = query_start_loc.to(device=device, dtype=torch.int64) lengths = qsl[1:] - qsl[:-1] return torch.repeat_interleave( @@ -360,6 +497,10 @@ def _convert_topk_to_global( attn_metadata: Any, block_size: int, ) -> torch.Tensor: + if int(block_size) <= 0: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." + ) num_tokens, topk = topk_indices.shape device = topk_indices.device block_table = _RealSparseMlaImpl._block_table(attn_metadata, device) @@ -374,10 +515,14 @@ def _convert_topk_to_global( rounding_mode="floor", ) offsets = torch.remainder(torch.clamp(token_indices, min=0), int(block_size)) - valid = valid & (req_id[:, None] >= 0) & (req_id[:, None] < block_table.shape[0]) + valid = ( + valid & (req_id[:, None] >= 0) & (req_id[:, None] < block_table.shape[0]) + ) valid = valid & (block_cols >= 0) & (block_cols < block_table.shape[1]) safe_req = torch.clamp(req_id, min=0, max=max(int(block_table.shape[0]) - 1, 0)) - safe_cols = torch.clamp(block_cols, min=0, max=max(int(block_table.shape[1]) - 1, 0)) + safe_cols = torch.clamp( + block_cols, min=0, max=max(int(block_table.shape[1]) - 1, 0) + ) block_ids = block_table.to(dtype=torch.long)[safe_req[:, None], safe_cols] valid = valid & (block_ids >= 0) global_indices = block_ids * int(block_size) + offsets @@ -393,9 +538,9 @@ def _decode_indptr( device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) - paged_kv_indptr = ( - torch.arange(num_tokens + 1, device=device, dtype=torch.int32) * int(topk) - ) + paged_kv_indptr = torch.arange( + num_tokens + 1, device=device, dtype=torch.int32 + ) * int(topk) paged_kv_last_page_len = torch.ones( (num_tokens,), device=device, dtype=torch.int32 ) @@ -452,7 +597,9 @@ def _aiter_dtype_for_tensor(tensor: torch.Tensor) -> Any: return dtypes.d_dtypes["bf16"] @staticmethod - def _aiter_dtype_for_torch_dtype(dtype: torch.dtype, *, assume_fp8: bool = False) -> Any: + def _aiter_dtype_for_torch_dtype( + dtype: torch.dtype, *, assume_fp8: bool = False + ) -> Any: try: from aiter import dtypes except Exception as exc: @@ -475,6 +622,39 @@ def _resolve_topk_for_prewarm(self) -> int: return int(value) return 2048 + @staticmethod + def _metadata_token_budget(*, num_tokens: int, topk: int) -> int: + # Sparse decode can materialize up to num_tokens * topk ragged entries. + # Use this upper bound to avoid undersized work/reduce metadata buffers. + return max(int(num_tokens) * max(int(topk), 1), 1) + + @staticmethod + def _validate_capture_sparse_buffer_capacity( + *, + sparse_bufs: dict[str, torch.Tensor], + num_tokens: int, + topk: int, + ) -> None: + needed_indices = int(num_tokens) * int(topk) + if int(sparse_bufs["paged_kv_indices"].numel()) < needed_indices: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_indices buffer is too small " + f"(buffer={int(sparse_bufs['paged_kv_indices'].numel())}, " + f"required={needed_indices})." + ) + if int(sparse_bufs["qo_indptr"].numel()) < int(num_tokens) + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture qo_indptr buffer is too small." + ) + if int(sparse_bufs["paged_kv_indptr"].numel()) < int(num_tokens) + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_indptr buffer is too small." + ) + if int(sparse_bufs["paged_kv_last_page_len"].numel()) < int(num_tokens): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_last_page_len buffer is too small." + ) + def prewarm_for_cuda_graph( self, *, @@ -487,12 +667,16 @@ def prewarm_for_cuda_graph( try: from aiter import get_mla_metadata_info_v1 except Exception as exc: - raise _SparseUnavailable(f"aiter metadata prewarm unavailable: {exc}") from exc + raise _SparseUnavailable( + f"aiter metadata prewarm unavailable: {exc}" + ) from exc max_tokens = int(max_num_tokens) if max_tokens <= 0: return - num_heads = int(self.num_heads or getattr(self.mla_modules, "num_local_heads", 0) or 0) + num_heads = int( + self.num_heads or getattr(self.mla_modules, "num_local_heads", 0) or 0 + ) if num_heads <= 0: # Lazily inferred in eager path; graph capture needs a stable budget. num_heads = int(getattr(self.mla_modules, "num_heads", 0) or 1) @@ -500,11 +684,16 @@ def prewarm_for_cuda_graph( self.num_heads = num_heads padded_num_heads = max(num_heads, 16) if padded_num_heads % num_heads != 0: - padded_num_heads = ((padded_num_heads + num_heads - 1) // num_heads) * num_heads + padded_num_heads = ( + (padded_num_heads + num_heads - 1) // num_heads + ) * num_heads topk = self._resolve_topk_for_prewarm() latent_dim = self.kv_lora_rank + self.qk_rope_head_dim q_dtype = self._aiter_dtype_for_torch_dtype(query_dtype) kv_dtype = self._aiter_dtype_for_torch_dtype(query_dtype, assume_fp8=True) + metadata_budget_tokens = self._metadata_token_budget( + num_tokens=max_tokens, topk=topk + ) ( (work_meta_data_size, work_meta_data_type), (work_indptr_size, work_indptr_type), @@ -513,7 +702,7 @@ def prewarm_for_cuda_graph( (reduce_final_map_size, reduce_final_map_type), (reduce_partial_map_size, reduce_partial_map_type), ) = get_mla_metadata_info_v1( - max(max_tokens, 1), + metadata_budget_tokens, 1, padded_num_heads, q_dtype, @@ -524,9 +713,15 @@ def prewarm_for_cuda_graph( self._cg_sparse_bufs = { "qo_indptr": torch.arange(max_tokens + 1, device=device, dtype=torch.int32), "sparse_seqlen": torch.empty(max_tokens, device=device, dtype=torch.int32), - "paged_kv_indptr": torch.empty(max_tokens + 1, device=device, dtype=torch.int32), - "paged_kv_last_page_len": torch.ones(max_tokens, device=device, dtype=torch.int32), - "paged_kv_indices": torch.empty(max_tokens * topk, device=device, dtype=torch.int32), + "paged_kv_indptr": torch.empty( + max_tokens + 1, device=device, dtype=torch.int32 + ), + "paged_kv_last_page_len": torch.ones( + max_tokens, device=device, dtype=torch.int32 + ), + "paged_kv_indices": torch.empty( + max_tokens * topk, device=device, dtype=torch.int32 + ), "q_rope": torch.empty( max_tokens, num_heads, @@ -541,27 +736,51 @@ def prewarm_for_cuda_graph( max_tokens, 1, self.qk_rope_head_dim, device=device, dtype=query_dtype ), "k_pe_rope_heads": torch.empty( - max_tokens, num_heads, self.qk_rope_head_dim, device=device, dtype=query_dtype + max_tokens, + num_heads, + self.qk_rope_head_dim, + device=device, + dtype=query_dtype, ), "q_latent_nope_t": torch.empty( - num_heads, max_tokens, self.kv_lora_rank, device=device, dtype=query_dtype + num_heads, + max_tokens, + self.kv_lora_rank, + device=device, + dtype=query_dtype, ), "q_latent": torch.empty( max_tokens, num_heads, latent_dim, device=device, dtype=query_dtype ), "q_for_kernel": torch.empty( - max_tokens, padded_num_heads, latent_dim, device=device, dtype=query_dtype + max_tokens, + padded_num_heads, + latent_dim, + device=device, + dtype=query_dtype, ), "latent_output": torch.empty( - max_tokens, padded_num_heads, self.kv_lora_rank, device=device, dtype=query_dtype + max_tokens, + padded_num_heads, + self.kv_lora_rank, + device=device, + dtype=query_dtype, ), "final_output_t": torch.empty( num_heads, max_tokens, self.v_head_dim, device=device, dtype=query_dtype ), - "work_meta_data": torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device), - "work_indptr": torch.empty(work_indptr_size, dtype=work_indptr_type, device=device), - "work_info_set": torch.empty(work_info_set_size, dtype=work_info_set_type, device=device), - "reduce_indptr": torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device), + "work_meta_data": torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ), + "work_indptr": torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ), + "work_info_set": torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ), + "reduce_indptr": torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ), "reduce_final_map": torch.empty( reduce_final_map_size, dtype=reduce_final_map_type, device=device ), @@ -593,11 +812,12 @@ def _build_atom_sparse_metadata( try: from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 from atom.plugin.attention_mla_sparse import ( - generate_sparse_seqlen_triton, triton_convert_req_index_to_global_index, ) except Exception as exc: - raise _SparseUnavailable(f"ATOM sparse MLA metadata helpers unavailable: {exc}") from exc + raise _SparseUnavailable( + f"ATOM sparse MLA metadata helpers unavailable: {exc}" + ) from exc plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) if plugin_metadata is None: @@ -614,7 +834,10 @@ def _build_atom_sparse_metadata( query_start_loc = getattr(plugin_metadata, "query_start_loc", None) if query_start_loc is None: query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) - if not isinstance(query_start_loc, torch.Tensor) or int(query_start_loc.numel()) < 2: + if ( + not isinstance(query_start_loc, torch.Tensor) + or int(query_start_loc.numel()) < 2 + ): raise _SparseUnavailable("GLM5 RTP sparse MLA requires query_start_loc.") if in_capture: if query_start_loc.device != device or query_start_loc.dtype != torch.int32: @@ -622,7 +845,9 @@ def _build_atom_sparse_metadata( "GLM5 RTP sparse MLA capture requires int32 query_start_loc on device." ) else: - query_start_loc = query_start_loc.to(device=device, dtype=torch.int32).contiguous() + query_start_loc = query_start_loc.to( + device=device, dtype=torch.int32 + ).contiguous() seq_lens = getattr(plugin_metadata, "seq_lens", None) if seq_lens is None: @@ -630,7 +855,9 @@ def _build_atom_sparse_metadata( if not isinstance(seq_lens, torch.Tensor) or int(seq_lens.numel()) + 1 != int( query_start_loc.numel() ): - raise _SparseUnavailable("GLM5 RTP sparse MLA requires seq_lens per request.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires seq_lens per request." + ) if in_capture: if seq_lens.device != device or seq_lens.dtype != torch.int32: raise _SparseUnavailable( @@ -641,7 +868,9 @@ def _build_atom_sparse_metadata( if in_capture: if not isinstance(cg_bufs, dict) or sparse_bufs is None: - raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires prewarmed buffers.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires prewarmed buffers." + ) req_id = cg_bufs.get("seq_id_i32", None) if not isinstance(req_id, torch.Tensor): raise _SparseUnavailable( @@ -650,13 +879,18 @@ def _build_atom_sparse_metadata( req_id = req_id[:num_tokens] block_table = getattr(plugin_metadata, "block_table", None) if not isinstance(block_table, torch.Tensor): - raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires block_table.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires block_table." + ) if block_table.device != device or block_table.dtype != torch.int32: raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires int32 block_table on device." ) topk_indices_i32 = topk_indices - if topk_indices_i32.device != device or topk_indices_i32.dtype != torch.int32: + if ( + topk_indices_i32.device != device + or topk_indices_i32.dtype != torch.int32 + ): raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires int32 topk_indices on device." ) @@ -664,6 +898,11 @@ def _build_atom_sparse_metadata( raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires contiguous topk_indices." ) + self._validate_capture_sparse_buffer_capacity( + sparse_bufs=sparse_bufs, + num_tokens=num_tokens, + topk=topk, + ) sparse_seqlen = sparse_bufs["sparse_seqlen"][:num_tokens] torch.clamp(seq_lens[:num_tokens], min=0, max=topk, out=sparse_seqlen) max_query_len_for_sparse = 1 @@ -672,60 +911,78 @@ def _build_atom_sparse_metadata( dtype=torch.int32 ) block_table = self._block_table(attn_metadata, device).to(dtype=torch.int32) - topk_indices_i32 = topk_indices.to(device=device, dtype=torch.int32).contiguous() - query_lens = (query_start_loc[1:] - query_start_loc[:-1]).contiguous() - max_query_len_for_sparse = ( - int(torch.max(query_lens).detach().cpu().item()) if num_tokens else 1 - ) - - if device.type == "cpu": - sparse_seqlen = self._generate_sparse_seqlen_torch( - query_lens=query_lens, - seq_lens=seq_lens, - query_start_loc=query_start_loc, - topk=topk, - num_tokens=num_tokens, - ) - else: - sparse_seqlen = generate_sparse_seqlen_triton( - query_lens, - seq_lens, - query_start_loc, - topk, - num_tokens, - max_query_len_for_sparse, - ) + topk_indices_i32 = topk_indices.to( + device=device, dtype=torch.int32 + ).contiguous() + # Keep prefill aligned with ATOM sparse metadata contract: token-ragged + # representation always uses max_q_len=1. + max_query_len_for_sparse = 1 + # Derive sparse lengths directly from indexer output validity. This is + # robust for chunked prefill where seq_lens may be chunk-local. + sparse_seqlen = torch.sum(topk_indices_i32 >= 0, dim=1, dtype=torch.int32) if in_capture: qo_indptr = sparse_bufs["qo_indptr"][: num_tokens + 1] paged_kv_indptr = sparse_bufs["paged_kv_indptr"][: num_tokens + 1] paged_kv_indptr[0].zero_() paged_kv_last_page_len = sparse_bufs["paged_kv_last_page_len"][:num_tokens] - if int(sparse_bufs["paged_kv_indices"].numel()) < num_tokens * topk: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA capture paged_kv_indices buffer is too small." - ) paged_kv_indices = sparse_bufs["paged_kv_indices"][: num_tokens * topk] else: qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) - paged_kv_indptr = torch.zeros((num_tokens + 1,), device=device, dtype=torch.int32) - paged_kv_last_page_len = torch.ones((num_tokens,), device=device, dtype=torch.int32) - paged_kv_indices = torch.zeros((num_tokens * topk,), device=device, dtype=torch.int32) + paged_kv_indptr = torch.zeros( + (num_tokens + 1,), device=device, dtype=torch.int32 + ) + paged_kv_last_page_len = torch.ones( + (num_tokens,), device=device, dtype=torch.int32 + ) + paged_kv_indices = torch.zeros( + (num_tokens * topk,), device=device, dtype=torch.int32 + ) torch.cumsum(sparse_seqlen, dim=0, out=paged_kv_indptr[1:]) - triton_convert_req_index_to_global_index( - req_id, - block_table, - topk_indices_i32, - paged_kv_indptr, - paged_kv_indices, - BLOCK_SIZE=int(block_size), - NUM_TOPK_TOKENS=topk, - ) + if not in_capture and int(block_size) <= 0: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." + ) + + if not in_capture and topk >= 2048: + # Debug-safe path for long-context GLM5: avoid Triton req->global + # conversion kernel first, because if this step writes invalid data it + # can hard-crash GPU before Python can surface an exception. + global_topk = self._convert_topk_to_global( + topk_indices=topk_indices_i32, + attn_metadata=attn_metadata, + block_size=int(block_size), + ) + token_k = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0) + valid_mask = token_k < sparse_seqlen.unsqueeze(1) + flattened = global_topk.masked_select(valid_mask) + expected = int(paged_kv_indptr[-1].item()) if int(num_tokens) > 0 else 0 + if int(flattened.numel()) != expected: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA inconsistent sparse metadata size " + f"(flattened={int(flattened.numel())}, expected={expected})." + ) + if expected > 0: + paged_kv_indices[:expected].copy_(flattened.to(dtype=torch.int32)) + if int(paged_kv_indices.numel()) > expected: + paged_kv_indices[expected:].zero_() + else: + triton_convert_req_index_to_global_index( + req_id, + block_table, + topk_indices_i32, + paged_kv_indptr, + paged_kv_indices, + BLOCK_SIZE=int(block_size), + NUM_TOPK_TOKENS=topk, + ) padded_num_heads = max(num_heads, 16) if padded_num_heads % num_heads != 0: - padded_num_heads = ((padded_num_heads + num_heads - 1) // num_heads) * num_heads + padded_num_heads = ( + (padded_num_heads + num_heads - 1) // num_heads + ) * num_heads head_repeat_factor = padded_num_heads // num_heads q_dtype = self._aiter_dtype_for_tensor(q_latent) kv_dtype = self._aiter_dtype_for_tensor(kv_cache_base) @@ -737,6 +994,13 @@ def _build_atom_sparse_metadata( reduce_final_map = sparse_bufs["reduce_final_map"] reduce_partial_map = sparse_bufs["reduce_partial_map"] else: + used_sparse_entries = ( + int(paged_kv_indptr[-1].item()) if int(num_tokens) > 0 else 0 + ) + metadata_budget_tokens = max( + self._metadata_token_budget(num_tokens=num_tokens, topk=topk), + used_sparse_entries, + ) ( (work_meta_data_size, work_meta_data_type), (work_indptr_size, work_indptr_type), @@ -745,7 +1009,7 @@ def _build_atom_sparse_metadata( (reduce_final_map_size, reduce_final_map_type), (reduce_partial_map_size, reduce_partial_map_type), ) = get_mla_metadata_info_v1( - max(num_tokens, 1), + metadata_budget_tokens, 1, padded_num_heads, q_dtype, @@ -753,16 +1017,70 @@ def _build_atom_sparse_metadata( is_sparse=True, fast_mode=True, ) - work_meta_data = torch.empty(work_meta_data_size, dtype=work_meta_data_type, device=device) - work_indptr = torch.empty(work_indptr_size, dtype=work_indptr_type, device=device) - work_info_set = torch.empty(work_info_set_size, dtype=work_info_set_type, device=device) - reduce_indptr = torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device=device) + work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) reduce_final_map = torch.empty( reduce_final_map_size, dtype=reduce_final_map_type, device=device ) reduce_partial_map = torch.empty( reduce_partial_map_size, dtype=reduce_partial_map_type, device=device ) + requested_page_size = self._resolve_sparse_page_size() + kv_token_slots = self._kv_token_slot_capacity(kv_cache_base) + page_size = requested_page_size + if ( + in_capture + and requested_page_size > 1 + and kv_token_slots % int(requested_page_size) != 0 + ): + # CUDA graph capture uses warmup shapes that may not be page-aligned; + # keep capture alive by using token page mode. + page_size = 1 + if page_size > 1: + max_page_slots = max( + (kv_token_slots + int(page_size) - 1) // int(page_size), + 1, + ) + if in_capture: + # Capture-safe: avoid host sync (e.g. item/min/max) in graph capture. + if int(paged_kv_indices.numel()) > 0: + paged_kv_indices.floor_divide_(int(page_size)) + paged_kv_indices.clamp_(min=0, max=max_page_slots - 1) + else: + used_now = int(paged_kv_indptr[-1].item()) if num_tokens > 0 else 0 + if used_now > 0: + paged_kv_indices[:used_now] = self._to_page_indices( + token_indices=paged_kv_indices[:used_now], + page_size=page_size, + max_slots=max_page_slots, + ) + else: + max_page_slots = int(kv_token_slots) + + if in_capture and int(paged_kv_indices.numel()) > 0: + # Capture path cannot run host-synced range checks; clamp indices into + # the current kv slot range to avoid kernel-side OOB accesses. + paged_kv_indices.clamp_(min=0, max=max(int(max_page_slots) - 1, 0)) + + if not in_capture: + self._validate_sparse_index_contract( + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + num_tokens=num_tokens, + page_size=page_size, + max_slots=max_page_slots, + ) + get_mla_metadata_v1( qo_indptr, paged_kv_indptr, @@ -776,7 +1094,7 @@ def _build_atom_sparse_metadata( reduce_indptr, reduce_final_map, reduce_partial_map, - page_size=1, + page_size=page_size, kv_granularity=16, max_seqlen_qo=max_query_len_for_sparse, uni_seqlen_qo=max_query_len_for_sparse, @@ -797,6 +1115,7 @@ def _build_atom_sparse_metadata( reduce_partial_map=reduce_partial_map, padded_num_heads=padded_num_heads, head_repeat_factor=head_repeat_factor, + page_size=page_size, ) def _run_sparse_decode( @@ -808,50 +1127,14 @@ def _run_sparse_decode( attn_metadata: Any, block_size: int, ) -> torch.Tensor: - if torch.cuda.is_current_stream_capturing(): - return self._run_aiter_sparse_decode( - q_latent=q_latent, - kv_cache_base=kv_cache_base, - topk_indices=topk_indices, - attn_metadata=attn_metadata, - block_size=block_size, - ) - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) - is_prefill = bool(getattr(plugin_metadata, "num_prefills", 0) or 0) - try: - from flash_mla import flash_mla_sparse_fwd - except Exception as exc: - if is_prefill: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA prefill requires flash_mla_sparse_fwd; " - "refusing to run prefill through the decode kernel." - ) from exc - return self._run_aiter_sparse_decode( - q_latent=q_latent, - kv_cache_base=kv_cache_base, - topk_indices=topk_indices, - attn_metadata=attn_metadata, - block_size=block_size, - ) - - latent_dim = int(q_latent.shape[-1]) - global_topk = self._convert_topk_to_global( + # Keep GLM5 sparse path aligned with ATOM native MLA kernels. + return self._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache_base, topk_indices=topk_indices, attn_metadata=attn_metadata, block_size=block_size, ) - try: - kv_buffer = kv_cache_base.reshape(-1, latent_dim) - output, _, _ = flash_mla_sparse_fwd( - q_latent, - kv_buffer, - global_topk.contiguous().unsqueeze(1), - self.scale, - d_v=self.kv_lora_rank, - ) - except Exception as exc: - raise _SparseUnavailable(f"flash_mla_sparse_fwd failed: {exc}") from exc - return output def _run_aiter_sparse_decode( self, @@ -865,7 +1148,9 @@ def _run_aiter_sparse_decode( try: from aiter.mla import mla_decode_fwd except Exception as exc: - raise _SparseUnavailable(f"aiter.mla_decode_fwd unavailable: {exc}") from exc + raise _SparseUnavailable( + f"aiter.mla_decode_fwd unavailable: {exc}" + ) from exc num_tokens, num_heads, latent_dim = q_latent.shape sparse_meta = self._build_atom_sparse_metadata( @@ -876,6 +1161,7 @@ def _run_aiter_sparse_decode( block_size=block_size, ) in_capture = torch.cuda.is_current_stream_capturing() + page_size = int(sparse_meta.page_size) if sparse_meta.head_repeat_factor > 1: if in_capture and self._cg_sparse_bufs is not None: q_for_kernel = self._cg_sparse_bufs["q_for_kernel"][ @@ -902,7 +1188,43 @@ def _run_aiter_sparse_decode( device=q_latent.device, ) try: - kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) + if page_size == 1: + kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) + else: + kv_slots = int(kv_cache_base.shape[0]) + padded_slots = ( + (kv_slots + int(page_size) - 1) // int(page_size) + ) * int(page_size) + if padded_slots != kv_slots: + if in_capture: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA kv buffer cannot be reshaped by " + "page_size during capture " + f"(kv_slots={kv_slots}, page_size={page_size})." + ) + pad_shape = list(kv_cache_base.shape) + pad_shape[0] = padded_slots - kv_slots + pad = torch.zeros( + pad_shape, + dtype=kv_cache_base.dtype, + device=kv_cache_base.device, + ) + kv_cache_base = torch.cat((kv_cache_base, pad), dim=0) + kv_buffer = kv_cache_base.reshape(-1, page_size, 1, latent_dim) + if not in_capture and int(sparse_meta.paged_kv_indices.numel()) > 0: + self._validate_sparse_index_contract( + paged_kv_indptr=sparse_meta.paged_kv_indptr, + paged_kv_indices=sparse_meta.paged_kv_indices, + num_tokens=num_tokens, + page_size=page_size, + max_slots=int(kv_buffer.shape[0]), + ) + self._validate_sparse_last_page_contract( + paged_kv_indptr=sparse_meta.paged_kv_indptr, + paged_kv_last_page_len=sparse_meta.paged_kv_last_page_len, + num_tokens=num_tokens, + page_size=page_size, + ) mla_decode_fwd( q_for_kernel, kv_buffer, @@ -913,7 +1235,7 @@ def _run_aiter_sparse_decode( sparse_meta.paged_kv_last_page_len, 1, sm_scale=self.scale, - page_size=1, + page_size=page_size, work_meta_data=sparse_meta.work_meta_data, work_indptr=sparse_meta.work_indptr, work_info_set=sparse_meta.work_info_set, @@ -944,7 +1266,9 @@ def forward( del layer_id if attn_metadata is None: raise _SparseUnavailable("GLM5 RTP sparse MLA requires attn_metadata.") - if getattr(getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False): + if getattr( + getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False + ): return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) q_rope, k_pe_rope = self._apply_rope(q, k_pe, positions) kv_cache_base = self._write_current_to_cache( @@ -959,7 +1283,9 @@ def forward( in_capture = torch.cuda.is_current_stream_capturing() if in_capture: if self._cg_sparse_bufs is None: - raise _SparseUnavailable("GLM5 RTP sparse MLA capture requires q buffers.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires q buffers." + ) if q_nope.dtype != absorbed.w_kc.dtype: raise _SparseUnavailable( "GLM5 RTP sparse MLA capture requires q_nope dtype to match absorbed weights." @@ -997,7 +1323,9 @@ def forward( plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) block_size = int(getattr(plugin_metadata, "sparse_block_size", 0) or 0) if block_size <= 0: - raise _SparseUnavailable("GLM5 RTP sparse MLA requires physical block size.") + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires physical block size." + ) latent_output = self._run_sparse_decode( q_latent=q_latent, kv_cache_base=kv_cache_base, @@ -1181,16 +1509,27 @@ def forward( positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: attn_metadata = self._get_attn_metadata() - if getattr(getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False): + if getattr( + getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False + ): return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) if topk_indices is None: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + num_prefills = int(getattr(plugin_metadata, "num_prefills", 0) or 0) + if num_prefills <= 0: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA decode requires topk_indices; " + "refusing dense fallback." + ) return self._dense_forward( q, compressed_kv, k_pe, kv_cache, layer_id, None, positions ) self._validate_topk_indices(q, topk_indices) - if self._default_mock or not callable(getattr(self.sparse_impl, "forward", None)): + if self._default_mock or not callable( + getattr(self.sparse_impl, "forward", None) + ): raise _SparseUnavailable( "GLM5 RTP sparse MLA is unavailable; refusing dense fallback." ) @@ -1210,13 +1549,11 @@ def forward( layer_id, **kwargs, ) - except _SparseUnavailable: - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) - if bool(getattr(plugin_metadata, "num_prefills", 0) or 0): - return self._dense_forward( - q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices, positions - ) - raise + except _SparseUnavailable as exc: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA unavailable; dense fallback is disabled. " + f"root_cause={exc}" + ) from exc def rtp_sparse_attn_indexer( diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 4517262d9c..6295f6dc47 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -626,6 +626,20 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) if is_prefill: + # For chunked prefill, prefix_lengths can remain per-chunk while + # sequence_lengths_plus_1_d tracks the true cumulative context length. + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() prefix_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "prefix_lengths_d", None), device=device, @@ -1221,6 +1235,7 @@ def _build_plugin_attention_metadata( plugin_md.topk_tokens = 0 plugin_md.sparse_block_size = int(seq_size_per_block) plugin_md.cg_bufs = cg_bufs + plugin_md.positions = positions cu_seqlen_ks = None cu_seqlen_ke = None if is_prefill: @@ -1715,8 +1730,15 @@ def _attach_mla_layer_caches( getattr(forward_context, "rtp_seq_size_per_block", 0) or getattr(forward_context, "rtp_kernel_seq_size_per_block", 0) or getattr(get_current_atom_config(), "kv_cache_block_size", 0) - or 1 ) + if block_size <= 0: + raise ValueError( + "RTP plugin requires positive block_size for MLA indexer cache " + f"(layer={layer_num}, rtp_seq_size_per_block=" + f"{getattr(forward_context, 'rtp_seq_size_per_block', 0)}, " + "rtp_kernel_seq_size_per_block=" + f"{getattr(forward_context, 'rtp_kernel_seq_size_per_block', 0)})." + ) indexer_cache_tensor = RTPForwardContext._build_fallback_indexer_cache( cache_owner=cache_owner, layer_cache=layer_cache, diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index 3d8207703f..af4f03ad6a 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -59,6 +59,7 @@ def _set_kv_cache_data(value): from atom.plugin.rtpllm.utils.forward_context import ( # noqa: E402 RTPForwardContext, RTPForwardMLAContext, + RTPForwardQwen35HybridContext, ) @@ -212,6 +213,31 @@ def test_plugin_attention_metadata_keeps_indexer_block_table_expanded(): assert md.block_tables[0, 64:68].cpu().tolist() == [512, 513, 514, 515] +def test_qwen35_context_does_not_use_glm5_indexer_block_expansion(): + block_table = torch.tensor([[7, 8]], dtype=torch.int32) + + qwen_block_tables = RTPForwardQwen35HybridContext._build_indexer_block_tables( + block_table_i32=block_table, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + cg_max_seq_len=0, + in_capture=False, + cg_bufs=None, + ) + glm5_block_tables = RTPForwardMLAContext._build_indexer_block_tables( + block_table_i32=block_table, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + cg_max_seq_len=0, + in_capture=False, + cg_bufs=None, + ) + + assert qwen_block_tables.shape == (1, 2) + assert qwen_block_tables.cpu().tolist() == [[7, 8]] + assert glm5_block_tables.shape[1] > qwen_block_tables.shape[1] + + def test_plugin_attention_metadata_keeps_physical_block_table_for_base_context(): attn_inputs = _make_attn_inputs( input_lengths=torch.tensor([1030], dtype=torch.int32), diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index 6d6147a683..99fe9516d1 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -8,10 +8,7 @@ import torch - -_SPARSE_BACKEND_MODULE = ( - "atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend" -) +_SPARSE_BACKEND_MODULE = "atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend" _FORBIDDEN_CUDA_SPARSE_MODULES = ( "flashmla_sparse", "flash_mla", @@ -25,7 +22,9 @@ def _guard_sparse_kernel_imports(monkeypatch): def _guarded_import(name, *args, **kwargs): if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): - raise AssertionError(f"M2 sparse contract must not import CUDA sparse kernel: {name}") + raise AssertionError( + f"M2 sparse contract must not import CUDA sparse kernel: {name}" + ) return original_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", _guarded_import) @@ -183,13 +182,17 @@ def _build_backend(backend_cls, dense_backend, sparse_impl): params = inspect.signature(backend_cls).parameters kwargs = {} if "dense_backend" not in params: - raise AssertionError("RTPSparseMlaBackend must accept dense_backend= for dense fallback") + raise AssertionError( + "RTPSparseMlaBackend must accept dense_backend= for dense fallback" + ) kwargs["dense_backend"] = dense_backend if "sparse_impl" in params: kwargs["sparse_impl"] = sparse_impl else: - raise AssertionError("RTPSparseMlaBackend must accept a mock sparse impl injection") + raise AssertionError( + "RTPSparseMlaBackend must accept a mock sparse impl injection" + ) if "v_head_dim" in params: kwargs["v_head_dim"] = dense_backend.v_head_dim @@ -228,6 +231,19 @@ def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): def test_sparse_backend_falls_back_to_dense_when_topk_is_none(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) + forward_context_mod = sys.modules["atom.utils.forward_context"] + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=1, is_dummy_warmup=False) + ), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() backend = _build_backend(backend_cls, dense_backend, sparse_impl) @@ -248,6 +264,37 @@ def test_sparse_backend_falls_back_to_dense_when_topk_is_none(monkeypatch): assert dense_backend.calls[0]["topk_indices"] is None +def test_sparse_backend_decode_without_topk_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = sys.modules["atom.utils.forward_context"] + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=0, is_dummy_warmup=False) + ), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + dense_backend = _FakeDenseBackend() + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, dense_backend, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None) + except module._SparseUnavailable as exc: + assert "decode requires topk_indices" in str(exc) + else: + raise AssertionError("Expected missing decode topk_indices to raise") + assert dense_backend.calls == [] + assert sparse_impl.calls == [] + + def test_sparse_backend_threads_kv_cache_and_layer_id_to_sparse_impl(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) dense_backend = _FakeDenseBackend() @@ -292,7 +339,7 @@ def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): assert sparse_impl.calls[0]["attn_metadata"] is attn_metadata -def test_sparse_backend_prefill_missing_flashmla_falls_back_to_dense(monkeypatch): +def test_sparse_backend_prefill_missing_sparse_kernel_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) module = importlib.import_module(_SPARSE_BACKEND_MODULE) forward_context_mod = sys.modules["atom.utils.forward_context"] @@ -320,11 +367,16 @@ def forward(self, *args, **kwargs): q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) - output = backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + except module._SparseUnavailable: + pass + else: + raise AssertionError( + "prefill sparse unavailability must not fall back to dense" + ) - assert output.shape == (3, 2, dense_backend.v_head_dim) - assert len(dense_backend.calls) == 1 - assert dense_backend.calls[0]["topk_indices"] is topk + assert len(dense_backend.calls) == 0 def test_sparse_backend_decode_missing_sparse_kernel_still_raises(monkeypatch): @@ -431,13 +483,24 @@ def fake_metadata_info(*args, **kwargs): def fake_metadata_v1(*args, **kwargs): calls["metadata_v1"] = (args, kwargs) - monkeypatch.setattr(aiter, "get_mla_metadata_info_v1", fake_metadata_info, raising=False) + monkeypatch.setattr( + aiter, "get_mla_metadata_info_v1", fake_metadata_info, raising=False + ) monkeypatch.setattr(aiter, "get_mla_metadata_v1", fake_metadata_v1, raising=False) fake_mla = type(sys)("aiter.mla") - def fake_mla_decode_fwd(q, kv, output, qo_indptr, paged_kv_indptr, paged_kv_indices, - paged_kv_last_page_len, *args, **kwargs): + def fake_mla_decode_fwd( + q, + kv, + output, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + *args, + **kwargs, + ): calls["mla_decode_fwd"] = { "q": q, "kv": kv, @@ -456,12 +519,21 @@ def fake_mla_decode_fwd(q, kv, output, qo_indptr, paged_kv_indptr, paged_kv_indi fake_sparse_helpers = type(sys)("atom.plugin.attention_mla_sparse") - def fake_generate_sparse_seqlen(query_lens, seq_lens, query_start_loc, topk, - num_tokens, max_query_len): + def fake_generate_sparse_seqlen( + query_lens, seq_lens, query_start_loc, topk, num_tokens, max_query_len + ): return torch.tensor([3, 2], dtype=torch.int32, device=query_lens.device) - def fake_convert(req_id, block_table, token_indices, cu_seqlens, out, - BLOCK_SIZE=1, NUM_TOPK_TOKENS=0, BLOCK_N=128): + def fake_convert( + req_id, + block_table, + token_indices, + cu_seqlens, + out, + BLOCK_SIZE=1, + NUM_TOPK_TOKENS=0, + BLOCK_N=128, + ): out[:5] = torch.tensor([0, 1, 2, 4, 5], dtype=torch.int32, device=out.device) fake_sparse_helpers.generate_sparse_seqlen_triton = fake_generate_sparse_seqlen @@ -513,3 +585,66 @@ def fake_convert(req_id, block_table, token_indices, cu_seqlens, out, assert decode_call["kwargs"]["page_size"] == 1 assert decode_call["kwargs"]["work_meta_data"] is not None assert decode_call["kwargs"]["reduce_final_map"] is not None + + +def test_real_sparse_decode_rejects_oob_paged_kv_indices(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + decode_called = {"value": False} + + fake_mla = type(sys)("aiter.mla") + + def fake_mla_decode_fwd(*args, **kwargs): + decode_called["value"] = True + + fake_mla.mla_decode_fwd = fake_mla_decode_fwd + monkeypatch.setitem(sys.modules, "aiter.mla", fake_mla) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5) + kv_cache = torch.randn(8, 1, 5) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + attn_metadata = SimpleNamespace(plugin_metadata=SimpleNamespace()) + + oob_meta = module._AtomSparseMetadata( + qo_indptr=torch.tensor([0, 1, 2], dtype=torch.int32), + paged_kv_indptr=torch.tensor([0, 3, 6], dtype=torch.int32), + # kv_buffer has 8 slots, index=8 is out of range. + paged_kv_indices=torch.tensor([0, 1, 2, 3, 4, 8], dtype=torch.int32), + paged_kv_last_page_len=torch.ones(2, dtype=torch.int32), + work_meta_data=torch.zeros(1, dtype=torch.int32), + work_indptr=torch.zeros(1, dtype=torch.int32), + work_info_set=torch.zeros(1, dtype=torch.int32), + reduce_indptr=torch.zeros(1, dtype=torch.int32), + reduce_final_map=torch.zeros(1, dtype=torch.int32), + reduce_partial_map=torch.zeros(1, dtype=torch.int32), + padded_num_heads=2, + head_repeat_factor=1, + page_size=1, + ) + monkeypatch.setattr(impl, "_build_atom_sparse_metadata", lambda **kwargs: oob_meta) + + try: + impl._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + except module._SparseUnavailable as exc: + assert "out-of-range paged_kv_indices" in str(exc) + else: + raise AssertionError( + "Expected OOB paged_kv_indices to raise _SparseUnavailable" + ) + assert decode_called["value"] is False diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index 4510425373..747b0e7ef1 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -137,9 +137,12 @@ def _make_wrapper_instance(cls): def test_glm5_load_skip_python_model_does_not_create_atom_model(): fake_modules = _install_fake_rtp_modules() - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -166,13 +169,19 @@ def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): fake_atom_model = MagicMock(name="atom_model") fake_atom_model.to.return_value = fake_atom_model - with patch.dict( - sys.modules, - fake_modules, - ), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, - ), patch("atom.prepare_model", return_value=fake_atom_model, create=True) as prepare_model: + with ( + patch.dict( + sys.modules, + fake_modules, + ), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + patch( + "atom.prepare_model", return_value=fake_atom_model, create=True + ) as prepare_model, + ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") module = importlib.reload(module) @@ -180,11 +189,14 @@ def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): instance.device = "cpu" instance.weight = MagicMock() - with _patch_optional_attr( - module, "apply_attention_mla_rtpllm_patch" - ) as mla_patch, _patch_optional_attr( - module, "apply_deepseek_mla_rtpllm_patch" - ) as deepseek_patch: + with ( + _patch_optional_attr( + module, "apply_attention_mla_rtpllm_patch" + ) as mla_patch, + _patch_optional_attr( + module, "apply_deepseek_mla_rtpllm_patch" + ) as deepseek_patch, + ): result = instance._create_python_model() prepare_model.assert_called_once_with(config=instance, engine="rtpllm") @@ -200,12 +212,15 @@ def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): def test_glm5_support_cuda_graph_honors_eager_env(): fake_modules = _install_fake_rtp_modules() - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - { - "RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models", - "ENABLE_CUDA_GRAPH": "0", - }, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + { + "RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models", + "ENABLE_CUDA_GRAPH": "0", + }, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -215,6 +230,30 @@ def test_glm5_support_cuda_graph_honors_eager_env(): assert instance.support_cuda_graph() is False +def test_glm5_runtime_uses_mla_forward_context_class(): + fake_modules = _install_fake_rtp_modules() + fake_utils_mod = ModuleType("atom.plugin.rtpllm.utils") + marker_context_cls = object() + fake_utils_mod.RTPForwardMLAContext = marker_context_cls + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict(sys.modules, {"atom.plugin.rtpllm.utils": fake_utils_mod}), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + module.RTPForwardContext = None + + context_cls = module._ATOMGlm5MoeRuntime._get_forward_context_cls() + + assert context_cls is marker_context_cls + + def test_glm5_runtime_forward_wraps_model_call_in_rtp_context(monkeypatch): fake_modules = _install_fake_rtp_modules() expected_input_ids = torch.tensor([10, 11], dtype=torch.int64) @@ -258,9 +297,12 @@ def bind(**kwargs): assert kwargs["positions"].dtype == torch.long return _FakeBind() - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -294,9 +336,12 @@ class _FakeRTPForwardContext: def collect_layer_maps(model): return ({}, {}, {}) - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -322,9 +367,12 @@ def collect_layer_maps(model): def test_glm5_runtime_decode_positions_prefer_sequence_lengths_plus_one(): fake_modules = _install_fake_rtp_modules() - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -348,9 +396,12 @@ def test_glm5_runtime_decode_positions_prefer_sequence_lengths_plus_one(): def test_glm5_runtime_graph_decode_ignores_stale_position_ids(): fake_modules = _install_fake_rtp_modules() - with patch.dict(sys.modules, fake_modules), patch.dict( - os.environ, - {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), ): sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) module = importlib.import_module("atom.plugin.rtpllm.models.glm5") @@ -374,4 +425,3 @@ def test_glm5_runtime_graph_decode_ignores_stale_position_ids(): ) assert positions.cpu().tolist() == [34, 48, 49] - From 8b92a5c283647ddf786591e12885d0ce7ed270c8 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 5 Jun 2026 11:01:16 +0000 Subject: [PATCH 06/20] refactor: RTP glm5 code --- .../attention_backend/rtp_mla_attention.py | 16 +++- .../rtp_sparse_mla_backend.py | 83 ++++++------------- atom/plugin/rtpllm/utils/forward_context.py | 34 ++++---- ...est_rtpllm_glm5_sparse_backend_contract.py | 1 + 4 files changed, 58 insertions(+), 76 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index 8b86e5d6cb..b6b6a410d4 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -23,7 +23,9 @@ def _resolve_index_topk(attn) -> int: def _get_topk_indices_buffer(attn) -> torch.Tensor: indexer = getattr(attn, "indexer", None) - buffer = getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None + buffer = ( + getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None + ) if buffer is None: buffer = getattr(attn, "topk_indices_buffer", None) if buffer is None: @@ -110,6 +112,11 @@ def __init__(self, *args, **kwargs) -> None: self.dense_backend = None self.kv_cache = kwargs.get("kv_cache") self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0))) + self._dense_backend_accepts_positions = ( + self._backend_accepts_positions(self.dense_backend) + if self.dense_backend is not None + else False + ) @staticmethod def _backend_accepts_positions(backend: object) -> bool: @@ -134,7 +141,9 @@ def _project_query( if q.ndim == 3: return q, True - num_heads = self.num_local_heads if self.num_local_heads is not None else self.num_heads + num_heads = ( + self.num_local_heads if self.num_local_heads is not None else self.num_heads + ) if num_heads is None: if self.qk_head_dim is None: raise AttributeError("GLM5 RTP MLA native contract requires num_heads") @@ -182,7 +191,7 @@ def forward( kwargs.get("topk_indices", topk_indices), ) forward_kwargs = {"topk_indices": topk_indices} - if self._backend_accepts_positions(self.dense_backend): + if self._dense_backend_accepts_positions: forward_kwargs["positions"] = positions attn_output = self.dense_backend.forward( q, @@ -207,4 +216,3 @@ def apply_attention_mla_rtpllm_patch() -> None: ops.RTPMLAAttention = RTPMLAAttention ops.Attention = RTPMLAAttention - diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index b3fddabf22..13518faa86 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -100,6 +100,13 @@ def __init__( self._cache_write_scale: dict[torch.device, torch.Tensor] = {} self._cg_sparse_bufs: dict[str, torch.Tensor] | None = None self._cg_workspace_signature: tuple[Any, ...] | None = None + self._sparse_page_size = self._resolve_sparse_page_size() + self._enable_debug_safe_path = ( + os.getenv("ATOM_RTP_GLM5_SPARSE_DEBUG_SAFE", "0") == "1" + ) + self._enable_sparse_validate = ( + os.getenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "0") == "1" + ) @staticmethod def _resolve_sparse_page_size() -> int: @@ -530,47 +537,6 @@ def _convert_topk_to_global( dtype=torch.int32 ) - @staticmethod - def _decode_indptr( - *, - num_tokens: int, - topk: int, - device: torch.device, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) - paged_kv_indptr = torch.arange( - num_tokens + 1, device=device, dtype=torch.int32 - ) * int(topk) - paged_kv_last_page_len = torch.ones( - (num_tokens,), device=device, dtype=torch.int32 - ) - return qo_indptr, paged_kv_indptr, paged_kv_last_page_len - - @staticmethod - def _generate_sparse_seqlen_torch( - *, - query_lens: torch.Tensor, - seq_lens: torch.Tensor, - query_start_loc: torch.Tensor, - topk: int, - num_tokens: int, - ) -> torch.Tensor: - out = torch.zeros((num_tokens,), dtype=torch.int32, device=query_lens.device) - for req_id in range(int(query_lens.numel())): - q_len = int(query_lens[req_id].item()) - seq_len = int(seq_lens[req_id].item()) - start = int(query_start_loc[req_id].item()) - if q_len <= 0 or seq_len <= 0: - continue - context_start = seq_len - q_len - offsets = torch.arange(q_len, device=query_lens.device, dtype=torch.int32) - out[start : start + q_len] = torch.clamp( - context_start + offsets + 1, - min=0, - max=int(topk), - ) - return out - @staticmethod def _aiter_dtype_for_tensor(tensor: torch.Tensor) -> Any: try: @@ -945,10 +911,9 @@ def _build_atom_sparse_metadata( f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." ) - if not in_capture and topk >= 2048: - # Debug-safe path for long-context GLM5: avoid Triton req->global - # conversion kernel first, because if this step writes invalid data it - # can hard-crash GPU before Python can surface an exception. + if not in_capture and self._enable_debug_safe_path and topk >= 2048: + # Keep a debug-safe fallback for field diagnostics. This path is slower + # than Triton and should stay disabled in normal serving. global_topk = self._convert_topk_to_global( topk_indices=topk_indices_i32, attn_metadata=attn_metadata, @@ -994,12 +959,8 @@ def _build_atom_sparse_metadata( reduce_final_map = sparse_bufs["reduce_final_map"] reduce_partial_map = sparse_bufs["reduce_partial_map"] else: - used_sparse_entries = ( - int(paged_kv_indptr[-1].item()) if int(num_tokens) > 0 else 0 - ) - metadata_budget_tokens = max( - self._metadata_token_budget(num_tokens=num_tokens, topk=topk), - used_sparse_entries, + metadata_budget_tokens = self._metadata_token_budget( + num_tokens=num_tokens, topk=topk ) ( (work_meta_data_size, work_meta_data_type), @@ -1035,7 +996,7 @@ def _build_atom_sparse_metadata( reduce_partial_map = torch.empty( reduce_partial_map_size, dtype=reduce_partial_map_type, device=device ) - requested_page_size = self._resolve_sparse_page_size() + requested_page_size = self._sparse_page_size kv_token_slots = self._kv_token_slot_capacity(kv_cache_base) page_size = requested_page_size if ( @@ -1072,7 +1033,7 @@ def _build_atom_sparse_metadata( # the current kv slot range to avoid kernel-side OOB accesses. paged_kv_indices.clamp_(min=0, max=max(int(max_page_slots) - 1, 0)) - if not in_capture: + if not in_capture and self._enable_sparse_validate: self._validate_sparse_index_contract( paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, @@ -1211,7 +1172,11 @@ def _run_aiter_sparse_decode( ) kv_cache_base = torch.cat((kv_cache_base, pad), dim=0) kv_buffer = kv_cache_base.reshape(-1, page_size, 1, latent_dim) - if not in_capture and int(sparse_meta.paged_kv_indices.numel()) > 0: + if ( + not in_capture + and self._enable_sparse_validate + and int(sparse_meta.paged_kv_indices.numel()) > 0 + ): self._validate_sparse_index_contract( paged_kv_indptr=sparse_meta.paged_kv_indptr, paged_kv_indices=sparse_meta.paged_kv_indices, @@ -1399,6 +1364,12 @@ def __init__( else: self.sparse_impl = _ContractSparseMlaImpl(self.v_head_dim) self._default_mock = True + self._sparse_impl_accepts_positions = self._impl_accepts_positions( + self.sparse_impl + ) + self._dense_forward_accepts_positions = self._call_accepts_positions( + getattr(self.dense_backend, "forward", None) + ) def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 del attn_inputs @@ -1487,7 +1458,7 @@ def _dense_forward( positions: Optional[torch.Tensor], ) -> torch.Tensor: kwargs = {"topk_indices": topk_indices} - if self._call_accepts_positions(self.dense_backend.forward): + if self._dense_forward_accepts_positions: kwargs["positions"] = positions return self.dense_backend.forward( q, @@ -1538,7 +1509,7 @@ def forward( "topk_indices": topk_indices, "attn_metadata": attn_metadata, } - if self._impl_accepts_positions(self.sparse_impl): + if self._sparse_impl_accepts_positions: kwargs["positions"] = positions try: return self.sparse_impl.forward( diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 6295f6dc47..bd7f15c996 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -112,7 +112,7 @@ def _recover_physical_block_table_from_kernel_kernel( @dataclass(frozen=True) class RTPForwardContext: - gdn_metadata: GDNAttentionMetadata + gdn_metadata: GDNAttentionMetadata | None attn_metadata: AttentionMetaData rtp_attn_inputs: Any rtp_seq_size_per_block: int @@ -1235,7 +1235,6 @@ def _build_plugin_attention_metadata( plugin_md.topk_tokens = 0 plugin_md.sparse_block_size = int(seq_size_per_block) plugin_md.cg_bufs = cg_bufs - plugin_md.positions = positions cu_seqlen_ks = None cu_seqlen_ke = None if is_prefill: @@ -1531,6 +1530,8 @@ def build( if kernel_seq_size_per_block <= 0: kernel_seq_size_per_block = int(seq_size_per_block) state_indices_cache: Dict[tuple[int, bool], torch.Tensor] = {} + resolved_layer_maps = layer_maps or cls.collect_layer_maps(model) + gdn_layer_map, _, _ = resolved_layer_maps layer_group_map_signature = cls._layer_group_map_signature(attn_inputs) layer_group_map = getattr(runtime, "_rtp_layer_group_map", None) cached_layer_group_map_signature = getattr( @@ -1543,19 +1544,21 @@ def build( layer_group_map = cls._build_layer_group_map(attn_inputs) runtime._rtp_layer_group_map = layer_group_map runtime._rtp_layer_group_map_signature = layer_group_map_signature - gdn_metadata = cls._build_gdn_metadata( - attn_inputs, - seq_size_per_block=seq_size_per_block, - num_tokens=int(positions.numel()), - state_indices_cache=state_indices_cache, - layer_group_map=layer_group_map, - ) - # Keep raw RTP attention inputs in metadata so GDN can resolve per-layer - # block-map/state-index semantics (same idea as RTP's select_block_map_for_layer). - gdn_metadata.rtp_attn_inputs = attn_inputs - gdn_metadata.rtp_seq_size_per_block = int(seq_size_per_block) - gdn_metadata.rtp_state_indices_cache = state_indices_cache - gdn_metadata.rtp_layer_group_map = layer_group_map + gdn_metadata = None + if gdn_layer_map: + gdn_metadata = cls._build_gdn_metadata( + attn_inputs, + seq_size_per_block=seq_size_per_block, + num_tokens=int(positions.numel()), + state_indices_cache=state_indices_cache, + layer_group_map=layer_group_map, + ) + # Keep raw RTP attention inputs in metadata so GDN can resolve per-layer + # block-map/state-index semantics (same idea as RTP's select_block_map_for_layer). + gdn_metadata.rtp_attn_inputs = attn_inputs + gdn_metadata.rtp_seq_size_per_block = int(seq_size_per_block) + gdn_metadata.rtp_state_indices_cache = state_indices_cache + gdn_metadata.rtp_layer_group_map = layer_group_map attn_metadata = cls._build_plugin_attention_metadata( attn_inputs=attn_inputs, positions=positions, @@ -1564,7 +1567,6 @@ def build( cg_max_seq_len=int(cg_max_seq_len), cg_bufs=cg_bufs, ) - resolved_layer_maps = layer_maps or cls.collect_layer_maps(model) kv_cache_signature = cls._kv_cache_signature( runtime=runtime, layer_maps=resolved_layer_maps, diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index 99fe9516d1..6cf0dac27e 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -590,6 +590,7 @@ def fake_convert( def test_real_sparse_decode_rejects_oob_paged_kv_indices(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) decode_called = {"value": False} + monkeypatch.setenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "1") fake_mla = type(sys)("aiter.mla") From 27be06e2d4a43d624a3117dd8612b2f0e4a5abb5 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 5 Jun 2026 11:31:07 +0000 Subject: [PATCH 07/20] feat: RTP glm5 optimize sparse decode path --- .../rtp_sparse_mla_backend.py | 98 ++++++++++++++----- 1 file changed, 71 insertions(+), 27 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 13518faa86..06ec2677df 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -107,6 +107,9 @@ def __init__( self._enable_sparse_validate = ( os.getenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "0") == "1" ) + self._enable_capture_metadata_reuse = ( + os.getenv("ATOM_RTP_GLM5_SPARSE_CAPTURE_META_REUSE", "1") == "1" + ) @staticmethod def _resolve_sparse_page_size() -> int: @@ -996,6 +999,30 @@ def _build_atom_sparse_metadata( reduce_partial_map = torch.empty( reduce_partial_map_size, dtype=reduce_partial_map_type, device=device ) + capture_meta_sig = ( + int(num_tokens), + int(topk), + int(padded_num_heads), + str(q_dtype), + str(kv_dtype), + str(device), + ) + reuse_capture_metadata = False + if in_capture and self._enable_capture_metadata_reuse: + cached_capture_meta = getattr( + plugin_metadata, "_rtp_sparse_capture_meta_workspace", None + ) + if ( + isinstance(cached_capture_meta, dict) + and cached_capture_meta.get("signature") == capture_meta_sig + ): + work_meta_data = cached_capture_meta["work_meta_data"] + work_indptr = cached_capture_meta["work_indptr"] + work_info_set = cached_capture_meta["work_info_set"] + reduce_indptr = cached_capture_meta["reduce_indptr"] + reduce_final_map = cached_capture_meta["reduce_final_map"] + reduce_partial_map = cached_capture_meta["reduce_partial_map"] + reuse_capture_metadata = True requested_page_size = self._sparse_page_size kv_token_slots = self._kv_token_slot_capacity(kv_cache_base) page_size = requested_page_size @@ -1042,27 +1069,38 @@ def _build_atom_sparse_metadata( max_slots=max_page_slots, ) - get_mla_metadata_v1( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - padded_num_heads, - 1, - True, - work_meta_data, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - page_size=page_size, - kv_granularity=16, - max_seqlen_qo=max_query_len_for_sparse, - uni_seqlen_qo=max_query_len_for_sparse, - fast_mode=True, - dtype_q=q_dtype, - dtype_kv=kv_dtype, - ) + if not reuse_capture_metadata: + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + padded_num_heads, + 1, + True, + work_meta_data, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + page_size=page_size, + kv_granularity=16, + max_seqlen_qo=max_query_len_for_sparse, + uni_seqlen_qo=max_query_len_for_sparse, + fast_mode=True, + dtype_q=q_dtype, + dtype_kv=kv_dtype, + ) + if in_capture and self._enable_capture_metadata_reuse: + plugin_metadata._rtp_sparse_capture_meta_workspace = { + "signature": capture_meta_sig, + "work_meta_data": work_meta_data, + "work_indptr": work_indptr, + "work_info_set": work_info_set, + "reduce_indptr": reduce_indptr, + "reduce_final_map": reduce_final_map, + "reduce_partial_map": reduce_partial_map, + } return _AtomSparseMetadata( qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, @@ -1128,13 +1166,19 @@ def _run_aiter_sparse_decode( q_for_kernel = self._cg_sparse_bufs["q_for_kernel"][ :num_tokens, : sparse_meta.padded_num_heads, : ] - for repeat_idx in range(sparse_meta.head_repeat_factor): - q_for_kernel[ - :, repeat_idx :: sparse_meta.head_repeat_factor, : - ].copy_(q_latent) + # Capture path: use one broadcasted copy to fill repeated heads, + # avoiding per-repeat slice copies in the decode hot path. + q_for_kernel.view( + num_tokens, + num_heads, + sparse_meta.head_repeat_factor, + latent_dim, + ).copy_(q_latent.unsqueeze(2)) else: - q_for_kernel = q_latent.repeat_interleave( - sparse_meta.head_repeat_factor, dim=1 + q_for_kernel = ( + q_latent.unsqueeze(2) + .expand(-1, -1, sparse_meta.head_repeat_factor, -1) + .reshape(num_tokens, sparse_meta.padded_num_heads, latent_dim) ) else: q_for_kernel = q_latent From d6afedac22cd47a0b7414758b9e10d8cd96d45f8 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 5 Jun 2026 16:35:23 +0000 Subject: [PATCH 08/20] refactor: RTP remove redundant envs --- .../rtp_sparse_mla_backend.py | 139 ++---------------- 1 file changed, 15 insertions(+), 124 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 06ec2677df..07cd089c8f 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -100,32 +100,9 @@ def __init__( self._cache_write_scale: dict[torch.device, torch.Tensor] = {} self._cg_sparse_bufs: dict[str, torch.Tensor] | None = None self._cg_workspace_signature: tuple[Any, ...] | None = None - self._sparse_page_size = self._resolve_sparse_page_size() - self._enable_debug_safe_path = ( - os.getenv("ATOM_RTP_GLM5_SPARSE_DEBUG_SAFE", "0") == "1" - ) self._enable_sparse_validate = ( os.getenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "0") == "1" ) - self._enable_capture_metadata_reuse = ( - os.getenv("ATOM_RTP_GLM5_SPARSE_CAPTURE_META_REUSE", "1") == "1" - ) - - @staticmethod - def _resolve_sparse_page_size() -> int: - value = os.getenv("ATOM_RTP_GLM5_SPARSE_PAGE_SIZE", "1") - try: - page_size = int(value) - except ValueError as exc: - raise _SparseUnavailable( - f"Invalid ATOM_RTP_GLM5_SPARSE_PAGE_SIZE={value!r}." - ) from exc - if page_size <= 0: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA requires positive page_size, " - f"got page_size={page_size}." - ) - return page_size @staticmethod def _validate_sparse_index_contract( @@ -203,20 +180,6 @@ def _validate_sparse_last_page_contract( "GLM5 RTP sparse MLA expects paged_kv_last_page_len==1 when page_size=1." ) - @staticmethod - def _to_page_indices( - *, token_indices: torch.Tensor, page_size: int, max_slots: int - ) -> torch.Tensor: - if page_size == 1: - return token_indices.to(dtype=torch.int32) - page_indices = torch.div( - token_indices.to(dtype=torch.int64), - int(page_size), - rounding_mode="floor", - ).to(dtype=torch.int32) - page_indices.clamp_(min=0, max=max(int(max_slots) - 1, 0)) - return page_indices - @staticmethod def _kv_token_slot_capacity(kv_cache_base: torch.Tensor) -> int: if kv_cache_base.ndim <= 0: @@ -914,37 +877,15 @@ def _build_atom_sparse_metadata( f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." ) - if not in_capture and self._enable_debug_safe_path and topk >= 2048: - # Keep a debug-safe fallback for field diagnostics. This path is slower - # than Triton and should stay disabled in normal serving. - global_topk = self._convert_topk_to_global( - topk_indices=topk_indices_i32, - attn_metadata=attn_metadata, - block_size=int(block_size), - ) - token_k = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0) - valid_mask = token_k < sparse_seqlen.unsqueeze(1) - flattened = global_topk.masked_select(valid_mask) - expected = int(paged_kv_indptr[-1].item()) if int(num_tokens) > 0 else 0 - if int(flattened.numel()) != expected: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA inconsistent sparse metadata size " - f"(flattened={int(flattened.numel())}, expected={expected})." - ) - if expected > 0: - paged_kv_indices[:expected].copy_(flattened.to(dtype=torch.int32)) - if int(paged_kv_indices.numel()) > expected: - paged_kv_indices[expected:].zero_() - else: - triton_convert_req_index_to_global_index( - req_id, - block_table, - topk_indices_i32, - paged_kv_indptr, - paged_kv_indices, - BLOCK_SIZE=int(block_size), - NUM_TOPK_TOKENS=topk, - ) + triton_convert_req_index_to_global_index( + req_id, + block_table, + topk_indices_i32, + paged_kv_indptr, + paged_kv_indices, + BLOCK_SIZE=int(block_size), + NUM_TOPK_TOKENS=topk, + ) padded_num_heads = max(num_heads, 16) if padded_num_heads % num_heads != 0: @@ -1008,7 +949,7 @@ def _build_atom_sparse_metadata( str(device), ) reuse_capture_metadata = False - if in_capture and self._enable_capture_metadata_reuse: + if in_capture: cached_capture_meta = getattr( plugin_metadata, "_rtp_sparse_capture_meta_workspace", None ) @@ -1023,37 +964,9 @@ def _build_atom_sparse_metadata( reduce_final_map = cached_capture_meta["reduce_final_map"] reduce_partial_map = cached_capture_meta["reduce_partial_map"] reuse_capture_metadata = True - requested_page_size = self._sparse_page_size kv_token_slots = self._kv_token_slot_capacity(kv_cache_base) - page_size = requested_page_size - if ( - in_capture - and requested_page_size > 1 - and kv_token_slots % int(requested_page_size) != 0 - ): - # CUDA graph capture uses warmup shapes that may not be page-aligned; - # keep capture alive by using token page mode. - page_size = 1 - if page_size > 1: - max_page_slots = max( - (kv_token_slots + int(page_size) - 1) // int(page_size), - 1, - ) - if in_capture: - # Capture-safe: avoid host sync (e.g. item/min/max) in graph capture. - if int(paged_kv_indices.numel()) > 0: - paged_kv_indices.floor_divide_(int(page_size)) - paged_kv_indices.clamp_(min=0, max=max_page_slots - 1) - else: - used_now = int(paged_kv_indptr[-1].item()) if num_tokens > 0 else 0 - if used_now > 0: - paged_kv_indices[:used_now] = self._to_page_indices( - token_indices=paged_kv_indices[:used_now], - page_size=page_size, - max_slots=max_page_slots, - ) - else: - max_page_slots = int(kv_token_slots) + page_size = 1 + max_page_slots = int(kv_token_slots) if in_capture and int(paged_kv_indices.numel()) > 0: # Capture path cannot run host-synced range checks; clamp indices into @@ -1091,7 +1004,7 @@ def _build_atom_sparse_metadata( dtype_q=q_dtype, dtype_kv=kv_dtype, ) - if in_capture and self._enable_capture_metadata_reuse: + if in_capture: plugin_metadata._rtp_sparse_capture_meta_workspace = { "signature": capture_meta_sig, "work_meta_data": work_meta_data, @@ -1160,7 +1073,7 @@ def _run_aiter_sparse_decode( block_size=block_size, ) in_capture = torch.cuda.is_current_stream_capturing() - page_size = int(sparse_meta.page_size) + page_size = 1 if sparse_meta.head_repeat_factor > 1: if in_capture and self._cg_sparse_bufs is not None: q_for_kernel = self._cg_sparse_bufs["q_for_kernel"][ @@ -1193,29 +1106,7 @@ def _run_aiter_sparse_decode( device=q_latent.device, ) try: - if page_size == 1: - kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) - else: - kv_slots = int(kv_cache_base.shape[0]) - padded_slots = ( - (kv_slots + int(page_size) - 1) // int(page_size) - ) * int(page_size) - if padded_slots != kv_slots: - if in_capture: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA kv buffer cannot be reshaped by " - "page_size during capture " - f"(kv_slots={kv_slots}, page_size={page_size})." - ) - pad_shape = list(kv_cache_base.shape) - pad_shape[0] = padded_slots - kv_slots - pad = torch.zeros( - pad_shape, - dtype=kv_cache_base.dtype, - device=kv_cache_base.device, - ) - kv_cache_base = torch.cat((kv_cache_base, pad), dim=0) - kv_buffer = kv_cache_base.reshape(-1, page_size, 1, latent_dim) + kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) if ( not in_capture and self._enable_sparse_validate From 0afe6873922bf6f258aead449f145d097465a99c Mon Sep 17 00:00:00 2001 From: Zhao An Date: Mon, 8 Jun 2026 07:21:03 +0000 Subject: [PATCH 09/20] refactor: [RTP] unify GLM5 MLA on sparse path, drop dead dense backend --- .../rtpllm/attention_backend/__init__.py | 12 +- .../rtp_dense_mla_backend.py | 520 -------- .../attention_backend/rtp_mla_attention.py | 34 +- .../rtp_sparse_mla_backend.py | 270 +++-- atom/plugin/rtpllm/models/glm5.py | 31 +- .../test_rtpllm_forward_context_semantics.py | 16 +- .../test_rtpllm_glm5_indexer_contract.py | 40 +- .../test_rtpllm_glm5_mla_forward_contract.py | 1053 ----------------- ...est_rtpllm_glm5_sparse_backend_contract.py | 128 +- 9 files changed, 296 insertions(+), 1808 deletions(-) delete mode 100644 atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py delete mode 100644 tests/plugin/test_rtpllm_glm5_mla_forward_contract.py diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index 9a157a3f71..053a7e3d47 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,7 +1,4 @@ -from .attention_gdn import apply_attention_gdn_rtpllm_patch -from .attention_switch import apply_attention_mha_rtpllm_patch from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention -from .rtp_dense_mla_backend import RTPDenseMlaBackend from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch from .rtp_mla_metadata import ( GLM5_RTP_BRIDGE_MODE, @@ -19,13 +16,20 @@ def __getattr__(name): return {"RTPAttention": RTPAttention, "RTPFullAttention": RTPFullAttention}[ name ] + if name == "apply_attention_gdn_rtpllm_patch": + from .attention_gdn import apply_attention_gdn_rtpllm_patch + + return apply_attention_gdn_rtpllm_patch + if name == "apply_attention_mha_rtpllm_patch": + from .attention_switch import apply_attention_mha_rtpllm_patch + + return apply_attention_mha_rtpllm_patch raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ "AttentionForRTPLLM", "RTPFullAttention", - "RTPDenseMlaBackend", "RTPMLAAttention", "RTPSparseMlaBackend", "GLM5_RTP_BRIDGE_MODE", diff --git a/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py deleted file mode 100644 index 28f9903f2f..0000000000 --- a/atom/plugin/rtpllm/attention_backend/rtp_dense_mla_backend.py +++ /dev/null @@ -1,520 +0,0 @@ -"""Dense MLA fallback for GLM5 rtp-llm plugin mode.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Optional - -import torch - - -_FP8_CACHE_DTYPES = tuple( - dtype - for dtype in ( - getattr(torch, "float8_e4m3fnuz", None), - getattr(torch, "float8_e4m3fn", None), - torch.uint8, - ) - if dtype is not None -) - - -def _raise_cache_error(message: str) -> None: - raise RuntimeError(message) - - -@dataclass(frozen=True) -class _DenseMlaMetadata: - query_start_loc: torch.Tensor - seq_lens: torch.Tensor | None - block_table: torch.Tensor | None - slot_mapping: torch.Tensor | None - is_prefill: bool - block_size: int - - -class RTPDenseMlaBackend: - """Small dense MLA backend used before the sparse kernel is wired. - - This backend intentionally avoids vLLM plugin metadata. It consumes the - native GLM5 five-tuple already prepared by DeepseekV2MLAAttention and uses - RTPForwardContext metadata only to recover per-sequence token ranges. - """ - - def __init__(self, *, mla_modules: Any) -> None: - self.mla_modules = mla_modules - self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) - self.rotary_emb = getattr(mla_modules, "rotary_emb", None) - self.v_head_dim = int(getattr(mla_modules, "v_head_dim")) - self.qk_nope_head_dim = getattr(mla_modules, "qk_nope_head_dim", None) - self.qk_rope_head_dim = getattr(mla_modules, "qk_rope_head_dim", None) - self._projection_checked = False - - def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 - del attn_inputs - - def prewarm_for_cuda_graph( - self, - *, - max_num_tokens: int, - max_seq_len: int, - query_dtype: torch.dtype, - device: torch.device, - ) -> None: - del max_num_tokens, max_seq_len, query_dtype, device - - @staticmethod - def _read_is_prefill(context: Any) -> bool: - if context is None or not hasattr(context, "is_prefill"): - raise ValueError( - "GLM5 RTP dense MLA requires explicit context.is_prefill metadata." - ) - return bool(getattr(context, "is_prefill")) - - @staticmethod - def _get_metadata(num_tokens: int, device: torch.device) -> _DenseMlaMetadata: - attn_metadata = None - context = None - rtp_seq_size_per_block = 0 - try: - from atom.utils.forward_context import get_forward_context - - forward_context = get_forward_context() - attn_metadata = getattr(forward_context, "attn_metadata", None) - context = getattr(forward_context, "context", None) - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) - rtp_seq_size_per_block = int( - getattr(plugin_metadata, "sparse_block_size", 0) - or getattr(attn_metadata, "rtp_seq_size_per_block", 0) - or 0 - ) - query_start_loc = getattr(plugin_metadata, "query_start_loc", None) - if query_start_loc is None: - query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) - if query_start_loc is None: - query_start_loc = getattr(attn_metadata, "cu_seqlens_q", None) - if query_start_loc is None: - decode_metadata = getattr(plugin_metadata, "decode_metadata", None) - query_start_loc = getattr(decode_metadata, "query_start_loc", None) - seq_lens = getattr(plugin_metadata, "seq_lens", None) - if seq_lens is None: - seq_lens = getattr(attn_metadata, "context_lens", None) - block_table = getattr(plugin_metadata, "block_table", None) - if block_table is None: - block_table = getattr(attn_metadata, "block_tables", None) - slot_mapping = getattr(plugin_metadata, "slot_mapping", None) - if slot_mapping is None: - slot_mapping = getattr(attn_metadata, "slot_mapping", None) - except Exception: - query_start_loc = None - seq_lens = None - block_table = None - slot_mapping = None - - if ( - context is not None - and hasattr(context, "is_prefill") - and not bool(getattr(context, "is_prefill")) - and isinstance(seq_lens, torch.Tensor) - and int(seq_lens.numel()) == num_tokens - and isinstance(block_table, torch.Tensor) - and isinstance(slot_mapping, torch.Tensor) - ): - query_start_loc = torch.arange( - num_tokens + 1, dtype=torch.int64, device=device - ) - - if query_start_loc is not None and int(query_start_loc.numel()) >= 2: - query_start_loc = query_start_loc.to(device=device, dtype=torch.int64) - if int(query_start_loc[0].item()) == 0 and int(query_start_loc[-1].item()) == num_tokens: - is_prefill = RTPDenseMlaBackend._read_is_prefill(context) - return _DenseMlaMetadata( - query_start_loc=query_start_loc, - seq_lens=( - seq_lens.to(device=device, dtype=torch.int64) - if isinstance(seq_lens, torch.Tensor) - else None - ), - block_table=( - block_table.to(device=device, dtype=torch.int64) - if isinstance(block_table, torch.Tensor) - else None - ), - slot_mapping=( - slot_mapping.to(device=device, dtype=torch.int64) - if isinstance(slot_mapping, torch.Tensor) - else None - ), - is_prefill=is_prefill, - block_size=max(1, rtp_seq_size_per_block), - ) - if num_tokens != 1: - raise ValueError( - "GLM5 RTP dense MLA requires query_start_loc metadata for " - f"multi-token batches (num_tokens={num_tokens})." - ) - is_prefill = RTPDenseMlaBackend._read_is_prefill(context) - return _DenseMlaMetadata( - query_start_loc=torch.tensor([0, num_tokens], dtype=torch.int64, device=device), - seq_lens=None, - block_table=None, - slot_mapping=None, - is_prefill=is_prefill, - block_size=max(1, rtp_seq_size_per_block), - ) - - @staticmethod - def _unwrap_linear_output(value: Any) -> torch.Tensor: - if isinstance(value, tuple): - value = value[0] - if not isinstance(value, torch.Tensor): - raise TypeError(f"Expected kv_b_proj to return Tensor, got {type(value)!r}.") - return value - - def _apply_current_rope( - self, - q: torch.Tensor, - k_pe: torch.Tensor, - positions: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - rope_dim = int(self.qk_rope_head_dim or k_pe.shape[-1]) - if rope_dim == 0: - return q, k_pe - if self.rotary_emb is None: - raise ValueError("GLM5 RTP dense MLA requires rotary_emb for RoPE dimensions.") - if positions is None or int(positions.numel()) != int(q.shape[0]): - got = None if positions is None else int(positions.numel()) - raise ValueError( - "GLM5 RTP dense MLA requires per-token absolute positions for RoPE " - f"(positions={got}, tokens={int(q.shape[0])})." - ) - if int(q.shape[-1]) < rope_dim: - raise ValueError( - f"GLM5 RTP dense MLA invalid q shape for RoPE: q={tuple(q.shape)}, " - f"rope_dim={rope_dim}." - ) - - q_rope = q.clone() - k_pe_rope = k_pe.clone() - # RotaryEmbedding.forward rotates the full tensor it receives. Passing - # only q_pe/k_pe is equivalent to the fused MLA path's nope-first layout. - rotated_q_pe, rotated_k_pe = self.rotary_emb( - positions.to(device=q.device, dtype=torch.long), - q_rope[..., -rope_dim:], - k_pe_rope, - ) - q_rope[..., -rope_dim:] = rotated_q_pe - return q_rope, rotated_k_pe - - def _project_kv( - self, - q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor] | None: - if self.kv_b_proj is None: - return None - _, num_heads, qk_head_dim = q.shape - num_kv_tokens = int(compressed_kv.shape[0]) - rope_dim = int(self.qk_rope_head_dim or k_pe.shape[-1]) - nope_dim = int(self.qk_nope_head_dim or (qk_head_dim - rope_dim)) - if nope_dim <= 0: - raise ValueError( - f"Invalid MLA qk dims: qk_head_dim={qk_head_dim}, rope_dim={rope_dim}." - ) - - compressed_kv = compressed_kv.contiguous() - kv_nope = self._unwrap_linear_output(self.kv_b_proj(compressed_kv)) - if kv_nope.numel() == 0: - raise ValueError("GLM5 RTP dense MLA kv_b_proj returned an empty tensor.") - expected_last_dim = num_heads * (nope_dim + self.v_head_dim) - if kv_nope.shape[-1] != expected_last_dim: - raise ValueError( - "GLM5 RTP dense MLA kv_b_proj output shape mismatch " - f"(got={tuple(kv_nope.shape)}, expected_last_dim={expected_last_dim}, " - f"num_heads={num_heads}, qk_nope_head_dim={nope_dim}, " - f"v_head_dim={self.v_head_dim})." - ) - if not self._projection_checked: - self._projection_checked = True - - kv_nope = kv_nope.reshape(num_kv_tokens, num_heads, nope_dim + self.v_head_dim) - k_nope, value = kv_nope.split([nope_dim, self.v_head_dim], dim=-1) - if k_pe.dim() == 2: - k_pe = k_pe.unsqueeze(1) - k_pe = k_pe.expand(num_kv_tokens, num_heads, rope_dim) - key = torch.cat((k_nope, k_pe), dim=-1) - return key, value - - @staticmethod - def _causal_attention( - q: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - query_start_loc: torch.Tensor, - scale: float, - ) -> torch.Tensor: - pieces: list[torch.Tensor] = [] - for start_tensor, end_tensor in zip(query_start_loc[:-1], query_start_loc[1:]): - start = int(start_tensor.item()) - end = int(end_tensor.item()) - if end <= start: - continue - q_seg = q[start:end].float() - k_seg = key[start:end].float() - v_seg = value[start:end].float() - scores = torch.einsum("tnd,snd->nts", q_seg, k_seg) * scale - seq_len = end - start - causal_mask = torch.ones( - (seq_len, seq_len), dtype=torch.bool, device=q.device - ).tril() - scores = scores.masked_fill(~causal_mask.unsqueeze(0), float("-inf")) - probs = torch.softmax(scores, dim=-1) - pieces.append(torch.einsum("nts,snd->tnd", probs, v_seg)) - if not pieces: - return value.new_empty((0, value.shape[1], value.shape[2])) - return torch.cat(pieces, dim=0) - - @staticmethod - def _cross_causal_attention( - q: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - ) -> torch.Tensor: - q_len = int(q.shape[0]) - k_len = int(key.shape[0]) - if q_len == 0: - return value.new_empty((0, value.shape[1], value.shape[2])) - if k_len < q_len: - raise ValueError( - f"GLM5 RTP dense MLA got invalid cross attention lengths: q={q_len}, k={k_len}." - ) - scores = torch.einsum("tnd,snd->nts", q.float(), key.float()) * scale - q_pos = torch.arange(q_len, device=q.device).unsqueeze(1) - k_pos = torch.arange(k_len, device=q.device).unsqueeze(0) - causal_mask = k_pos <= (k_len - q_len + q_pos) - scores = scores.masked_fill(~causal_mask.unsqueeze(0), float("-inf")) - probs = torch.softmax(scores, dim=-1) - return torch.einsum("nts,snd->tnd", probs, value.float()) - - @staticmethod - def _flatten_latent_cache( - layer_cache: Any, - *, - block_size: int, - kv_dim: int, - ) -> torch.Tensor | None: - kv_cache_base = getattr(layer_cache, "kv_cache_base", None) - if not isinstance(kv_cache_base, torch.Tensor) or kv_cache_base.numel() == 0: - return None - if kv_cache_base.dtype in _FP8_CACHE_DTYPES: - raise NotImplementedError( - "GLM5 RTP dense MLA reference path requires BF16/FP16 latent KV cache; " - "FP8 KV cache layout/dequant is not supported yet." - ) - if kv_cache_base.dim() == 3 and int(kv_cache_base.shape[-1]) == kv_dim: - return kv_cache_base.reshape(-1, kv_dim) - if kv_cache_base.dim() == 2 and int(kv_cache_base.shape[1]) % block_size == 0: - per_token_dim = int(kv_cache_base.shape[1]) // block_size - if per_token_dim == kv_dim: - return kv_cache_base.view(kv_cache_base.shape[0], block_size, kv_dim).reshape( - -1, kv_dim - ) - return None - - @staticmethod - def _write_current_to_cache( - *, - layer_cache: Any, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, - metadata: _DenseMlaMetadata, - kv_dim: int, - ) -> None: - if metadata.slot_mapping is None: - return - flat_cache = RTPDenseMlaBackend._flatten_latent_cache( - layer_cache, block_size=metadata.block_size, kv_dim=kv_dim - ) - if flat_cache is None: - return - latent = torch.cat((compressed_kv, k_pe), dim=-1) - if latent.shape[0] != metadata.slot_mapping.shape[0]: - return - slots = metadata.slot_mapping[: latent.shape[0]].long() - flat_size = int(flat_cache.shape[0]) - non_negative = slots >= 0 - in_bounds = non_negative & (slots < flat_size) - if bool((non_negative & (slots >= flat_size)).any().item()): - bad_slots = slots[non_negative & (slots >= flat_size)] - _raise_cache_error( - "GLM5 RTP dense MLA refuses to write out-of-bounds slot_mapping " - f"(block_size={metadata.block_size}, flat_tokens={flat_size}, " - f"slot_min={int(bad_slots.min().item())}, " - f"slot_max={int(bad_slots.max().item())})." - ) - if not bool(in_bounds.any().item()): - return - flat_cache[slots[in_bounds]] = latent[in_bounds].to(dtype=flat_cache.dtype) - - @staticmethod - def _resolve_layer_cache(kv_cache: object, layer_id: int) -> object: - if kv_cache is not None: - return kv_cache - try: - from atom.utils.forward_context import get_forward_context - - forward_context = get_forward_context() - kv_cache_data = getattr(forward_context, "kv_cache_data", None) - if kv_cache_data is None: - return None - layer_cache_entry = kv_cache_data.get(f"layer_{int(layer_id)}") - if layer_cache_entry is None: - return None - return getattr(layer_cache_entry, "k_cache", layer_cache_entry) - except Exception: - return None - - @staticmethod - def _gather_latent_history( - *, - layer_cache: Any, - metadata: _DenseMlaMetadata, - batch_idx: int, - kv_dim: int, - ) -> torch.Tensor | None: - if metadata.block_table is None or metadata.seq_lens is None: - return None - flat_cache = RTPDenseMlaBackend._flatten_latent_cache( - layer_cache, block_size=metadata.block_size, kv_dim=kv_dim - ) - if flat_cache is None: - return None - block_size = int(metadata.block_size) - kv_cache_base = getattr(layer_cache, "kv_cache_base", None) - if ( - block_size <= 1 - and isinstance(kv_cache_base, torch.Tensor) - and kv_cache_base.dim() == 3 - and int(kv_cache_base.shape[-1]) == kv_dim - ): - block_size = int(kv_cache_base.shape[1]) - seq_len = int(metadata.seq_lens[batch_idx].item()) - if seq_len <= 0: - return None - block_row = metadata.block_table[batch_idx].long() - positions = torch.arange(seq_len, dtype=torch.long, device=flat_cache.device) - block_cols = torch.div(positions, block_size, rounding_mode="floor") - block_col_max = int(block_cols.max().item()) - if block_col_max >= int(block_row.numel()): - return None - offsets = positions.remainder(block_size) - slots = block_row[block_cols] * block_size + offsets - flat_size = int(flat_cache.shape[0]) - if bool(((slots < 0) | (slots >= flat_size)).any().item()): - bad_slots = slots[(slots < 0) | (slots >= flat_size)] - _raise_cache_error( - "GLM5 RTP dense MLA refuses to gather out-of-bounds KV history " - f"(batch_idx={batch_idx}, seq_len={seq_len}, " - f"block_size={metadata.block_size}, flat_tokens={flat_size}, " - f"slot_min={int(bad_slots.min().item())}, " - f"slot_max={int(bad_slots.max().item())})." - ) - return flat_cache[slots] - - @staticmethod - def _require_decode_cache_metadata( - *, - layer_cache: Any, - metadata: _DenseMlaMetadata, - kv_dim: int, - ) -> None: - missing = [] - if metadata.block_table is None: - missing.append("block_table") - if metadata.seq_lens is None: - missing.append("seq_lens") - if metadata.slot_mapping is None: - missing.append("slot_mapping") - if missing: - raise ValueError( - "GLM5 RTP dense MLA decode requires RTP KV metadata: " - + ", ".join(missing) - + "." - ) - flat_cache = RTPDenseMlaBackend._flatten_latent_cache( - layer_cache, block_size=metadata.block_size, kv_dim=kv_dim - ) - if flat_cache is None: - raise ValueError( - "GLM5 RTP dense MLA decode requires a readable BF16/FP16 kv_cache_base." - ) - - def forward( - self, - q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, - kv_cache: object, - layer_id: int, - topk_indices: Optional[torch.Tensor] = None, - positions: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - del topk_indices - if self.kv_b_proj is None: - return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) - layer_cache = self._resolve_layer_cache(kv_cache, layer_id) - q, k_pe = self._apply_current_rope(q, k_pe, positions) - projected = self._project_kv(q, compressed_kv, k_pe) - if projected is None: - return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) - metadata = self._get_metadata(q.shape[0], q.device) - kv_dim = int(compressed_kv.shape[-1]) + int(k_pe.shape[-1]) - self._write_current_to_cache( - layer_cache=layer_cache, - compressed_kv=compressed_kv, - k_pe=k_pe, - metadata=metadata, - kv_dim=kv_dim, - ) - key, value = projected - query_start_loc = metadata.query_start_loc - scale = float(q.shape[-1] ** -0.5) - if metadata.is_prefill: - output = self._causal_attention(q, key, value, query_start_loc, scale) - return output.to(dtype=compressed_kv.dtype) - - self._require_decode_cache_metadata( - layer_cache=layer_cache, - metadata=metadata, - kv_dim=kv_dim, - ) - pieces: list[torch.Tensor] = [] - for batch_idx, (start_tensor, end_tensor) in enumerate( - zip(query_start_loc[:-1], query_start_loc[1:]) - ): - start = int(start_tensor.item()) - end = int(end_tensor.item()) - if end <= start: - continue - q_seg = q[start:end] - latent_history = self._gather_latent_history( - layer_cache=layer_cache, - metadata=metadata, - batch_idx=batch_idx, - kv_dim=kv_dim, - ) - if latent_history is None: - raise ValueError( - "GLM5 RTP dense MLA decode failed to gather latent KV history." - ) - hist_compressed_kv, hist_k_pe = latent_history.split( - [compressed_kv.shape[-1], k_pe.shape[-1]], dim=-1 - ) - hist_key, hist_value = self._project_kv(q_seg, hist_compressed_kv, hist_k_pe) - pieces.append( - self._cross_causal_attention(q_seg, hist_key, hist_value, scale) - ) - output = torch.cat(pieces, dim=0) if pieces else value.new_empty((0, *value.shape[1:])) - return output.to(dtype=compressed_kv.dtype) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index b6b6a410d4..fd1d30cd6c 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -46,14 +46,6 @@ def _should_emit_topk_indices(attn) -> bool: context = getattr(forward_context, "context", None) if getattr(context, "is_dummy_run", False): return False - attn_metadata = getattr(forward_context, "attn_metadata", None) - if getattr(context, "is_prefill", False) and attn_metadata is not None: - max_seqlen_k = getattr(attn_metadata, "max_seqlen_k", None) - if max_seqlen_k is not None: - try: - return int(max_seqlen_k) > _get_topk_indices_buffer(attn).shape[1] - except AttributeError: - return True return True @@ -65,7 +57,7 @@ def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None: class RTPMLAAttention: - """Dense RTP MLA adapter for the native GLM5 MLA call contract.""" + """RTP MLA adapter for the native GLM5 MLA call contract.""" use_mla = True @@ -91,30 +83,26 @@ def __init__(self, *args, **kwargs) -> None: if self.indexer is not None else None ) - injected_backend = kwargs.get("dense_backend") + injected_backend = kwargs.get("sparse_backend", kwargs.get("dense_backend")) if injected_backend is not None: - self.dense_backend = injected_backend + self.sparse_backend = injected_backend elif mla_modules is not None: - from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( - RTPDenseMlaBackend, - ) from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( RTPSparseMlaBackend, ) - self.dense_backend = RTPSparseMlaBackend( - dense_backend=RTPDenseMlaBackend(mla_modules=mla_modules), + self.sparse_backend = RTPSparseMlaBackend( v_head_dim=mla_modules.v_head_dim, mla_modules=mla_modules, scale=kwargs.get("scale"), ) else: - self.dense_backend = None + self.sparse_backend = None self.kv_cache = kwargs.get("kv_cache") self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0))) - self._dense_backend_accepts_positions = ( - self._backend_accepts_positions(self.dense_backend) - if self.dense_backend is not None + self._sparse_backend_accepts_positions = ( + self._backend_accepts_positions(self.sparse_backend) + if self.sparse_backend is not None else False ) @@ -179,7 +167,7 @@ def forward( topk_indices: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - if self.dense_backend is None: + if self.sparse_backend is None: raise NotImplementedError( "RTPMLAAttention requires an attention backend for contract execution" ) @@ -191,9 +179,9 @@ def forward( kwargs.get("topk_indices", topk_indices), ) forward_kwargs = {"topk_indices": topk_indices} - if self._dense_backend_accepts_positions: + if self._sparse_backend_accepts_positions: forward_kwargs["positions"] = positions - attn_output = self.dense_backend.forward( + attn_output = self.sparse_backend.forward( q, compressed_kv, k_pe, diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 07cd089c8f..6dc763a56a 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -860,16 +860,48 @@ def _build_atom_sparse_metadata( paged_kv_last_page_len = sparse_bufs["paged_kv_last_page_len"][:num_tokens] paged_kv_indices = sparse_bufs["paged_kv_indices"][: num_tokens * topk] else: - qo_indptr = torch.arange(num_tokens + 1, device=device, dtype=torch.int32) - paged_kv_indptr = torch.zeros( - (num_tokens + 1,), device=device, dtype=torch.int32 + eager_sig = ( + int(num_tokens), + int(topk), + str(device), ) - paged_kv_last_page_len = torch.ones( - (num_tokens,), device=device, dtype=torch.int32 - ) - paged_kv_indices = torch.zeros( - (num_tokens * topk,), device=device, dtype=torch.int32 + cached_eager = getattr(plugin_metadata, "_rtp_sparse_eager_workspace", None) + if ( + isinstance(cached_eager, dict) + and cached_eager.get("signature") == eager_sig + ): + qo_indptr = cached_eager["qo_indptr"] + paged_kv_indptr = cached_eager["paged_kv_indptr"] + paged_kv_last_page_len = cached_eager["paged_kv_last_page_len"] + paged_kv_indices = cached_eager["paged_kv_indices"] + else: + qo_indptr = torch.empty( + num_tokens + 1, device=device, dtype=torch.int32 + ) + paged_kv_indptr = torch.empty( + num_tokens + 1, device=device, dtype=torch.int32 + ) + paged_kv_last_page_len = torch.empty( + num_tokens, device=device, dtype=torch.int32 + ) + paged_kv_indices = torch.empty( + num_tokens * topk, device=device, dtype=torch.int32 + ) + try: + plugin_metadata._rtp_sparse_eager_workspace = { + "signature": eager_sig, + "qo_indptr": qo_indptr, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_last_page_len": paged_kv_last_page_len, + "paged_kv_indices": paged_kv_indices, + } + except Exception: + pass + qo_indptr.copy_( + torch.arange(num_tokens + 1, device=device, dtype=torch.int32) ) + paged_kv_indptr.zero_() + paged_kv_last_page_len.fill_(1) torch.cumsum(sparse_seqlen, dim=0, out=paged_kv_indptr[1:]) if not in_capture and int(block_size) <= 0: @@ -903,43 +935,79 @@ def _build_atom_sparse_metadata( reduce_final_map = sparse_bufs["reduce_final_map"] reduce_partial_map = sparse_bufs["reduce_partial_map"] else: - metadata_budget_tokens = self._metadata_token_budget( - num_tokens=num_tokens, topk=topk - ) - ( - (work_meta_data_size, work_meta_data_type), - (work_indptr_size, work_indptr_type), - (work_info_set_size, work_info_set_type), - (reduce_indptr_size, reduce_indptr_type), - (reduce_final_map_size, reduce_final_map_type), - (reduce_partial_map_size, reduce_partial_map_type), - ) = get_mla_metadata_info_v1( - metadata_budget_tokens, - 1, - padded_num_heads, - q_dtype, - kv_dtype, - is_sparse=True, - fast_mode=True, + eager_meta_sig = ( + int(num_tokens), + int(topk), + int(padded_num_heads), + str(q_dtype), + str(kv_dtype), + str(device), ) - work_meta_data = torch.empty( - work_meta_data_size, dtype=work_meta_data_type, device=device - ) - work_indptr = torch.empty( - work_indptr_size, dtype=work_indptr_type, device=device - ) - work_info_set = torch.empty( - work_info_set_size, dtype=work_info_set_type, device=device - ) - reduce_indptr = torch.empty( - reduce_indptr_size, dtype=reduce_indptr_type, device=device - ) - reduce_final_map = torch.empty( - reduce_final_map_size, dtype=reduce_final_map_type, device=device - ) - reduce_partial_map = torch.empty( - reduce_partial_map_size, dtype=reduce_partial_map_type, device=device + cached_eager_meta = getattr( + plugin_metadata, "_rtp_sparse_eager_meta_workspace", None ) + if ( + isinstance(cached_eager_meta, dict) + and cached_eager_meta.get("signature") == eager_meta_sig + ): + work_meta_data = cached_eager_meta["work_meta_data"] + work_indptr = cached_eager_meta["work_indptr"] + work_info_set = cached_eager_meta["work_info_set"] + reduce_indptr = cached_eager_meta["reduce_indptr"] + reduce_final_map = cached_eager_meta["reduce_final_map"] + reduce_partial_map = cached_eager_meta["reduce_partial_map"] + else: + metadata_budget_tokens = self._metadata_token_budget( + num_tokens=num_tokens, topk=topk + ) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + metadata_budget_tokens, + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) + reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + reduce_partial_map = torch.empty( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=device, + ) + try: + plugin_metadata._rtp_sparse_eager_meta_workspace = { + "signature": eager_meta_sig, + "work_meta_data": work_meta_data, + "work_indptr": work_indptr, + "work_info_set": work_info_set, + "reduce_indptr": reduce_indptr, + "reduce_final_map": reduce_final_map, + "reduce_partial_map": reduce_partial_map, + } + except Exception: + pass capture_meta_sig = ( int(num_tokens), int(topk), @@ -1265,18 +1333,18 @@ class RTPSparseMlaBackend: def __init__( self, *, - dense_backend: object, sparse_impl: Optional[object] = None, v_head_dim: Optional[int] = None, mla_modules: Optional[object] = None, scale: Optional[float] = None, ) -> None: - self.dense_backend = dense_backend - self.v_head_dim = int( - v_head_dim - if v_head_dim is not None - else getattr(dense_backend, "v_head_dim") - ) + if v_head_dim is None: + if mla_modules is None or not hasattr(mla_modules, "v_head_dim"): + raise ValueError( + "RTPSparseMlaBackend requires v_head_dim or mla_modules.v_head_dim." + ) + v_head_dim = getattr(mla_modules, "v_head_dim") + self.v_head_dim = int(v_head_dim) if sparse_impl is not None: self.sparse_impl = sparse_impl self._default_mock = False @@ -1302,9 +1370,6 @@ def __init__( self._sparse_impl_accepts_positions = self._impl_accepts_positions( self.sparse_impl ) - self._dense_forward_accepts_positions = self._call_accepts_positions( - getattr(self.dense_backend, "forward", None) - ) def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 del attn_inputs @@ -1317,14 +1382,6 @@ def prewarm_for_cuda_graph( query_dtype: torch.dtype, device: torch.device, ) -> None: - dense_prewarm = getattr(self.dense_backend, "prewarm_for_cuda_graph", None) - if callable(dense_prewarm): - dense_prewarm( - max_num_tokens=max_num_tokens, - max_seq_len=max_seq_len, - query_dtype=query_dtype, - device=device, - ) sparse_prewarm = getattr(self.sparse_impl, "prewarm_for_cuda_graph", None) if callable(sparse_prewarm): sparse_prewarm( @@ -1371,39 +1428,6 @@ def _impl_accepts_positions(impl: object) -> bool: for parameter in signature.parameters.values() ) - @staticmethod - def _call_accepts_positions(callable_obj: object) -> bool: - try: - signature = inspect.signature(callable_obj) - except (TypeError, ValueError): - return False - return "positions" in signature.parameters or any( - parameter.kind == inspect.Parameter.VAR_KEYWORD - for parameter in signature.parameters.values() - ) - - def _dense_forward( - self, - q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, - kv_cache: object, - layer_id: int, - topk_indices: Optional[torch.Tensor], - positions: Optional[torch.Tensor], - ) -> torch.Tensor: - kwargs = {"topk_indices": topk_indices} - if self._dense_forward_accepts_positions: - kwargs["positions"] = positions - return self.dense_backend.forward( - q, - compressed_kv, - k_pe, - kv_cache, - layer_id, - **kwargs, - ) - def forward( self, q: torch.Tensor, @@ -1421,17 +1445,9 @@ def forward( return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) if topk_indices is None: - plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) - num_prefills = int(getattr(plugin_metadata, "num_prefills", 0) or 0) - if num_prefills <= 0: - raise _SparseUnavailable( - "GLM5 RTP sparse MLA decode requires topk_indices; " - "refusing dense fallback." - ) - return self._dense_forward( - q, compressed_kv, k_pe, kv_cache, layer_id, None, positions + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires topk_indices; refusing dense fallback." ) - self._validate_topk_indices(q, topk_indices) if self._default_mock or not callable( getattr(self.sparse_impl, "forward", None) @@ -1486,6 +1502,52 @@ def rtp_sparse_attn_indexer( is_neox_style: bool, use_qk_rope_cache_fusion: bool, ) -> torch.Tensor: + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + except Exception: + forward_context = None + context = getattr(forward_context, "context", None) + attn_metadata = getattr(forward_context, "attn_metadata", None) + # For short prefill (ctx <= topk buffer width), DeepSeek indexer returns early and + # doesn't write topk buffer. Emit causal full-history indices to keep sparse path valid. + if ( + context is not None + and bool(getattr(context, "is_prefill", False)) + and attn_metadata is not None + and topk_indices_buffer is not None + and positions is not None + ): + max_seqlen_k = int(getattr(attn_metadata, "max_seqlen_k", 0) or 0) + topk_capacity = int(topk_indices_buffer.shape[1]) + if max_seqlen_k > 0 and max_seqlen_k <= topk_capacity: + num_tokens = int(hidden_states.shape[0]) + if num_tokens > 0: + positions_i32 = positions.to( + device=topk_indices_buffer.device, dtype=torch.int32 + ).view(-1) + row_limits = ( + (positions_i32 + 1).clamp(min=0, max=topk_tokens).view(-1, 1) + ) + col_ids = torch.arange( + topk_tokens, + device=topk_indices_buffer.device, + dtype=torch.int32, + ).view(1, -1) + causal_topk = torch.where( + col_ids < row_limits, + col_ids.expand(num_tokens, topk_tokens), + torch.full( + (num_tokens, topk_tokens), + -1, + device=topk_indices_buffer.device, + dtype=torch.int32, + ), + ) + topk_indices_buffer[:num_tokens, :topk_tokens].copy_(causal_topk) + return weights + from atom.models.deepseek_v2 import sparse_attn_indexer return sparse_attn_indexer( diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index 5783d081b2..537491ed65 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -48,7 +48,6 @@ def __init__(self, runtime: "_ATOMGlm5MoeRuntime") -> None: self.is_cuda_graph = False self._rtp_mla_layers: list[Any] = [] self._rtp_sparse_mla_backends: list[Any] = [] - self._rtp_dense_mla_backends: list[Any] = [] self._collect_mla_layers() @staticmethod @@ -59,12 +58,10 @@ def _append_unique(items: list[Any], value: Any) -> None: def _collect_mla_layers(self) -> None: try: from atom.plugin.rtpllm.attention_backend import ( - RTPDenseMlaBackend, RTPMLAAttention, RTPSparseMlaBackend, ) except (ImportError, ModuleNotFoundError): - RTPDenseMlaBackend = None RTPMLAAttention = None RTPSparseMlaBackend = None @@ -80,9 +77,9 @@ def _collect_mla_layers(self) -> None: for candidate in candidates: if RTPMLAAttention is not None and isinstance(candidate, RTPMLAAttention): self._append_unique(self._rtp_mla_layers, candidate) - backend = getattr(candidate, "dense_backend", None) + backend = getattr(candidate, "sparse_backend", None) else: - backend = getattr(candidate, "dense_backend", None) + backend = getattr(candidate, "sparse_backend", None) if ( backend is None and RTPSparseMlaBackend is not None @@ -94,15 +91,6 @@ def _collect_mla_layers(self) -> None: backend, RTPSparseMlaBackend ): self._append_unique(self._rtp_sparse_mla_backends, backend) - dense_backend = getattr(backend, "dense_backend", None) - if RTPDenseMlaBackend is not None and isinstance( - dense_backend, RTPDenseMlaBackend - ): - self._append_unique(self._rtp_dense_mla_backends, dense_backend) - elif RTPDenseMlaBackend is not None and isinstance( - backend, RTPDenseMlaBackend - ): - self._append_unique(self._rtp_dense_mla_backends, backend) @property def fmha_params(self): @@ -113,7 +101,7 @@ def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 prepare = getattr(layer, "prepare_cuda_graph", None) if callable(prepare): prepare(attn_inputs) - for backend in self._rtp_sparse_mla_backends + self._rtp_dense_mla_backends: + for backend in self._rtp_sparse_mla_backends: prepare = getattr(backend, "prepare_cuda_graph", None) if callable(prepare): prepare(attn_inputs) @@ -235,16 +223,6 @@ def _ensure_cuda_graph_prewarmed(self) -> None: query_dtype=dtype, device=device, ) - for backend in self._atom_attn_pyobj._rtp_dense_mla_backends: - prewarm = getattr(backend, "prewarm_for_cuda_graph", None) - if callable(prewarm): - prewarm( - max_num_tokens=max_num_tokens, - max_seq_len=max_seq_len, - query_dtype=dtype, - device=device, - ) - self._cg_meta_bufs: dict[str, torch.Tensor] = { "query_start_loc": torch.arange( 0, max_num_tokens + 1, device=device, dtype=torch.int32 @@ -289,13 +267,12 @@ def _ensure_cuda_graph_prewarmed(self) -> None: self._cg_layers_prewarmed = True logger.info( "ATOM GLM5 cuda-graph prewarmed " - "(max_num_tokens=%d, max_seq_len=%d, sparse_layers=%d, dense_layers=%d, " + "(max_num_tokens=%d, max_seq_len=%d, sparse_layers=%d, " "physical_block_table_i32[%dx%d], block_table_i32[%dx%d], " "indexer_block_table_i32[%dx%d])", max_num_tokens, max_seq_len, len(self._atom_attn_pyobj._rtp_sparse_mla_backends), - len(self._atom_attn_pyobj._rtp_dense_mla_backends), max_num_tokens, recovered_physical_max_blocks, max_num_tokens, diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index af4f03ad6a..5f1761d8cc 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -339,7 +339,7 @@ def test_rtpllm_decode_seq_lens_uses_rtp_plus_one_in_graph_and_eager_modes(): def test_collect_layer_maps_keeps_mla_layers_separate(): from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) model = SimpleNamespace(modules=lambda: [mla_layer]) gdn_map, full_attn_map, mla_map = RTPForwardContext.collect_layer_maps(model) @@ -352,7 +352,7 @@ def test_collect_layer_maps_keeps_mla_layers_separate(): def test_collect_layer_maps_keeps_sparse_mla_owner_for_indexer_cache(): from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) sparse_owner = SimpleNamespace( layer_num=7, indexer=SimpleNamespace(), @@ -370,7 +370,7 @@ def test_collect_layer_maps_keeps_sparse_mla_owner_for_indexer_cache(): def test_collect_layer_maps_recognizes_atom_mla_wrapper_by_indexer_and_mla_attn(): from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - inner_mla = RTPMLAAttention(dense_backend=object(), layer_num=9) + inner_mla = RTPMLAAttention(sparse_backend=object(), layer_num=9) atom_wrapper = SimpleNamespace( layer_num=9, indexer=SimpleNamespace(), @@ -390,7 +390,7 @@ def test_build_kv_cache_tensors_threads_raw_layer_cache_for_mla(): runtime = SimpleNamespace( kv_cache=SimpleNamespace(get_layer_cache=lambda layer_num: layer_cache) ) - mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7) + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) cache_tensors = RTPForwardContext._build_kv_cache_tensors( runtime=runtime, @@ -406,7 +406,9 @@ def test_bind_temporarily_attaches_mla_layer_cache(monkeypatch): old_cache = SimpleNamespace(name="old-cache") new_cache = SimpleNamespace(name="new-cache") - mla_layer = RTPMLAAttention(dense_backend=object(), layer_num=7, kv_cache=old_cache) + mla_layer = RTPMLAAttention( + sparse_backend=object(), layer_num=7, kv_cache=old_cache + ) forward_context = SimpleNamespace( attn_metadata=SimpleNamespace(), gdn_metadata=SimpleNamespace(), @@ -451,7 +453,7 @@ def test_bind_writes_kv_cache_to_mla_attn_owner_not_outer_wrapper(monkeypatch): k_cache=SimpleNamespace(kv_cache=[torch.empty(0)]), ) mla_layer = RTPMLAAttention( - dense_backend=object(), + sparse_backend=object(), layer_num=7, kv_cache=old_inner_cache, ) @@ -503,7 +505,7 @@ def test_bind_temporarily_attaches_sparse_mla_indexer_cache(monkeypatch): k_cache=SimpleNamespace(kv_cache=[old_index_cache]), ) mla_layer = RTPMLAAttention( - dense_backend=object(), + sparse_backend=object(), layer_num=7, kv_cache=old_cache, mla_modules=SimpleNamespace(indexer=indexer), diff --git a/tests/plugin/test_rtpllm_glm5_indexer_contract.py b/tests/plugin/test_rtpllm_glm5_indexer_contract.py index 84d9841cca..8369a54546 100644 --- a/tests/plugin/test_rtpllm_glm5_indexer_contract.py +++ b/tests/plugin/test_rtpllm_glm5_indexer_contract.py @@ -8,7 +8,6 @@ from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - _FORBIDDEN_CUDA_SPARSE_MODULES = ( "flashmla_sparse", "flash_mla", @@ -22,13 +21,15 @@ def _guard_sparse_kernel_imports(monkeypatch): def _guarded_import(name, *args, **kwargs): if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): - raise AssertionError(f"M1.5 tests must not import sparse MLA kernels: {name}") + raise AssertionError( + f"M1.5 tests must not import sparse MLA kernels: {name}" + ) return original_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", _guarded_import) -class _FakeDenseBackend: +class _FakeSparseBackend: def __init__(self, v_head_dim: int): self.v_head_dim = v_head_dim self.calls = [] @@ -56,9 +57,9 @@ def __init__(self, topk_values): -1, dtype=torch.int32, ) - self.topk_indices_buffer[ - : topk_values.shape[0], : topk_values.shape[1] - ].copy_(topk_values) + self.topk_indices_buffer[: topk_values.shape[0], : topk_values.shape[1]].copy_( + topk_values + ) self.weights = torch.full(topk_values.shape, 99.0, dtype=torch.float32) def __call__(self, *args, **kwargs): @@ -93,7 +94,7 @@ def _make_attention(topk_values): projected_q = torch.arange( token_count * num_heads * qk_head_dim, dtype=torch.float32 ).reshape(token_count, num_heads * qk_head_dim) - backend = _FakeDenseBackend(v_head_dim=v_head_dim) + backend = _FakeSparseBackend(v_head_dim=v_head_dim) indexer = _FakeIndexer(topk_values) modules = SimpleNamespace( q_proj=_FakeQProj(projected_q), @@ -108,7 +109,7 @@ def _make_attention(topk_values): ) attention = RTPMLAAttention( mla_modules=modules, - dense_backend=backend, + sparse_backend=backend, layer_num=7, kv_cache="kv-cache", ) @@ -134,7 +135,9 @@ def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): default_op = object() rtp_op = object() - monkeypatch.setattr(torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False) + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) indexer = SimpleNamespace( topk_indices_buffer=topk_buffer, @@ -149,7 +152,7 @@ def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): v_head_dim=3, ) - attention = RTPMLAAttention(mla_modules=modules, dense_backend=object()) + attention = RTPMLAAttention(mla_modules=modules, sparse_backend=object()) assert attention.indexer is indexer assert indexer.sparse_attn_indexer_impl is rtp_op @@ -168,7 +171,7 @@ def _run_attention(attention, token_count: int): ) -def test_indexer_buffer_topk_is_passed_to_dense_backend_when_emit_allowed(monkeypatch): +def test_indexer_buffer_topk_is_passed_to_sparse_backend_when_emit_allowed(monkeypatch): _guard_sparse_kernel_imports(monkeypatch) topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) attention, modules, backend = _make_attention(topk_values) @@ -200,7 +203,7 @@ def _patch_forward_context(monkeypatch, *, is_dummy_run, is_prefill, max_seqlen_ ) -def test_dummy_run_does_not_emit_topk_to_dense_backend(monkeypatch): +def test_dummy_run_does_not_emit_topk_to_sparse_backend(monkeypatch): _guard_sparse_kernel_imports(monkeypatch) _patch_forward_context( monkeypatch, @@ -217,7 +220,7 @@ def test_dummy_run_does_not_emit_topk_to_dense_backend(monkeypatch): assert backend.calls[0]["topk_indices"] is None -def test_short_prefill_does_not_emit_topk_to_dense_backend(monkeypatch): +def test_short_prefill_emits_topk_to_sparse_backend(monkeypatch): _guard_sparse_kernel_imports(monkeypatch) _patch_forward_context( monkeypatch, @@ -231,10 +234,12 @@ def test_short_prefill_does_not_emit_topk_to_dense_backend(monkeypatch): _run_attention(attention, token_count=topk_values.shape[0]) assert modules.indexer.calls == [] - assert backend.calls[0]["topk_indices"] is None + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) -def test_prefill_within_topk_buffer_padding_does_not_emit_topk(monkeypatch): +def test_prefill_within_topk_buffer_padding_still_emits_topk(monkeypatch): _guard_sparse_kernel_imports(monkeypatch) _patch_forward_context( monkeypatch, @@ -250,5 +255,6 @@ def test_prefill_within_topk_buffer_padding_does_not_emit_topk(monkeypatch): assert modules.indexer.index_topk == 4 assert modules.indexer.topk_indices_buffer.shape[1] == 6 assert modules.indexer.calls == [] - assert backend.calls[0]["topk_indices"] is None - + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) diff --git a/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py b/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py deleted file mode 100644 index a4088fd106..0000000000 --- a/tests/plugin/test_rtpllm_glm5_mla_forward_contract.py +++ /dev/null @@ -1,1053 +0,0 @@ -"""Contract-executable tests for GLM5 RTP MLA native forward.""" - -import builtins -import importlib -import inspect -from types import SimpleNamespace - -import pytest -import torch - -from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - - -_FORBIDDEN_CUDA_SPARSE_MODULES = ( - "flashmla_sparse", - "flash_mla", - "sparse_mla", - "attention_mla_sparse", -) - - -class _FakeDenseBackend: - def __init__(self, v_head_dim: int): - self.v_head_dim = v_head_dim - self.calls = [] - - def forward( - self, - q, - compressed_kv, - k_pe, - kv_cache, - layer_id, - topk_indices=None, - positions=None, - ): - self.calls.append( - { - "q": q, - "compressed_kv": compressed_kv, - "k_pe": k_pe, - "kv_cache": kv_cache, - "layer_id": layer_id, - "topk_indices": topk_indices, - "positions": positions, - } - ) - return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) - - -def test_rtp_mla_attention_keeps_legacy_dense_boundary_during_migration(): - backend = _FakeDenseBackend(v_head_dim=16) - attention = RTPMLAAttention(dense_backend=backend, layer_id=7, kv_cache="cache") - q = torch.empty(3, 2, 12, dtype=torch.bfloat16) - compressed_kv = torch.empty(3, 8, dtype=torch.bfloat16) - k_pe = torch.empty(3, 4, dtype=torch.bfloat16) - positions = torch.arange(3, dtype=torch.int32) - - output = attention.forward( - q, - compressed_kv, - k_pe, - positions=positions, - topk_indices=None, - ) - - assert output.shape == (3, 2, 16) - assert len(backend.calls) == 1 - call = backend.calls[0] - assert call["q"] is q - assert call["compressed_kv"] is compressed_kv - assert call["k_pe"] is k_pe - assert call["kv_cache"] == "cache" - assert call["layer_id"] == 7 - assert call["topk_indices"] is None - assert call["positions"] is positions - - -def _guard_sparse_kernel_imports(monkeypatch): - original_import = builtins.__import__ - - def _guarded_import(name, *args, **kwargs): - if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): - raise AssertionError(f"M1 dense contract must not import sparse kernel module: {name}") - return original_import(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", _guarded_import) - - -def test_rtp_mla_attention_accepts_explicit_topk_and_passes_it_to_dense_backend(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - attention = RTPMLAAttention(dense_backend=_FakeDenseBackend(v_head_dim=16)) - q = torch.empty(1, 2, 12) - compressed_kv = torch.empty(1, 8) - k_pe = torch.empty(1, 4) - positions = torch.arange(1, dtype=torch.int32) - topk = torch.tensor([[3, 1, 0, 2]], dtype=torch.int32) - - output = attention.forward( - q, - compressed_kv, - k_pe, - positions=positions, - topk_indices=topk, - ) - - assert output.shape == (1, 2, 16) - assert len(attention.dense_backend.calls) == 1 - assert attention.dense_backend.calls[0]["topk_indices"] is topk - - -def test_dense_backend_output_does_not_depend_on_topk_values(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - backend = _FakeDenseBackend(v_head_dim=16) - attention = RTPMLAAttention(dense_backend=backend) - q = torch.ones(2, 2, 12) - compressed_kv = torch.empty(2, 8) - k_pe = torch.empty(2, 4) - positions = torch.arange(2, dtype=torch.int32) - topk_a = torch.tensor([[3, 1, 0, 2], [2, 0, 1, 3]], dtype=torch.int32) - topk_b = torch.tensor([[0, 2, 1, 3], [3, 1, 2, 0]], dtype=torch.int32) - - out_a = attention.forward( - q, - compressed_kv, - k_pe, - positions=positions, - topk_indices=topk_a, - ) - out_b = attention.forward( - q, - compressed_kv, - k_pe, - positions=positions, - topk_indices=topk_b, - ) - - assert torch.equal(out_a, out_b) - assert backend.calls[0]["topk_indices"] is topk_a - assert backend.calls[1]["topk_indices"] is topk_b - - -def test_native_forward_signature_exposes_q_scale_argument(): - signature = inspect.signature(RTPMLAAttention.forward) - - assert "q_scale" in signature.parameters - - -@pytest.mark.parametrize("attr", ["q_proj", "o_proj", "kv_b_proj", "v_head_dim"]) -def test_constructor_injects_native_mla_module_attributes(attr): - modules = SimpleNamespace( - q_proj=object(), - o_proj=object(), - kv_b_proj=object(), - v_head_dim=16, - ) - attention = RTPMLAAttention(mla_modules=modules) - - assert getattr(attention, attr) == getattr(modules, attr) - - -class _FakeQProj: - def __init__(self, output): - self.output = output - self.calls = [] - - def __call__(self, query, q_scale=None): - self.calls.append((query, q_scale)) - return self.output - - -class _FakeOProj: - def __init__(self, hidden_dim: int): - self.hidden_dim = hidden_dim - self.calls = [] - - def __call__(self, tensor): - self.calls.append(tensor) - return tensor.new_empty((tensor.shape[0], self.hidden_dim)) - - -def test_native_five_tuple_projects_latent_query_and_applies_o_proj(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - token_count = 3 - num_heads = 2 - qk_head_dim = 4 - v_head_dim = 5 - hidden_dim = 7 - query = torch.arange(token_count * 6, dtype=torch.float32).reshape(token_count, 6) - q_scale = torch.ones(token_count, 1) - projected_q = torch.arange( - token_count * num_heads * qk_head_dim, dtype=torch.float32 - ).reshape(token_count, num_heads * qk_head_dim) - compressed_kv = torch.empty(token_count, 8) - k_rope = torch.empty(token_count, 3) - positions = torch.arange(token_count, dtype=torch.int32) - backend = _FakeDenseBackend(v_head_dim=v_head_dim) - modules = SimpleNamespace( - q_proj=_FakeQProj(projected_q), - o_proj=_FakeOProj(hidden_dim=hidden_dim), - kv_b_proj=object(), - v_head_dim=v_head_dim, - qk_head_dim=qk_head_dim, - num_heads=num_heads, - num_local_heads=num_heads, - ) - attention = RTPMLAAttention( - mla_modules=modules, - dense_backend=backend, - layer_num=5, - kv_cache="kv-cache", - ) - - output = attention.forward( - query, - compressed_kv, - k_rope, - positions=positions, - q_scale=q_scale, - ) - - assert modules.q_proj.calls == [(query, q_scale)] - assert len(backend.calls) == 1 - call = backend.calls[0] - assert call["q"].shape == (token_count, num_heads, qk_head_dim) - assert torch.equal(call["q"].reshape(token_count, -1), projected_q) - assert call["compressed_kv"] is compressed_kv - assert call["k_pe"] is k_rope - assert call["kv_cache"] == "kv-cache" - assert call["layer_id"] == 5 - assert len(modules.o_proj.calls) == 1 - assert modules.o_proj.calls[0].shape == (token_count, num_heads * v_head_dim) - assert output.shape == (token_count, hidden_dim) - - -def test_rtp_mla_attention_builds_m0_backend_from_mla_modules(): - modules = SimpleNamespace(v_head_dim=16) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - q = torch.empty(2, 4, 12) - compressed_kv = torch.empty(2, 8) - k_pe = torch.empty(2, 4) - - output = attention(q, compressed_kv, k_pe, positions=torch.arange(2)) - - assert output.shape == (2, 4, 16) - - -def test_rtp_mla_attention_defaults_to_sparse_backend_from_mla_modules(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( - RTPDenseMlaBackend, - ) - from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( - RTPSparseMlaBackend, - ) - - modules = SimpleNamespace(v_head_dim=16) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - assert isinstance(attention.dense_backend, RTPSparseMlaBackend) - assert isinstance(attention.dense_backend.dense_backend, RTPDenseMlaBackend) - - -class _FakeKVProj: - def __init__(self, output: torch.Tensor): - self.output = output - self.calls = [] - - def __call__(self, compressed_kv): - self.calls.append(compressed_kv) - output = self.output.to(device=compressed_kv.device, dtype=compressed_kv.dtype) - if output.shape[0] == 1 and compressed_kv.shape[0] != 1: - output = output.expand(compressed_kv.shape[0], -1).contiguous() - return output - - -class _DeterministicKVProj: - def __init__(self, output_dim: int): - self.output_dim = output_dim - self.calls = [] - - def __call__(self, compressed_kv): - self.calls.append(compressed_kv.detach().clone()) - token_signal = compressed_kv.float().mean(dim=-1, keepdim=True) - basis = torch.linspace( - 0.0, - 1.0, - self.output_dim, - device=compressed_kv.device, - dtype=torch.float32, - ).unsqueeze(0) - return (token_signal + basis).to(dtype=compressed_kv.dtype) - - -class _FakeRotaryEmbedding: - is_neox_style = True - - def __init__(self): - self.calls = [] - - def __call__(self, positions, query, key): - self.calls.append( - { - "positions": positions.detach().clone(), - "query": query.detach().clone(), - "key": key.detach().clone(), - } - ) - offset = positions.to(device=query.device, dtype=query.dtype) - while offset.ndim < query.ndim: - offset = offset.unsqueeze(-1) - query = query + offset - key_offset = positions.to(device=key.device, dtype=key.dtype) - while key_offset.ndim < key.ndim: - key_offset = key_offset.unsqueeze(-1) - key = key + key_offset - return query, key - - -def _patch_forward_context( - monkeypatch, - *, - is_prefill, - query_start_loc, - seq_lens=None, - block_table=None, - slot_mapping=None, - kv_cache_data=None, -): - plugin_metadata = SimpleNamespace( - query_start_loc=query_start_loc, - rtp_cu_seqlens_q=query_start_loc, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=slot_mapping, - ) - fake_context = SimpleNamespace( - attn_metadata=SimpleNamespace( - plugin_metadata=plugin_metadata, - rtp_kernel_seq_size_per_block=4, - ), - context=SimpleNamespace(is_prefill=is_prefill), - kv_cache_data=kv_cache_data, - ) - forward_context_module = importlib.import_module("atom.utils.forward_context") - monkeypatch.setattr( - forward_context_module, - "get_forward_context", - lambda: fake_context, - ) - - -def _patch_forward_context_with_top_level_attn_metadata( - monkeypatch, - *, - is_prefill, - seq_lens, - block_table, - slot_mapping, - kv_cache_data=None, -): - fake_context = SimpleNamespace( - attn_metadata=SimpleNamespace( - plugin_metadata=None, - context_lens=seq_lens, - block_tables=block_table, - slot_mapping=slot_mapping, - cu_seqlens_q=None, - rtp_kernel_seq_size_per_block=4, - ), - context=SimpleNamespace(is_prefill=is_prefill), - kv_cache_data=kv_cache_data, - ) - forward_context_module = importlib.import_module("atom.utils.forward_context") - monkeypatch.setattr( - forward_context_module, - "get_forward_context", - lambda: fake_context, - ) - - -def _patch_forward_context_without_is_prefill(monkeypatch, *, query_start_loc): - plugin_metadata = SimpleNamespace( - query_start_loc=query_start_loc, - rtp_cu_seqlens_q=query_start_loc, - seq_lens=None, - block_table=None, - slot_mapping=None, - ) - fake_context = SimpleNamespace( - attn_metadata=SimpleNamespace( - plugin_metadata=plugin_metadata, - rtp_kernel_seq_size_per_block=4, - ), - context=SimpleNamespace(), - ) - forward_context_module = importlib.import_module("atom.utils.forward_context") - monkeypatch.setattr( - forward_context_module, - "get_forward_context", - lambda: fake_context, - ) - - -def test_default_dense_mla_backend_computes_nonzero_attention(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( - RTPSparseMlaBackend, - ) - - q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.ones(2, 4, dtype=torch.float32) - # Per token: [k_nope_dim=2, v_head_dim=1]. - kv_projection = torch.tensor([[1.0, 0.0, 5.0], [0.0, 1.0, 7.0]]) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(kv_projection), - ) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - output = attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) - - assert isinstance(attention.dense_backend, RTPSparseMlaBackend) - assert output.shape == (2, 1, 1) - assert not torch.equal(output, torch.zeros_like(output)) - assert len(modules.kv_b_proj.calls) == 1 - assert modules.kv_b_proj.calls[0] is compressed_kv - - -def test_default_dense_mla_backend_rejects_missing_multi_token_metadata(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(2, 1, 2) - compressed_kv = torch.ones(2, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(2, 3)), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - with pytest.raises(ValueError, match="query_start_loc metadata"): - attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) - - -def test_default_dense_mla_backend_decode_reads_history_from_raw_cache(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) - # The backend projects each latent token into [k_nope0, k_nope1, v]. - kv_projection = torch.tensor([[0.0, 0.0, 1.0]]) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(kv_projection), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - # Three historical latent tokens are already in cache. - layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) - layer_cache.kv_cache_base[0, 1] = torch.tensor([2.0, 0.0, 0.0, 0.0]) - layer_cache.kv_cache_base[0, 2] = torch.tensor([3.0, 0.0, 0.0, 0.0]) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([4], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([3], dtype=torch.int32), - ) - attention = RTPMLAAttention( - mla_modules=modules, - layer_num=3, - kv_cache=layer_cache, - ) - - output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - assert output.shape == (1, 1, 1) - assert layer_cache.kv_cache_base[0, 3].tolist() == [9.0, 9.0, 9.0, 9.0] - - -def test_default_dense_mla_backend_decode_uses_top_level_rtp_metadata(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) - _patch_forward_context_with_top_level_attn_metadata( - monkeypatch, - is_prefill=False, - seq_lens=torch.tensor([2], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([1], dtype=torch.int32), - ) - attention = RTPMLAAttention( - mla_modules=modules, - layer_num=3, - kv_cache=layer_cache, - ) - - output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - assert output.shape == (1, 1, 1) - assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] - - -def test_default_dense_mla_backend_decode_rebuilds_stale_query_start_loc(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - seq_lens=torch.tensor([2], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([1], dtype=torch.int32), - ) - attention = RTPMLAAttention( - mla_modules=modules, - layer_num=3, - kv_cache=layer_cache, - ) - - output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - assert output.shape == (1, 1, 1) - assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] - - -def test_default_sparse_wrapper_refuses_mock_dense_fallback(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend import rtp_sparse_mla_backend - - dense_backend = _FakeDenseBackend(v_head_dim=4) - sparse_impl = SimpleNamespace(calls=[]) - backend = rtp_sparse_mla_backend.RTPSparseMlaBackend( - dense_backend=dense_backend, - sparse_impl=sparse_impl, - v_head_dim=4, - ) - q = torch.ones(2, 1, 3) - compressed_kv = torch.ones(2, 5) - k_pe = torch.ones(2, 2) - positions = torch.arange(2) - topk = torch.tensor([[1, 0], [0, 1]], dtype=torch.int32) - - try: - backend.forward( - q, - compressed_kv, - k_pe, - kv_cache="cache", - layer_id=9, - topk_indices=topk, - positions=positions, - ) - except rtp_sparse_mla_backend._SparseUnavailable: - pass - else: - raise AssertionError("default sparse mock must not silently fallback to dense") - - assert dense_backend.calls == [] - assert sparse_impl.calls == [] - - -def test_default_dense_mla_backend_resolves_kv_cache_from_forward_context(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.tensor([[9.0, 9.0, 9.0, 9.0]], dtype=torch.float32) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[0.0, 0.0, 1.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - layer_cache.kv_cache_base[0, 0] = torch.tensor([1.0, 0.0, 0.0, 0.0]) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([2], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([1], dtype=torch.int32), - kv_cache_data={"layer_3": SimpleNamespace(k_cache=layer_cache)}, - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - output = attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - assert output.shape == (1, 1, 1) - assert layer_cache.kv_cache_base[0, 1].tolist() == [9.0, 9.0, 9.0, 9.0] - - -def test_default_dense_mla_backend_accepts_noncontiguous_compressed_kv(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) - storage = torch.arange(16, dtype=torch.float32).reshape(2, 8) - compressed_kv = storage[:, ::2] - assert not compressed_kv.is_contiguous() - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), - ) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - output = attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) - - assert output.shape == (2, 1, 1) - assert modules.kv_b_proj.calls[0].is_contiguous() - - -def test_default_dense_mla_backend_skips_negative_slot_mapping(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]], dtype=torch.float32) - compressed_kv = torch.tensor( - [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32 - ) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - slot_mapping=torch.tensor([-1, 1], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) - - assert torch.equal(layer_cache.kv_cache_base[0, -1], torch.zeros(4)) - assert torch.equal(layer_cache.kv_cache_base[0, 1], compressed_kv[1]) - - -def test_default_dense_mla_backend_rejects_oob_slot_mapping(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.tensor([[[1.0, 0.0]]], dtype=torch.float32) - compressed_kv = torch.tensor([[1.0, 2.0, 3.0, 4.0]], dtype=torch.float32) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - slot_mapping=torch.tensor([4], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - with pytest.raises(RuntimeError, match="out-of-bounds slot_mapping"): - attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - assert torch.equal(layer_cache.kv_cache_base, torch.zeros_like(layer_cache.kv_cache_base)) - - -def test_default_dense_mla_backend_writes_post_rope_kpe_to_cache(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - rotary_emb = _FakeRotaryEmbedding() - q = torch.tensor([[[1.0, 2.0, 10.0, 20.0]], [[3.0, 4.0, 30.0, 40.0]]]) - compressed_kv = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - k_pe = torch.tensor([[100.0, 200.0], [300.0, 400.0]]) - positions = torch.tensor([5, 7], dtype=torch.long) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=2, - rotary_emb=rotary_emb, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 5)) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - slot_mapping=torch.tensor([0, 1], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - attention(q, compressed_kv, k_pe, positions=positions) - - expected_k_pe = k_pe + positions.to(k_pe.dtype).unsqueeze(-1) - expected_cache = torch.cat((compressed_kv, expected_k_pe), dim=-1) - assert torch.equal(layer_cache.kv_cache_base[0, :2], expected_cache) - assert torch.equal(rotary_emb.calls[0]["positions"], positions) - - -def test_default_dense_mla_backend_uses_post_rope_q_for_attention(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( - RTPDenseMlaBackend, - ) - - captured = {} - - def _fake_causal_attention(q, key, value, query_start_loc, scale): - del key, query_start_loc, scale - captured["q"] = q.detach().clone() - return value.new_zeros((q.shape[0], q.shape[1], value.shape[-1])) - - monkeypatch.setattr( - RTPDenseMlaBackend, - "_causal_attention", - staticmethod(_fake_causal_attention), - ) - rotary_emb = _FakeRotaryEmbedding() - q = torch.tensor([[[1.0, 2.0, 10.0, 20.0]], [[3.0, 4.0, 30.0, 40.0]]]) - compressed_kv = torch.ones(2, 3) - k_pe = torch.ones(2, 2) - positions = torch.tensor([5, 7], dtype=torch.long) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=2, - rotary_emb=rotary_emb, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 2.0]])), - ) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - attention(q, compressed_kv, k_pe, positions=positions) - - expected_q = q.clone() - expected_q[..., -2:] += positions.to(q.dtype).view(2, 1, 1) - assert torch.equal(captured["q"], expected_q) - - -def test_default_dense_mla_backend_decode_history_kpe_not_double_roped(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - from atom.plugin.rtpllm.attention_backend.rtp_dense_mla_backend import ( - RTPDenseMlaBackend, - ) - - captured_k_pe = [] - original_project_kv = RTPDenseMlaBackend._project_kv - - def _capture_project_kv(self, q, compressed_kv, k_pe): - captured_k_pe.append(k_pe.detach().clone()) - return original_project_kv(self, q, compressed_kv, k_pe) - - monkeypatch.setattr(RTPDenseMlaBackend, "_project_kv", _capture_project_kv) - rotary_emb = _FakeRotaryEmbedding() - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=2, - rotary_emb=rotary_emb, - kv_b_proj=_FakeKVProj(torch.tensor([[1.0, 0.0, 1.0]])), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 5)) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - prefill_k_pe = torch.tensor([[10.0, 20.0], [30.0, 40.0]]) - prefill_positions = torch.tensor([4, 5], dtype=torch.long) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - slot_mapping=torch.tensor([0, 1], dtype=torch.int32), - ) - attention( - torch.ones(2, 1, 4), - torch.ones(2, 3), - prefill_k_pe, - positions=prefill_positions, - ) - - decode_k_pe = torch.tensor([[50.0, 60.0]]) - decode_positions = torch.tensor([6], dtype=torch.long) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([3], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([2], dtype=torch.int32), - ) - attention( - torch.ones(1, 1, 4), - torch.ones(1, 3), - decode_k_pe, - positions=decode_positions, - ) - - expected_history_k_pe = torch.cat( - ( - prefill_k_pe + prefill_positions.to(prefill_k_pe.dtype).unsqueeze(-1), - decode_k_pe + decode_positions.to(decode_k_pe.dtype).unsqueeze(-1), - ), - dim=0, - ) - assert torch.equal(captured_k_pe[-1], expected_history_k_pe) - - -def test_default_dense_mla_backend_rejects_missing_is_prefill_metadata(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(1, 1, 2) - compressed_kv = torch.ones(1, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(1, 3)), - ) - _patch_forward_context_without_is_prefill( - monkeypatch, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - with pytest.raises(ValueError, match="context.is_prefill"): - attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - -@pytest.mark.parametrize( - ("field_name", "seq_lens", "block_table", "slot_mapping"), - [ - ("seq_lens", None, torch.tensor([[0]], dtype=torch.int32), torch.tensor([0])), - ("block_table", torch.tensor([1], dtype=torch.int32), None, torch.tensor([0])), - ("slot_mapping", torch.tensor([1], dtype=torch.int32), torch.tensor([[0]], dtype=torch.int32), None), - ], -) -def test_default_dense_mla_backend_decode_requires_rtp_metadata( - monkeypatch, field_name, seq_lens, block_table, slot_mapping -): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(1, 1, 2) - compressed_kv = torch.ones(1, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(1, 3)), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4)) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=slot_mapping, - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - with pytest.raises(ValueError, match=field_name): - attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - -def test_default_dense_mla_backend_decode_requires_readable_cache(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(1, 1, 2) - compressed_kv = torch.ones(1, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(1, 3)), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.empty(0)) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([1], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([0], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - with pytest.raises(ValueError, match="kv_cache_base"): - attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - -def test_default_dense_mla_backend_rejects_fp8_kv_cache(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(1, 1, 2) - compressed_kv = torch.ones(1, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(1, 3)), - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, 4, dtype=torch.uint8)) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([1], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([0], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - - with pytest.raises(NotImplementedError, match="FP8 KV cache"): - attention(q, compressed_kv, q.new_empty((1, 0)), positions=torch.arange(1)) - - -def test_default_dense_mla_backend_glm5_shape_bf16_cache_roundtrip(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - num_heads = 32 - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 - v_head_dim = 128 - kv_lora_rank = 512 - kv_dim = kv_lora_rank + qk_rope_head_dim - output_dim = num_heads * (qk_nope_head_dim + v_head_dim) - kv_proj = _DeterministicKVProj(output_dim) - rotary_emb = _FakeRotaryEmbedding() - modules = SimpleNamespace( - v_head_dim=v_head_dim, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - rotary_emb=rotary_emb, - kv_b_proj=kv_proj, - ) - layer_cache = SimpleNamespace(kv_cache_base=torch.zeros(1, 4, kv_dim, dtype=torch.bfloat16)) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3, kv_cache=layer_cache) - q_prefill = torch.randn( - 3, - num_heads, - qk_nope_head_dim + qk_rope_head_dim, - dtype=torch.bfloat16, - ) - compressed_prefill = torch.randn(3, kv_lora_rank, dtype=torch.bfloat16) - k_pe_prefill = torch.randn(3, qk_rope_head_dim, dtype=torch.bfloat16) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 3], dtype=torch.int32), - slot_mapping=torch.tensor([0, 1, 2], dtype=torch.int32), - ) - - prefill_output = attention( - q_prefill, - compressed_prefill, - k_pe_prefill, - positions=torch.arange(3), - ) - - assert prefill_output.shape == (3, num_heads, v_head_dim) - expected_prefill_k_pe = k_pe_prefill + torch.arange(3).to( - dtype=k_pe_prefill.dtype - ).unsqueeze(-1) - expected_prefill_cache = torch.cat((compressed_prefill, expected_prefill_k_pe), dim=-1) - assert torch.equal(layer_cache.kv_cache_base[0, :3], expected_prefill_cache) - - q_decode = torch.randn( - 1, - num_heads, - qk_nope_head_dim + qk_rope_head_dim, - dtype=torch.bfloat16, - ) - compressed_decode = torch.randn(1, kv_lora_rank, dtype=torch.bfloat16) - k_pe_decode = torch.randn(1, qk_rope_head_dim, dtype=torch.bfloat16) - _patch_forward_context( - monkeypatch, - is_prefill=False, - query_start_loc=torch.tensor([0, 1], dtype=torch.int32), - seq_lens=torch.tensor([4], dtype=torch.int32), - block_table=torch.tensor([[0]], dtype=torch.int32), - slot_mapping=torch.tensor([3], dtype=torch.int32), - ) - - decode_output = attention( - q_decode, - compressed_decode, - k_pe_decode, - positions=torch.arange(1), - ) - - assert decode_output.shape == (1, num_heads, v_head_dim) - expected_decode_k_pe = k_pe_decode + torch.arange(1).to( - dtype=k_pe_decode.dtype - ).unsqueeze(-1) - expected_decode_cache = torch.cat((compressed_decode, expected_decode_k_pe), dim=-1) - assert torch.equal(layer_cache.kv_cache_base[0, 3:4], expected_decode_cache) - expected_history = torch.cat((compressed_prefill, compressed_decode), dim=0) - assert torch.equal(kv_proj.calls[-1], expected_history) - - -def test_default_dense_mla_backend_rejects_bad_kv_projection_shape(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - q = torch.randn(2, 1, 2) - compressed_kv = torch.ones(2, 4) - modules = SimpleNamespace( - v_head_dim=1, - qk_nope_head_dim=2, - qk_rope_head_dim=0, - kv_b_proj=_FakeKVProj(torch.empty(2, 2)), - ) - _patch_forward_context( - monkeypatch, - is_prefill=True, - query_start_loc=torch.tensor([0, 2], dtype=torch.int32), - ) - attention = RTPMLAAttention(mla_modules=modules, layer_num=3) - - with pytest.raises(ValueError, match="kv_b_proj output shape mismatch"): - attention(q, compressed_kv, q.new_empty((2, 0)), positions=torch.arange(2)) - - -def test_rtp_mla_attention_explicit_dense_backend_overrides_sparse_default(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - dense_backend = _FakeDenseBackend(v_head_dim=16) - modules = SimpleNamespace(v_head_dim=16) - - attention = RTPMLAAttention(mla_modules=modules, dense_backend=dense_backend) - - assert attention.dense_backend is dense_backend - diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index 6cf0dac27e..ed563d3b0b 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -129,23 +129,64 @@ def fake_sparse_attn_indexer_fake(*args): assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) -class _FakeDenseBackend: - def __init__(self, v_head_dim: int = 5): - self.v_head_dim = v_head_dim - self.calls = [] +def test_rtp_sparse_attn_indexer_short_prefill_fills_causal_topk(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = sys.modules["atom.utils.forward_context"] + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_prefill=True, is_dummy_run=False), + attn_metadata=SimpleNamespace(max_seqlen_k=4), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) - def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): - self.calls.append( - { - "q": q, - "compressed_kv": compressed_kv, - "k_pe": k_pe, - "kv_cache": kv_cache, - "layer_id": layer_id, - "topk_indices": topk_indices, - } + def _unexpected_call(*args, **kwargs): + raise AssertionError( + "short prefill path should not call deepseek sparse_attn_indexer" ) - return q.new_full((q.shape[0], q.shape[1], self.v_head_dim), -1) + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = _unexpected_call + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + topk_buffer = torch.full((3, 8), -99, dtype=torch.int32) + positions = torch.tensor([0, 1, 3], dtype=torch.int32) + tensor = torch.empty(3, 2) + weights = torch.randn(3, 4) + out = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + weights, + 128, + None, + 6, + 64, + 4096, + 3, + topk_buffer, + tensor, + tensor, + 1e-6, + positions, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert out is weights + assert topk_buffer[:3, :6].tolist() == [ + [0, -1, -1, -1, -1, -1], + [0, 1, -1, -1, -1, -1], + [0, 1, 2, 3, -1, -1], + ] class _FakeSparseImpl: @@ -178,14 +219,9 @@ def forward( return q.new_full((q.shape[0], q.shape[1], self.v_head_dim), 7) -def _build_backend(backend_cls, dense_backend, sparse_impl): +def _build_backend(backend_cls, sparse_impl): params = inspect.signature(backend_cls).parameters kwargs = {} - if "dense_backend" not in params: - raise AssertionError( - "RTPSparseMlaBackend must accept dense_backend= for dense fallback" - ) - kwargs["dense_backend"] = dense_backend if "sparse_impl" in params: kwargs["sparse_impl"] = sparse_impl @@ -195,7 +231,7 @@ def _build_backend(backend_cls, dense_backend, sparse_impl): ) if "v_head_dim" in params: - kwargs["v_head_dim"] = dense_backend.v_head_dim + kwargs["v_head_dim"] = int(getattr(sparse_impl, "v_head_dim", 5)) return backend_cls(**kwargs) @@ -211,9 +247,8 @@ def _make_inputs(): def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) - dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() - backend = _build_backend(backend_cls, dense_backend, sparse_impl) + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[4, 1], [3, 0], [2, 1]], dtype=torch.int32) @@ -222,14 +257,13 @@ def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): ) assert output.shape == (3, 2, sparse_impl.v_head_dim) - assert dense_backend.calls == [] assert len(sparse_impl.calls) == 1 assert sparse_impl.calls[0]["topk_indices"] is topk assert sparse_impl.calls[0]["topk_indices"].dtype == torch.int32 assert sparse_impl.calls[0]["topk_indices"].shape == (3, 2) -def test_sparse_backend_falls_back_to_dense_when_topk_is_none(monkeypatch): +def test_sparse_backend_prefill_without_topk_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) forward_context_mod = sys.modules["atom.utils.forward_context"] fake_forward_context = SimpleNamespace( @@ -244,24 +278,18 @@ def test_sparse_backend_falls_back_to_dense_when_topk_is_none(monkeypatch): lambda: fake_forward_context, raising=False, ) - dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() - backend = _build_backend(backend_cls, dense_backend, sparse_impl) + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() - output = backend.forward( - q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None - ) - - assert output.shape == (3, 2, dense_backend.v_head_dim) - assert len(dense_backend.calls) == 1 + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None) + except module._SparseUnavailable as exc: + assert "requires topk_indices" in str(exc) + else: + raise AssertionError("Expected missing prefill topk_indices to raise") assert sparse_impl.calls == [] - assert dense_backend.calls[0]["q"] is q - assert dense_backend.calls[0]["compressed_kv"] is compressed_kv - assert dense_backend.calls[0]["k_pe"] is k_pe - assert dense_backend.calls[0]["kv_cache"] is kv_cache - assert dense_backend.calls[0]["layer_id"] == layer_id - assert dense_backend.calls[0]["topk_indices"] is None def test_sparse_backend_decode_without_topk_raises(monkeypatch): @@ -280,26 +308,23 @@ def test_sparse_backend_decode_without_topk_raises(monkeypatch): lambda: fake_forward_context, raising=False, ) - dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() - backend = _build_backend(backend_cls, dense_backend, sparse_impl) + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() try: backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None) except module._SparseUnavailable as exc: - assert "decode requires topk_indices" in str(exc) + assert "requires topk_indices" in str(exc) else: raise AssertionError("Expected missing decode topk_indices to raise") - assert dense_backend.calls == [] assert sparse_impl.calls == [] def test_sparse_backend_threads_kv_cache_and_layer_id_to_sparse_impl(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) - dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() - backend = _build_backend(backend_cls, dense_backend, sparse_impl) + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) @@ -328,9 +353,8 @@ def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): lambda: fake_forward_context, raising=False, ) - dense_backend = _FakeDenseBackend() sparse_impl = _FakeSparseImpl() - backend = _build_backend(backend_cls, dense_backend, sparse_impl) + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) @@ -362,8 +386,8 @@ class _MissingPrefillSparse: def forward(self, *args, **kwargs): raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") - dense_backend = _FakeDenseBackend() - backend = _build_backend(backend_cls, dense_backend, _MissingPrefillSparse()) + sparse_impl = _MissingPrefillSparse() + backend = _build_backend(backend_cls, sparse_impl) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) @@ -376,8 +400,6 @@ def forward(self, *args, **kwargs): "prefill sparse unavailability must not fall back to dense" ) - assert len(dense_backend.calls) == 0 - def test_sparse_backend_decode_missing_sparse_kernel_still_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) @@ -402,7 +424,7 @@ class _MissingDecodeSparse: def forward(self, *args, **kwargs): raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") - backend = _build_backend(backend_cls, _FakeDenseBackend(), _MissingDecodeSparse()) + backend = _build_backend(backend_cls, _MissingDecodeSparse()) q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) From b4997d6f4f5df7e7672bdcd08d6a2578692a2d8a Mon Sep 17 00:00:00 2001 From: Zhao An Date: Sat, 13 Jun 2026 17:25:59 +0000 Subject: [PATCH 10/20] fix: RTP GLM5 prefil reuse Sparse MLA metadata --- .../rtp_sparse_mla_backend.py | 13 ++- ...est_rtpllm_glm5_sparse_backend_contract.py | 91 +++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 6dc763a56a..e7446950e3 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -927,6 +927,7 @@ def _build_atom_sparse_metadata( head_repeat_factor = padded_num_heads // num_heads q_dtype = self._aiter_dtype_for_tensor(q_latent) kv_dtype = self._aiter_dtype_for_tensor(kv_cache_base) + reuse_eager_metadata = False if in_capture: work_meta_data = sparse_bufs["work_meta_data"] work_indptr = sparse_bufs["work_indptr"] @@ -956,6 +957,9 @@ def _build_atom_sparse_metadata( reduce_indptr = cached_eager_meta["reduce_indptr"] reduce_final_map = cached_eager_meta["reduce_final_map"] reduce_partial_map = cached_eager_meta["reduce_partial_map"] + reuse_eager_metadata = bool( + cached_eager_meta.get("metadata_ready", False) + ) else: metadata_budget_tokens = self._metadata_token_budget( num_tokens=num_tokens, topk=topk @@ -1005,6 +1009,7 @@ def _build_atom_sparse_metadata( "reduce_indptr": reduce_indptr, "reduce_final_map": reduce_final_map, "reduce_partial_map": reduce_partial_map, + "metadata_ready": False, } except Exception: pass @@ -1050,7 +1055,7 @@ def _build_atom_sparse_metadata( max_slots=max_page_slots, ) - if not reuse_capture_metadata: + if not reuse_capture_metadata and not reuse_eager_metadata: get_mla_metadata_v1( qo_indptr, paged_kv_indptr, @@ -1072,6 +1077,12 @@ def _build_atom_sparse_metadata( dtype_q=q_dtype, dtype_kv=kv_dtype, ) + if not in_capture: + cached_eager_meta = getattr( + plugin_metadata, "_rtp_sparse_eager_meta_workspace", None + ) + if isinstance(cached_eager_meta, dict): + cached_eager_meta["metadata_ready"] = True if in_capture: plugin_metadata._rtp_sparse_capture_meta_workspace = { "signature": capture_meta_sig, diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index ed563d3b0b..b67d907df3 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -609,6 +609,97 @@ def fake_convert( assert decode_call["kwargs"]["reduce_final_map"] is not None +def test_real_sparse_eager_metadata_workspace_skips_refill(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + metadata_calls = [] + + fake_aiter = type(sys)("aiter") + fake_aiter.dtypes = SimpleNamespace(d_dtypes={"bf16": "bf16", "fp16": "fp16"}) + + def fake_metadata_info(*args, **kwargs): + return ( + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + ) + + def fake_metadata_v1(*args, **kwargs): + metadata_calls.append((args, kwargs)) + + fake_aiter.get_mla_metadata_info_v1 = fake_metadata_info + fake_aiter.get_mla_metadata_v1 = fake_metadata_v1 + monkeypatch.setitem(sys.modules, "aiter", fake_aiter) + monkeypatch.setattr( + torch.cuda, "is_current_stream_capturing", lambda: False, raising=False + ) + + fake_sparse_helpers = type(sys)("atom.plugin.attention_mla_sparse") + + def fake_convert( + req_id, + block_table, + token_indices, + cu_seqlens, + out, + BLOCK_SIZE=1, + NUM_TOPK_TOKENS=0, + BLOCK_N=128, + ): + del req_id, block_table, token_indices, BLOCK_SIZE, NUM_TOPK_TOKENS, BLOCK_N + out[: int(cu_seqlens[-1].item())].zero_() + + fake_sparse_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem( + sys.modules, + "atom.plugin.attention_mla_sparse", + fake_sparse_helpers, + ) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5) + kv_cache = torch.randn(8, 1, 5) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + plugin_metadata = SimpleNamespace( + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([3, 2], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [1]], dtype=torch.int32), + ) + attn_metadata = SimpleNamespace(plugin_metadata=plugin_metadata) + + first = impl._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + second = impl._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + + assert len(metadata_calls) == 1 + assert second.work_meta_data is first.work_meta_data + assert plugin_metadata._rtp_sparse_eager_meta_workspace["metadata_ready"] is True + + def test_real_sparse_decode_rejects_oob_paged_kv_indices(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) decode_called = {"value": False} From d208756e8d9eba50d6ce947749fa3df90c25dbb0 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Mon, 15 Jun 2026 06:19:02 +0000 Subject: [PATCH 11/20] fix: RTP GLM5 enable FP8 MLA path --- .../rtp_sparse_mla_backend.py | 37 +++++++++++++++++-- atom/plugin/rtpllm/models/glm5.py | 7 ++++ ...est_rtpllm_glm5_sparse_backend_contract.py | 27 +++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index e7446950e3..bf601af1da 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -367,7 +367,8 @@ def _cache_dtype_name(self, kv_cache_base: torch.Tensor) -> str: } if kv_cache_base.dtype not in fp8_dtypes: return "auto" - return "fp8_model1_mla" if self.kv_lora_rank == 448 else "fp8_ds_mla" + # RTP allocates GLM5 FP8 MLA KV cache in the aiter 576-byte/token layout. + return "fp8" def _write_current_to_cache( self, @@ -597,7 +598,7 @@ def prewarm_for_cuda_graph( ) -> None: del max_seq_len try: - from aiter import get_mla_metadata_info_v1 + from aiter import dtypes, get_mla_metadata_info_v1 except Exception as exc: raise _SparseUnavailable( f"aiter metadata prewarm unavailable: {exc}" @@ -691,6 +692,13 @@ def prewarm_for_cuda_graph( device=device, dtype=query_dtype, ), + "q_for_kernel_fp8": torch.empty( + max_tokens, + padded_num_heads, + latent_dim, + device=device, + dtype=dtypes.fp8, + ), "latent_output": torch.empty( max_tokens, padded_num_heads, @@ -1174,6 +1182,7 @@ def _run_aiter_sparse_decode( ) else: q_for_kernel = q_latent + output_dtype = q_for_kernel.dtype if in_capture and self._cg_sparse_bufs is not None: output = self._cg_sparse_bufs["latent_output"][ :num_tokens, : sparse_meta.padded_num_heads, : @@ -1181,9 +1190,30 @@ def _run_aiter_sparse_decode( else: output = torch.empty( (num_tokens, sparse_meta.padded_num_heads, self.kv_lora_rank), - dtype=q_for_kernel.dtype, + dtype=output_dtype, device=q_latent.device, ) + fp8_scale_kwargs = {} + if self._cache_dtype_name(kv_cache_base) == "fp8": + kv_scale = self._cache_write_scale.get(kv_cache_base.device) + if kv_scale is None: + kv_scale = torch.tensor( + 1.0, dtype=torch.float32, device=kv_cache_base.device + ) + self._cache_write_scale[kv_cache_base.device] = kv_scale + fp8_scale_kwargs = {"q_scale": kv_scale, "kv_scale": kv_scale} + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + if in_capture and self._cg_sparse_bufs is not None: + q_for_kernel_fp8 = self._cg_sparse_bufs["q_for_kernel_fp8"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + q_for_kernel_fp8.copy_(q_for_kernel) + q_for_kernel = q_for_kernel_fp8 + else: + q_for_kernel = q_for_kernel.to(dtype=dtypes.fp8) try: kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) if ( @@ -1221,6 +1251,7 @@ def _run_aiter_sparse_decode( reduce_indptr=sparse_meta.reduce_indptr, reduce_final_map=sparse_meta.reduce_final_map, reduce_partial_map=sparse_meta.reduce_partial_map, + **fp8_scale_kwargs, ) except Exception as exc: raise _SparseUnavailable(f"mla_decode_fwd failed: {exc}") from exc diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index 537491ed65..2999e022a6 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -577,6 +577,13 @@ def _is_external_plugin_mode() -> bool: modules = os.getenv("RTP_LLM_EXTERNAL_MODEL_PACKAGES", "") return "atom.plugin.rtpllm.models" in modules + @classmethod + def _create_config(cls, ckpt_path: str): + config = super()._create_config(ckpt_path) + # ATOM sparse MLA reads the FP8 KV cache through aiter's 576-token layout. + config.attn_config.mla_use_aiter_fp8_layout = True + return config + def support_cuda_graph(self) -> bool: if os.getenv("ENABLE_CUDA_GRAPH", "1") == "0": logger.info("ENABLE_CUDA_GRAPH=0 - ATOMGlm5Moe forces eager forward.") diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index b67d907df3..b1e7e99e12 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -577,8 +577,8 @@ def fake_convert( ), v_head_dim=3, ) - q_latent = torch.randn(2, 2, 5) - kv_cache = torch.randn(8, 1, 5) + q_latent = torch.randn(2, 2, 5, dtype=torch.bfloat16) + kv_cache = torch.empty(8, 1, 5, dtype=torch.uint8) topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) attn_metadata = SimpleNamespace( plugin_metadata=SimpleNamespace( @@ -598,17 +598,40 @@ def fake_convert( ) assert output.shape == (2, 2, 4) + assert output.dtype == torch.bfloat16 assert torch.all(output == 3) decode_call = calls["mla_decode_fwd"] assert decode_call["q"].shape == (2, 16, 5) + assert decode_call["q"].dtype == aiter.dtypes.fp8 assert decode_call["output"].shape == (2, 16, 4) + assert decode_call["output"].dtype == torch.bfloat16 assert decode_call["paged_kv_indptr"].tolist() == [0, 3, 5] assert decode_call["paged_kv_indices"][:5].tolist() == [0, 1, 2, 4, 5] assert decode_call["kwargs"]["page_size"] == 1 + assert decode_call["kwargs"]["q_scale"] is not None + assert decode_call["kwargs"]["kv_scale"] is not None assert decode_call["kwargs"]["work_meta_data"] is not None assert decode_call["kwargs"]["reduce_final_map"] is not None +def test_real_sparse_cache_dtype_uses_aiter_fp8_layout(): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=128, + ) + + assert impl._cache_dtype_name(torch.empty(1, 576, dtype=torch.uint8)) == "fp8" + assert impl._cache_dtype_name(torch.empty(1, 576, dtype=torch.bfloat16)) == "auto" + + def test_real_sparse_eager_metadata_workspace_skips_refill(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) metadata_calls = [] From 48089f91d68cf2ca0db0d722c610fefb634ab354 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Wed, 17 Jun 2026 10:43:28 +0000 Subject: [PATCH 12/20] feat: RTP GLM5 conflict issue after rebase --- .../rtpllm/attention_backend/__init__.py | 17 +- .../attention_backend/rtp_mla_attention.py | 46 ++++ .../rtp_sparse_mla_backend.py | 260 +++++++++++++++++- atom/plugin/rtpllm/models/glm5.py | 6 +- atom/plugin/rtpllm/utils/forward_context.py | 43 ++- .../test_rtpllm_glm5_indexer_contract.py | 37 +++ tests/plugin/test_rtpllm_glm5_mla_patch.py | 29 +- ...est_rtpllm_glm5_sparse_backend_contract.py | 83 ++++++ 8 files changed, 508 insertions(+), 13 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index 053a7e3d47..4045e2fe78 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,4 +1,3 @@ -from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch from .rtp_mla_metadata import ( GLM5_RTP_BRIDGE_MODE, @@ -10,12 +9,18 @@ def __getattr__(name): - if name in {"RTPAttention", "RTPFullAttention"}: - from .rtp_full_attention import RTPAttention, RTPFullAttention + if name == "AttentionForRTPLLM": + from .rtp_full_attention import AttentionForRTPLLM - return {"RTPAttention": RTPAttention, "RTPFullAttention": RTPFullAttention}[ - name - ] + return AttentionForRTPLLM + if name == "RTPFullAttention": + from .rtp_full_attention import RTPFullAttention + + return RTPFullAttention + if name == "RTPAttention": + from .rtp_full_attention import RTPFullAttention + + return RTPFullAttention if name == "apply_attention_gdn_rtpllm_patch": from .attention_gdn import apply_attention_gdn_rtpllm_patch diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index fd1d30cd6c..9c510e063a 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +from types import MethodType from typing import Optional import torch @@ -54,6 +55,39 @@ def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None: return __import__("atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend") indexer.sparse_attn_indexer_impl = torch.ops.aiter.rtp_sparse_attn_indexer + if getattr(indexer, "_atom_rtp_topk_buffer_patched", False) or not hasattr( + indexer, "forward" + ): + return + original_forward = indexer.forward + + def _forward_with_topk_buffer(self, hidden_states, *args, **kwargs): + num_tokens = int(hidden_states.shape[0]) + topk_tokens = getattr(self, "topk_tokens", None) + if topk_tokens is None: + topk_tokens = getattr(self, "index_topk") + topk_tokens = int(topk_tokens) + buffer = getattr(self, "topk_indices_buffer", None) + needs_new_buffer = ( + buffer is None + or buffer.dim() != 2 + or buffer.device != hidden_states.device + or int(buffer.shape[0]) < num_tokens + or int(buffer.shape[1]) < topk_tokens + ) + if needs_new_buffer: + buffer = torch.empty( + num_tokens, + topk_tokens, + dtype=torch.int32, + device=hidden_states.device, + ) + self.topk_indices_buffer = buffer + self.sparse_kv_indices_buffer = self.topk_indices_buffer + return original_forward(hidden_states, *args, **kwargs) + + indexer.forward = MethodType(_forward_with_topk_buffer, indexer) + indexer._atom_rtp_topk_buffer_patched = True class RTPMLAAttention: @@ -200,7 +234,19 @@ def forward( def apply_attention_mla_rtpllm_patch() -> None: """Switch ATOM's generic Attention symbol to the RTP MLA adapter.""" + import sys + import atom.model_ops as ops + import atom.model_ops.base_attention as base_attention ops.RTPMLAAttention = RTPMLAAttention ops.Attention = RTPMLAAttention + base_attention.Attention = RTPMLAAttention + + deepseek_v2 = sys.modules.get("atom.models.deepseek_v2") + if deepseek_v2 is None: + try: + import atom.models.deepseek_v2 as deepseek_v2 + except (ImportError, ModuleNotFoundError): + return + deepseek_v2.Attention = RTPMLAAttention diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index bf601af1da..21cb816ecb 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -2,6 +2,7 @@ from __future__ import annotations +import importlib import inspect import os from dataclasses import dataclass @@ -16,6 +17,25 @@ class _SparseUnavailable(RuntimeError): pass +def _resolve_plugin_sparse_index_converter(): + """Resolve the plugin-style request-local topk to global KV index converter.""" + errors: list[str] = [] + for module_name in ( + # Old GLM5 RTP branch location. + "atom.plugin.attention_mla_sparse", + # Current refactored plugin location with the same call contract. + "atom.plugin.vllm.attention.layer_sparse_mla", + ): + try: + module = importlib.import_module(module_name) + return getattr(module, "triton_convert_req_index_to_global_index") + except Exception as exc: + errors.append(f"{module_name}: {exc}") + raise _SparseUnavailable( + "plugin sparse MLA index converter unavailable; " + "; ".join(errors) + ) + + @dataclass class _AbsorbedWeights: w_kc: torch.Tensor @@ -751,8 +771,9 @@ def _build_atom_sparse_metadata( ) -> _AtomSparseMetadata: try: from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 - from atom.plugin.attention_mla_sparse import ( - triton_convert_req_index_to_global_index, + + triton_convert_req_index_to_global_index = ( + _resolve_plugin_sparse_index_converter() ) except Exception as exc: raise _SparseUnavailable( @@ -1487,6 +1508,8 @@ def forward( return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) if topk_indices is None: + if self._default_mock: + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) raise _SparseUnavailable( "GLM5 RTP sparse MLA requires topk_indices; refusing dense fallback." ) @@ -1520,6 +1543,211 @@ def forward( ) from exc +def _run_rtp_sparse_attn_indexer_topk_only( + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, + context: Any, + attn_metadata: Any, +) -> torch.Tensor: + from aiter import ( + cp_gather_indexer_k_quant_cache, + dtypes, + indexer_k_quant_and_cache, + indexer_qk_rope_quant_and_cache, + top_k_per_row_decode, + top_k_per_row_prefill, + ) + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + from atom.config import get_current_atom_config + + slot_mapping = getattr(attn_metadata, "slot_mapping", None) + if slot_mapping is None: + raise _SparseUnavailable("RTP sparse indexer requires slot_mapping metadata.") + if topk_indices_buffer is None: + raise _SparseUnavailable("RTP sparse indexer requires topk_indices_buffer.") + if topk_indices_buffer.dim() != 2: + raise _SparseUnavailable( + "RTP sparse indexer requires a 2D topk_indices_buffer; " + f"got shape={tuple(topk_indices_buffer.shape)}." + ) + + if bool(getattr(context, "is_dummy_run", False)): + return torch.zeros_like(weights, dtype=torch.float32) + + num_tokens = int(hidden_states.shape[0]) + if num_tokens <= 0: + return weights + topk_indices = topk_indices_buffer[:num_tokens, :topk_tokens] + if topk_indices.dtype != torch.int32: + raise _SparseUnavailable( + f"RTP sparse indexer topk buffer must be int32, got {topk_indices.dtype}." + ) + + runner_block_size = int(get_current_atom_config().kv_cache_block_size) + kv_cache = kv_cache.view(-1, runner_block_size, kv_cache.shape[-1]) + + if use_qk_rope_cache_fusion: + q_bf16 = q_input + q_fp8 = torch.empty_like(q_bf16, dtype=dtypes.fp8) + weights_out = torch.empty( + weights.shape, device=weights.device, dtype=torch.float32 + ) + indexer_qk_rope_quant_and_cache( + q_bf16, + q_fp8, + weights, + weights_out, + k, + kv_cache, + slot_mapping, + k_norm_weight, + k_norm_bias, + positions, + cos_cache, + sin_cache, + k_norm_eps, + quant_block_size, + scale_fmt, + weights_scale, + preshuffle=True, + is_neox=is_neox_style, + ) + weights = weights_out + else: + q_fp8 = q_input + indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + preshuffle=True, + ) + + is_prefill = bool(getattr(context, "is_prefill", False)) + max_seqlen_k = int(getattr(attn_metadata, "max_seqlen_k", 0) or 0) + if is_prefill and max_seqlen_k <= int(topk_tokens): + return weights + + if is_prefill: + total_seq_lens = int(hidden_states.shape[0]) + has_cached = bool(getattr(attn_metadata, "has_cached", False)) + total_kv = ( + int(getattr(attn_metadata, "total_kv", total_seq_lens)) + if has_cached + else total_seq_lens + ) + k_fp8 = torch.empty([total_kv, head_dim], device=k.device, dtype=dtypes.fp8) + k_scale = torch.empty([total_kv, 1], device=k.device, dtype=torch.float32) + block_tables = getattr(attn_metadata, "block_tables", None) + cu_seqlens_q = getattr(attn_metadata, "cu_seqlens_q", None) + if block_tables is None or cu_seqlens_q is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires block_tables and cu_seqlens_q." + ) + cu_seqlens_k = ( + getattr(attn_metadata, "cu_seqlens_k", None) if has_cached else cu_seqlens_q + ) + if cu_seqlens_k is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires cu_seqlens_k." + ) + cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale.view(dtypes.fp8), + block_tables, + cu_seqlens_k, + preshuffle=True, + ) + cu_seqlen_ks = getattr(attn_metadata, "cu_seqlen_ks", None) + cu_seqlen_ke = getattr(attn_metadata, "cu_seqlen_ke", None) + if cu_seqlen_ks is None or cu_seqlen_ke is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires cu_seqlen_ks/cu_seqlen_ke." + ) + num_decode_tokens = 0 + logits = fp8_mqa_logits( + Q=q_fp8[num_decode_tokens:num_tokens], + KV=k_fp8, + kv_scales=k_scale, + weights=weights[num_decode_tokens:num_tokens], + cu_starts=cu_seqlen_ks, + cu_ends=cu_seqlen_ke, + ) + top_k_per_row_prefill( + logits=logits, + rowStarts=cu_seqlen_ks, + rowEnds=cu_seqlen_ke, + indices=topk_indices[num_decode_tokens:num_tokens, :topk_tokens], + values=None, + numRows=logits.shape[0], + stride0=logits.stride(0), + stride1=logits.stride(1), + ) + return weights + + max_seqlen_q = int(getattr(attn_metadata, "max_seqlen_q", 1) or 1) + num_decode_tokens = int(context.batch_size) * max_seqlen_q + kv_cache_for_logits = kv_cache.unsqueeze(-2) + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + int(context.batch_size), -1, *q_fp8.shape[1:] + ) + batch_size, next_n, _heads, _dim = padded_q_fp8_decode_tokens.shape + logits = torch.empty( + [batch_size * next_n, int(max_model_len)], + dtype=torch.float32, + device=hidden_states.device, + ) + context_lens = getattr(attn_metadata, "context_lens", None) + block_tables = getattr(attn_metadata, "block_tables", None) + if context_lens is None or block_tables is None: + raise _SparseUnavailable( + "RTP sparse decode indexer requires context_lens and block_tables." + ) + deepgemm_fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache_for_logits, + weights[:num_decode_tokens], + logits, + context_lens, + block_tables, + int(max_model_len), + KVBlockSize=runner_block_size, + Preshuffle=True, + ) + top_k_per_row_decode( + logits, + next_n, + context_lens, + topk_indices[:num_decode_tokens, :topk_tokens], + logits.shape[0], + logits.stride(0), + logits.stride(1), + ) + return weights + + def rtp_sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -1559,6 +1787,7 @@ def rtp_sparse_attn_indexer( and bool(getattr(context, "is_prefill", False)) and attn_metadata is not None and topk_indices_buffer is not None + and topk_indices_buffer.dim() == 2 and positions is not None ): max_seqlen_k = int(getattr(attn_metadata, "max_seqlen_k", 0) or 0) @@ -1590,6 +1819,33 @@ def rtp_sparse_attn_indexer( topk_indices_buffer[:num_tokens, :topk_tokens].copy_(causal_topk) return weights + if context is not None and attn_metadata is not None: + return _run_rtp_sparse_attn_indexer_topk_only( + hidden_states, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + context, + attn_metadata, + ) + from atom.models.deepseek_v2 import sparse_attn_indexer return sparse_attn_indexer( diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index 2999e022a6..ce0f9d2642 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -720,6 +720,10 @@ def _create_python_model(self): import atom from atom.model_loader.loader import load_model_in_plugin_mode + prepare_model = getattr(atom, "prepare_model", None) + if prepare_model is None: + from atom.plugin.prepare import prepare_model + target_device = torch.device( self.device if getattr(self, "device", None) else "cuda" ) @@ -740,7 +744,7 @@ def _create_python_model(self): torch.set_default_dtype(target_dtype) try: - atom_model = atom.prepare_model(config=self, engine="rtpllm") + atom_model = prepare_model(config=self, engine="rtpllm") if atom_model is None: raise ValueError("ATOM failed to create GLM5 model for rtp-llm plugin") diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index bd7f15c996..0eca3270f5 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -18,8 +18,18 @@ from atom.config import KVCacheTensor, get_current_atom_config from atom.model_ops.attention_gdn import GatedDeltaNet -from atom.model_ops.attention_mha import PagedAttentionImpl -from atom.model_ops.paged_attention import Attention as PagedAttention + +try: + from atom.model_ops.attention_mha import PagedAttentionImpl +except (ImportError, ModuleNotFoundError): + PagedAttentionImpl = type("PagedAttentionImpl", (), {}) +try: + from atom.model_ops.paged_attention import Attention as PagedAttention +except (ImportError, ModuleNotFoundError): + try: + from atom.model_ops.paged_attention import PagedAttention + except (ImportError, ModuleNotFoundError): + PagedAttention = type("PagedAttention", (), {}) from atom.model_ops.attentions.gdn_attn import ( GDNAttentionMetadata, compute_causal_conv1d_metadata, @@ -1007,6 +1017,29 @@ def _expand_block_table_for_atom_indexer_capture( ) return out_view + @classmethod + def _build_indexer_block_tables( + cls, + *, + block_table_i32: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_max_seq_len: int, + in_capture: bool, + cg_bufs: dict | None, + ) -> torch.Tensor: + del ( + cls, + seq_size_per_block, + kernel_seq_size_per_block, + cg_max_seq_len, + in_capture, + cg_bufs, + ) + # Base path (e.g. Qwen3.5): keep compact physical table layout and do not + # expand to indexer granularity. + return block_table_i32 + @classmethod def _resolve_plugin_block_table( cls, @@ -1603,7 +1636,6 @@ def build( context=context, num_tokens=int(positions.numel()), mla_layer_map=cls._resolve_mla_layer_map(resolved_layer_maps), - use_rtp_indexer_cache=cls._use_rtp_indexer_cache(), ) @staticmethod @@ -1670,6 +1702,11 @@ def _resolve_rtp_indexer_cache( f"allowed_last_dims={sorted(allowed_dims)})." ) + @classmethod + def _resolve_mla_layer_map(cls, layer_maps: LayerMaps) -> Dict[int, Any]: + del cls, layer_maps + return {} + @staticmethod def _build_fallback_indexer_cache( *, diff --git a/tests/plugin/test_rtpllm_glm5_indexer_contract.py b/tests/plugin/test_rtpllm_glm5_indexer_contract.py index 8369a54546..bd22838cb3 100644 --- a/tests/plugin/test_rtpllm_glm5_indexer_contract.py +++ b/tests/plugin/test_rtpllm_glm5_indexer_contract.py @@ -158,6 +158,43 @@ def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): assert indexer.sparse_attn_indexer_impl is rtp_op +def test_constructor_patches_indexer_forward_to_own_topk_buffer(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) + + class _ForwardIndexer: + def __init__(self): + self.topk_tokens = 4 + self.sparse_attn_indexer_impl = default_op + self.sparse_kv_indices_buffer = torch.empty(0, dtype=torch.int32) + self.seen_sparse_buffer = None + + def forward(self, hidden_states): + self.seen_sparse_buffer = self.sparse_kv_indices_buffer + return hidden_states + + indexer = _ForwardIndexer() + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + RTPMLAAttention(mla_modules=modules, sparse_backend=object()) + hidden_states = torch.empty(2, 8) + indexer.forward(hidden_states) + + assert indexer.sparse_attn_indexer_impl is rtp_op + assert indexer.topk_indices_buffer.shape == (2, 4) + assert indexer.topk_indices_buffer.dtype == torch.int32 + assert indexer.seen_sparse_buffer is indexer.topk_indices_buffer + + def _run_attention(attention, token_count: int): query = torch.empty(token_count, 6) compressed_kv = torch.empty(token_count, 8) diff --git a/tests/plugin/test_rtpllm_glm5_mla_patch.py b/tests/plugin/test_rtpllm_glm5_mla_patch.py index 78d0bc52ca..fb50f83d3b 100644 --- a/tests/plugin/test_rtpllm_glm5_mla_patch.py +++ b/tests/plugin/test_rtpllm_glm5_mla_patch.py @@ -2,7 +2,6 @@ from pathlib import Path - _ATOM_ROOT = Path(__file__).resolve().parents[2] @@ -21,3 +20,31 @@ def test_glm5_wrapper_does_not_import_or_call_deepseek_mla_patch(): assert "apply_deepseek_mla_rtpllm_patch" not in source + +def test_rtp_mla_patch_updates_deepseek_attention_symbol(monkeypatch): + import sys + import types + + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( + RTPMLAAttention, + apply_attention_mla_rtpllm_patch, + ) + + sentinel = object() + fake_ops = types.ModuleType("atom.model_ops") + fake_ops.Attention = sentinel + fake_base_attention = types.ModuleType("atom.model_ops.base_attention") + fake_base_attention.Attention = sentinel + fake_deepseek = types.ModuleType("atom.models.deepseek_v2") + fake_deepseek.Attention = sentinel + monkeypatch.setitem(sys.modules, "atom.model_ops", fake_ops) + monkeypatch.setitem( + sys.modules, "atom.model_ops.base_attention", fake_base_attention + ) + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + apply_attention_mla_rtpllm_patch() + + assert fake_ops.Attention is RTPMLAAttention + assert fake_base_attention.Attention is RTPMLAAttention + assert fake_deepseek.Attention is RTPMLAAttention diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index b1e7e99e12..c02c0b038c 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -83,6 +83,72 @@ def fake_sparse_attn_indexer(*args): assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) +def test_rtp_sparse_attn_indexer_uses_rtp_topk_path_when_context_exists(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = sys.modules["atom.utils.forward_context"] + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_prefill=False, is_dummy_run=False, batch_size=1), + attn_metadata=SimpleNamespace(max_seqlen_q=1), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + def _unexpected_call(*args, **kwargs): + raise AssertionError("RTP context path must not call deepseek sparse indexer") + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = _unexpected_call + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + expected = torch.empty(1) + calls = [] + + def _fake_topk_only(*args): + calls.append(args) + return expected + + monkeypatch.setattr( + module, "_run_rtp_sparse_attn_indexer_topk_only", _fake_topk_only + ) + tensor = torch.empty(1) + + output = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + torch.empty(1, 2048, dtype=torch.int32), + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][-2:] == ( + fake_forward_context.context, + fake_forward_context.attn_metadata, + ) + + def test_rtp_sparse_attn_indexer_fake_bridge_forwards_to_main_fake(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) calls = [] @@ -632,6 +698,23 @@ def test_real_sparse_cache_dtype_uses_aiter_fp8_layout(): assert impl._cache_dtype_name(torch.empty(1, 576, dtype=torch.bfloat16)) == "auto" +def test_sparse_index_converter_resolves_current_refactored_path(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + old_module_name = "atom.plugin.attention_mla_sparse" + new_module_name = "atom.plugin.vllm.attention.layer_sparse_mla" + monkeypatch.delitem(sys.modules, old_module_name, raising=False) + + fake_new_helpers = type(sys)(new_module_name) + + def fake_convert(): + return None + + fake_new_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem(sys.modules, new_module_name, fake_new_helpers) + + assert module._resolve_plugin_sparse_index_converter() is fake_convert + + def test_real_sparse_eager_metadata_workspace_skips_refill(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) metadata_calls = [] From d31dbb084fe96eb64dcbb0315de4a9fdb6a173cd Mon Sep 17 00:00:00 2001 From: Zhao An Date: Thu, 18 Jun 2026 13:14:02 +0000 Subject: [PATCH 13/20] fix: RTP plugin imports conflict after rebase main --- atom/plugin/rtpllm/__init__.py | 13 ++++---- .../attention_backend/rtp_mla_attention.py | 5 +-- atom/plugin/rtpllm/models/qwen3_5.py | 33 +++++++++++-------- ...est_rtpllm_glm5_sparse_backend_contract.py | 30 ++++++++++++----- .../test_rtpllm_glm5_wrapper_lifecycle.py | 13 ++++++++ 5 files changed, 63 insertions(+), 31 deletions(-) diff --git a/atom/plugin/rtpllm/__init__.py b/atom/plugin/rtpllm/__init__.py index 4dad126add..eee9517201 100644 --- a/atom/plugin/rtpllm/__init__.py +++ b/atom/plugin/rtpllm/__init__.py @@ -1,8 +1,7 @@ -try: - from .models import base_model_wrapper as _base_model_wrapper -except ModuleNotFoundError as exc: - if exc.name != "rtp_llm": - raise - _base_model_wrapper = None +"""RTP-LLM plugin helpers. -__all__ = ["_base_model_wrapper"] +Keep the package root import side-effect free. RTP external model registration +is triggered by importing ``atom.plugin.rtpllm.models``. +""" + +__all__: list[str] = [] diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index 9c510e063a..0109a1f667 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -234,10 +234,11 @@ def forward( def apply_attention_mla_rtpllm_patch() -> None: """Switch ATOM's generic Attention symbol to the RTP MLA adapter.""" + import importlib import sys - import atom.model_ops as ops - import atom.model_ops.base_attention as base_attention + ops = importlib.import_module("atom.model_ops") + base_attention = importlib.import_module("atom.model_ops.base_attention") ops.RTPMLAAttention = RTPMLAAttention ops.Attention = RTPMLAAttention diff --git a/atom/plugin/rtpllm/models/qwen3_5.py b/atom/plugin/rtpllm/models/qwen3_5.py index 4e44455fd5..e45dc7717f 100644 --- a/atom/plugin/rtpllm/models/qwen3_5.py +++ b/atom/plugin/rtpllm/models/qwen3_5.py @@ -13,18 +13,7 @@ from rtp_llm.ops.compute_ops import PyModelInputs, PyModelOutputs from rtp_llm.utils.model_weight import W -from atom.model_loader.loader import WeightsMapper -from atom.models.qwen3_5 import ( - detect_fused_expert_format, - get_fused_expert_mapping, - load_fused_expert_weights, -) -from atom.plugin.rtpllm.attention_backend import ( - apply_attention_gdn_rtpllm_patch, - apply_attention_mha_rtpllm_patch, -) from atom.plugin.rtpllm.models.qwen3_next import apply_qwen3_next_rtpllm_patch -from atom.plugin.rtpllm.utils import RTPForwardQwen35HybridContext logger = logging.getLogger("atom.plugin.rtpllm.models") @@ -126,8 +115,11 @@ def __init__( ) self._model_device = first_param.device self._model_dtype = first_param.dtype + from atom.plugin.rtpllm.utils import RTPForwardQwen35HybridContext + + self._rtp_forward_context_cls = RTPForwardQwen35HybridContext # Cache module layer maps once to avoid per-forward model.modules() traversal. - self._rtp_layer_maps = RTPForwardQwen35HybridContext.collect_layer_maps( + self._rtp_layer_maps = self._rtp_forward_context_cls.collect_layer_maps( model=self.model ) # Lazy-built in forward_context; invalidated by kv buffer signature change. @@ -468,7 +460,7 @@ def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutput ): inputs_embeds = inputs_embeds.to(dtype=model_dtype) - with RTPForwardQwen35HybridContext.bind( + with self._rtp_forward_context_cls.bind( model=self.model, runtime=self, inputs=inputs, @@ -586,7 +578,9 @@ def support_cuda_graph(self) -> bool: return True @staticmethod - def _make_qwen35_hf_mapper() -> WeightsMapper: + def _make_qwen35_hf_mapper(): + from atom.model_loader.loader import WeightsMapper + # Keep loading on outer text-only wrapper so packed_modules_mapping works. # Normalize checkpoint prefixes to match wrapper's weights_mapping rules. return WeightsMapper( @@ -754,6 +748,12 @@ def _load_fused_expert_weights_for_qwen35( shard_id: str, num_experts: int, ) -> bool: + from atom.models.qwen3_5 import ( + detect_fused_expert_format, + get_fused_expert_mapping, + load_fused_expert_weights, + ) + if not detect_fused_expert_format(original_name): return False mapping = get_fused_expert_mapping() @@ -771,6 +771,11 @@ def _load_fused_expert_weights_for_qwen35( try: # Keep RTP-specific behavior in plugin path only. _set_framework_backbone("rtpllm") + from atom.plugin.rtpllm.attention_backend import ( + apply_attention_gdn_rtpllm_patch, + apply_attention_mha_rtpllm_patch, + ) + apply_attention_gdn_rtpllm_patch() apply_attention_mha_rtpllm_patch() apply_qwen3_next_rtpllm_patch() diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index c02c0b038c..d6b16ab60b 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -37,6 +37,15 @@ def _load_sparse_backend(monkeypatch): return module.RTPSparseMlaBackend +def _forward_context_module(): + module = sys.modules.get("atom.utils.forward_context") + if module is None: + module = type(sys)("atom.utils.forward_context") + module.get_forward_context = lambda: None + sys.modules["atom.utils.forward_context"] = module + return module + + def test_rtp_sparse_attn_indexer_bridge_forwards_to_main_indexer(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) calls = [] @@ -85,7 +94,7 @@ def fake_sparse_attn_indexer(*args): def test_rtp_sparse_attn_indexer_uses_rtp_topk_path_when_context_exists(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() fake_forward_context = SimpleNamespace( context=SimpleNamespace(is_prefill=False, is_dummy_run=False, batch_size=1), attn_metadata=SimpleNamespace(max_seqlen_q=1), @@ -197,7 +206,7 @@ def fake_sparse_attn_indexer_fake(*args): def test_rtp_sparse_attn_indexer_short_prefill_fills_causal_topk(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() fake_forward_context = SimpleNamespace( context=SimpleNamespace(is_prefill=True, is_dummy_run=False), attn_metadata=SimpleNamespace(max_seqlen_k=4), @@ -331,7 +340,7 @@ def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): def test_sparse_backend_prefill_without_topk_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() fake_forward_context = SimpleNamespace( context=SimpleNamespace(is_dummy_run=False), attn_metadata=SimpleNamespace( @@ -361,7 +370,7 @@ def test_sparse_backend_prefill_without_topk_raises(monkeypatch): def test_sparse_backend_decode_without_topk_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) module = importlib.import_module(_SPARSE_BACKEND_MODULE) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() fake_forward_context = SimpleNamespace( context=SimpleNamespace(is_dummy_run=False), attn_metadata=SimpleNamespace( @@ -406,7 +415,7 @@ def test_sparse_backend_threads_kv_cache_and_layer_id_to_sparse_impl(monkeypatch def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() attn_metadata = SimpleNamespace(block_table="block-table", seq_lens="seq-lens") fake_forward_context = SimpleNamespace( @@ -432,7 +441,7 @@ def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): def test_sparse_backend_prefill_missing_sparse_kernel_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) module = importlib.import_module(_SPARSE_BACKEND_MODULE) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() attn_metadata = SimpleNamespace( plugin_metadata=SimpleNamespace(num_prefills=1, is_dummy_warmup=False) @@ -470,7 +479,7 @@ def forward(self, *args, **kwargs): def test_sparse_backend_decode_missing_sparse_kernel_still_raises(monkeypatch): backend_cls = _load_sparse_backend(monkeypatch) module = importlib.import_module(_SPARSE_BACKEND_MODULE) - forward_context_mod = sys.modules["atom.utils.forward_context"] + forward_context_mod = _forward_context_module() attn_metadata = SimpleNamespace( plugin_metadata=SimpleNamespace(num_prefills=0, is_dummy_warmup=False) @@ -555,7 +564,12 @@ def test_real_sparse_decode_uses_atom_aiter_metadata(monkeypatch): module = importlib.import_module(_SPARSE_BACKEND_MODULE) calls = {} - import aiter + aiter = type(sys)("aiter") + aiter.dtypes = SimpleNamespace( + fp8=torch.float8_e4m3fnuz, + d_dtypes={"fp16": torch.float16, "bf16": torch.bfloat16}, + ) + monkeypatch.setitem(sys.modules, "aiter", aiter) def fake_metadata_info(*args, **kwargs): calls["metadata_info"] = (args, kwargs) diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index 747b0e7ef1..ffe25f3a56 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -53,6 +53,10 @@ def __init__(self, **kwargs): def set_global_weight(self, name, tensor): self.global_weights[name] = tensor + class _FakeModelDeployWeightInfo: + pass + + fake_weight_info_mod.ModelDeployWeightInfo = _FakeModelDeployWeightInfo fake_weight_info_mod.ModelWeights = _FakeModelWeights fake_module_base_mod = ModuleType("rtp_llm.models_py.model_desc.module_base") @@ -168,6 +172,14 @@ def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): fake_modules = _install_fake_rtp_modules() fake_atom_model = MagicMock(name="atom_model") fake_atom_model.to.return_value = fake_atom_model + fake_utils_mod = ModuleType("atom.plugin.rtpllm.utils") + + class _FakeRTPForwardMLAContext: + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + fake_utils_mod.RTPForwardMLAContext = _FakeRTPForwardMLAContext with ( patch.dict( @@ -178,6 +190,7 @@ def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): os.environ, {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, ), + patch.dict(sys.modules, {"atom.plugin.rtpllm.utils": fake_utils_mod}), patch( "atom.prepare_model", return_value=fake_atom_model, create=True ) as prepare_model, From 21b84657a34fbd50d17e386a6fc0e70c26506ef0 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Thu, 18 Jun 2026 15:16:39 +0000 Subject: [PATCH 14/20] refactor: RTP GLM5 tests merge --- .../test_rtpllm_glm5_indexer_contract.py | 297 ------------------ .../test_rtpllm_glm5_mha_bridge_guard.py | 66 ---- .../test_rtpllm_glm5_mla_bridge_shape.py | 24 -- tests/plugin/test_rtpllm_glm5_mla_patch.py | 50 --- tests/plugin/test_rtpllm_glm5_ownership.py | 39 --- tests/plugin/test_rtpllm_glm5_registration.py | 78 ----- ...est_rtpllm_glm5_sparse_backend_contract.py | 275 ++++++++++++++++ .../test_rtpllm_glm5_wrapper_lifecycle.py | 225 ++++++++++++- 8 files changed, 499 insertions(+), 555 deletions(-) delete mode 100644 tests/plugin/test_rtpllm_glm5_indexer_contract.py delete mode 100644 tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py delete mode 100644 tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py delete mode 100644 tests/plugin/test_rtpllm_glm5_mla_patch.py delete mode 100644 tests/plugin/test_rtpllm_glm5_ownership.py delete mode 100644 tests/plugin/test_rtpllm_glm5_registration.py diff --git a/tests/plugin/test_rtpllm_glm5_indexer_contract.py b/tests/plugin/test_rtpllm_glm5_indexer_contract.py deleted file mode 100644 index bd22838cb3..0000000000 --- a/tests/plugin/test_rtpllm_glm5_indexer_contract.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Contract-executable tests for GLM5 RTP MLA M1.5 indexer behavior.""" - -import builtins -import sys -from types import SimpleNamespace - -import torch - -from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - -_FORBIDDEN_CUDA_SPARSE_MODULES = ( - "flashmla_sparse", - "flash_mla", - "sparse_mla", - "attention_mla_sparse", -) - - -def _guard_sparse_kernel_imports(monkeypatch): - original_import = builtins.__import__ - - def _guarded_import(name, *args, **kwargs): - if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): - raise AssertionError( - f"M1.5 tests must not import sparse MLA kernels: {name}" - ) - return original_import(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", _guarded_import) - - -class _FakeSparseBackend: - def __init__(self, v_head_dim: int): - self.v_head_dim = v_head_dim - self.calls = [] - - def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): - self.calls.append( - { - "q": q, - "compressed_kv": compressed_kv, - "k_pe": k_pe, - "kv_cache": kv_cache, - "layer_id": layer_id, - "topk_indices": topk_indices, - } - ) - return q.new_empty((q.shape[0], q.shape[1], self.v_head_dim)) - - -class _FakeIndexer: - def __init__(self, topk_values): - self.calls = [] - self.index_topk = topk_values.shape[1] - self.topk_indices_buffer = torch.full( - (topk_values.shape[0], topk_values.shape[1] + 2), - -1, - dtype=torch.int32, - ) - self.topk_indices_buffer[: topk_values.shape[0], : topk_values.shape[1]].copy_( - topk_values - ) - self.weights = torch.full(topk_values.shape, 99.0, dtype=torch.float32) - - def __call__(self, *args, **kwargs): - self.calls.append((args, kwargs)) - return self.weights - - -class _FakeQProj: - def __init__(self, output): - self.output = output - self.calls = [] - - def __call__(self, query, q_scale=None): - self.calls.append((query, q_scale)) - return self.output - - -class _FakeOProj: - def __init__(self): - self.calls = [] - - def __call__(self, tensor): - self.calls.append(tensor) - return tensor - - -def _make_attention(topk_values): - token_count = topk_values.shape[0] - num_heads = 2 - qk_head_dim = 4 - v_head_dim = 3 - projected_q = torch.arange( - token_count * num_heads * qk_head_dim, dtype=torch.float32 - ).reshape(token_count, num_heads * qk_head_dim) - backend = _FakeSparseBackend(v_head_dim=v_head_dim) - indexer = _FakeIndexer(topk_values) - modules = SimpleNamespace( - q_proj=_FakeQProj(projected_q), - o_proj=_FakeOProj(), - kv_b_proj=object(), - indexer=indexer, - v_head_dim=v_head_dim, - qk_head_dim=qk_head_dim, - num_heads=num_heads, - num_local_heads=num_heads, - index_topk=topk_values.shape[1], - ) - attention = RTPMLAAttention( - mla_modules=modules, - sparse_backend=backend, - layer_num=7, - kv_cache="kv-cache", - ) - return attention, modules, backend - - -def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): - topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) - indexer = SimpleNamespace(topk_indices_buffer=topk_buffer, index_topk=4) - modules = SimpleNamespace( - q_proj=object(), - o_proj=object(), - kv_b_proj=object(), - indexer=indexer, - v_head_dim=3, - ) - attention = RTPMLAAttention(mla_modules=modules) - - assert attention.indexer is indexer - assert attention.topk_indices_buffer is topk_buffer - - -def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): - default_op = object() - rtp_op = object() - monkeypatch.setattr( - torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False - ) - topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) - indexer = SimpleNamespace( - topk_indices_buffer=topk_buffer, - index_topk=4, - sparse_attn_indexer_impl=default_op, - ) - modules = SimpleNamespace( - q_proj=object(), - o_proj=object(), - kv_b_proj=object(), - indexer=indexer, - v_head_dim=3, - ) - - attention = RTPMLAAttention(mla_modules=modules, sparse_backend=object()) - - assert attention.indexer is indexer - assert indexer.sparse_attn_indexer_impl is rtp_op - - -def test_constructor_patches_indexer_forward_to_own_topk_buffer(monkeypatch): - default_op = object() - rtp_op = object() - monkeypatch.setattr( - torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False - ) - - class _ForwardIndexer: - def __init__(self): - self.topk_tokens = 4 - self.sparse_attn_indexer_impl = default_op - self.sparse_kv_indices_buffer = torch.empty(0, dtype=torch.int32) - self.seen_sparse_buffer = None - - def forward(self, hidden_states): - self.seen_sparse_buffer = self.sparse_kv_indices_buffer - return hidden_states - - indexer = _ForwardIndexer() - modules = SimpleNamespace( - q_proj=object(), - o_proj=object(), - kv_b_proj=object(), - indexer=indexer, - v_head_dim=3, - ) - - RTPMLAAttention(mla_modules=modules, sparse_backend=object()) - hidden_states = torch.empty(2, 8) - indexer.forward(hidden_states) - - assert indexer.sparse_attn_indexer_impl is rtp_op - assert indexer.topk_indices_buffer.shape == (2, 4) - assert indexer.topk_indices_buffer.dtype == torch.int32 - assert indexer.seen_sparse_buffer is indexer.topk_indices_buffer - - -def _run_attention(attention, token_count: int): - query = torch.empty(token_count, 6) - compressed_kv = torch.empty(token_count, 8) - k_rope = torch.empty(token_count, 3) - positions = torch.arange(token_count, dtype=torch.int32) - return attention.forward( - query, - compressed_kv, - k_rope, - positions=positions, - ) - - -def test_indexer_buffer_topk_is_passed_to_sparse_backend_when_emit_allowed(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) - attention, modules, backend = _make_attention(topk_values) - - _run_attention(attention, token_count=topk_values.shape[0]) - - assert modules.indexer.calls == [] - topk_indices = backend.calls[0]["topk_indices"] - assert topk_indices is not None - assert topk_indices.dtype == torch.int32 - assert topk_indices.shape == topk_values.shape - assert torch.equal(topk_indices, topk_values) - assert topk_indices is not modules.indexer.weights - assert not torch.equal(topk_indices.to(torch.float32), modules.indexer.weights) - - -def _patch_forward_context(monkeypatch, *, is_dummy_run, is_prefill, max_seqlen_k): - forward_context_mod = sys.modules["atom.utils.forward_context"] - - fake_forward_context = SimpleNamespace( - context=SimpleNamespace(is_dummy_run=is_dummy_run, is_prefill=is_prefill), - attn_metadata=SimpleNamespace(max_seqlen_k=max_seqlen_k), - ) - monkeypatch.setattr( - forward_context_mod, - "get_forward_context", - lambda: fake_forward_context, - raising=False, - ) - - -def test_dummy_run_does_not_emit_topk_to_sparse_backend(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - _patch_forward_context( - monkeypatch, - is_dummy_run=True, - is_prefill=False, - max_seqlen_k=4096, - ) - topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) - attention, modules, backend = _make_attention(topk_values) - - _run_attention(attention, token_count=topk_values.shape[0]) - - assert modules.indexer.calls == [] - assert backend.calls[0]["topk_indices"] is None - - -def test_short_prefill_emits_topk_to_sparse_backend(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - _patch_forward_context( - monkeypatch, - is_dummy_run=False, - is_prefill=True, - max_seqlen_k=4, - ) - topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) - attention, modules, backend = _make_attention(topk_values) - - _run_attention(attention, token_count=topk_values.shape[0]) - - assert modules.indexer.calls == [] - topk_indices = backend.calls[0]["topk_indices"] - assert topk_indices is not None - assert torch.equal(topk_indices, topk_values) - - -def test_prefill_within_topk_buffer_padding_still_emits_topk(monkeypatch): - _guard_sparse_kernel_imports(monkeypatch) - _patch_forward_context( - monkeypatch, - is_dummy_run=False, - is_prefill=True, - max_seqlen_k=5, - ) - topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) - attention, modules, backend = _make_attention(topk_values) - - _run_attention(attention, token_count=topk_values.shape[0]) - - assert modules.indexer.index_topk == 4 - assert modules.indexer.topk_indices_buffer.shape[1] == 6 - assert modules.indexer.calls == [] - topk_indices = backend.calls[0]["topk_indices"] - assert topk_indices is not None - assert torch.equal(topk_indices, topk_values) diff --git a/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py b/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py deleted file mode 100644 index 908b17cffb..0000000000 --- a/tests/plugin/test_rtpllm_glm5_mha_bridge_guard.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Static guards for the GLM5 rtp-llm plugin path.""" - -import ast -from pathlib import Path - - -_ATOM_ROOT = Path(__file__).resolve().parents[2] -_FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS = { - "flashmla_sparse", - "flash_mla", - "sparse_mla", - "attention_mla_sparse", -} - - -def _read_plugin_file(relative_path: str) -> str: - return (_ATOM_ROOT / relative_path).read_text() - - -def test_glm5_wrapper_does_not_use_mha_or_qwen_patches(): - source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") - - assert "RTPFullAttention" not in source - assert "apply_attention_mha_rtpllm_patch" not in source - assert "apply_attention_gdn_rtpllm_patch" not in source - assert "apply_qwen3_next_rtpllm_patch" not in source - - -def test_glm5_wrapper_does_not_reference_deepseek_mla_patch(): - source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") - - assert "apply_deepseek_mla_rtpllm_patch" not in source - - -def test_rtp_mla_prepare_does_not_keep_native_forward_mirror_helpers(): - assert not ( - _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" - ).exists() - - -def test_glm5_mla_backend_is_not_full_attention_adapter(): - source = _read_plugin_file("atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py") - - assert "class RTPMLAAttention" in source - assert "use_mla" in source - assert "RTPFullAttention" not in source - - -def test_sparse_mla_backend_has_no_import_time_cuda_sparse_kernel_dependencies(): - backend_path = _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py" - assert backend_path.exists() - - tree = ast.parse(backend_path.read_text()) - imported_modules = set() - for node in ast.walk(tree): - if isinstance(node, ast.Import): - imported_modules.update(alias.name for alias in node.names) - elif isinstance(node, ast.ImportFrom) and node.module is not None: - imported_modules.add(node.module) - - assert not any( - forbidden in module_name.split(".") - for module_name in imported_modules - for forbidden in _FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS - ) - diff --git a/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py b/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py deleted file mode 100644 index cd4c0602c6..0000000000 --- a/tests/plugin/test_rtpllm_glm5_mla_bridge_shape.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Shape-level tests for the GLM5 RTP MLA bridge.""" - -from types import SimpleNamespace - -import torch - -from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention - - -def test_mla_attention_legacy_boundary_shape_stays_executable_during_migration(): - q = torch.empty(2, 4, 256) - compressed_kv = torch.empty(2, 512) - k_pe = torch.empty(2, 64) - positions = torch.arange(2, dtype=torch.int32) - attention = RTPMLAAttention(mla_modules=SimpleNamespace(v_head_dim=128)) - - output = attention(q, compressed_kv, k_pe, positions=positions) - - assert output.shape == (2, 4, 128) - - -def test_mla_attention_is_marked_as_mla_adapter(): - assert RTPMLAAttention.use_mla is True - diff --git a/tests/plugin/test_rtpllm_glm5_mla_patch.py b/tests/plugin/test_rtpllm_glm5_mla_patch.py deleted file mode 100644 index fb50f83d3b..0000000000 --- a/tests/plugin/test_rtpllm_glm5_mla_patch.py +++ /dev/null @@ -1,50 +0,0 @@ -"""No-monkey-patch guards for GLM5 RTP MLA M1.5 forward.""" - -from pathlib import Path - -_ATOM_ROOT = Path(__file__).resolve().parents[2] - - -def _read_plugin_file(relative_path: str) -> str: - return (_ATOM_ROOT / relative_path).read_text() - - -def test_rtp_mla_prepare_no_longer_contains_deepseek_forward_monkey_patch(): - assert not ( - _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" - ).exists() - - -def test_glm5_wrapper_does_not_import_or_call_deepseek_mla_patch(): - source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") - - assert "apply_deepseek_mla_rtpllm_patch" not in source - - -def test_rtp_mla_patch_updates_deepseek_attention_symbol(monkeypatch): - import sys - import types - - from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( - RTPMLAAttention, - apply_attention_mla_rtpllm_patch, - ) - - sentinel = object() - fake_ops = types.ModuleType("atom.model_ops") - fake_ops.Attention = sentinel - fake_base_attention = types.ModuleType("atom.model_ops.base_attention") - fake_base_attention.Attention = sentinel - fake_deepseek = types.ModuleType("atom.models.deepseek_v2") - fake_deepseek.Attention = sentinel - monkeypatch.setitem(sys.modules, "atom.model_ops", fake_ops) - monkeypatch.setitem( - sys.modules, "atom.model_ops.base_attention", fake_base_attention - ) - monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) - - apply_attention_mla_rtpllm_patch() - - assert fake_ops.Attention is RTPMLAAttention - assert fake_base_attention.Attention is RTPMLAAttention - assert fake_deepseek.Attention is RTPMLAAttention diff --git a/tests/plugin/test_rtpllm_glm5_ownership.py b/tests/plugin/test_rtpllm_glm5_ownership.py deleted file mode 100644 index 1bfa19cfc9..0000000000 --- a/tests/plugin/test_rtpllm_glm5_ownership.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Ownership contract tests for GLM5 rtp-llm M0.""" - -from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( - GLM5_RTP_BRIDGE_MODE, - GLM5_RTP_BRIDGE_MODE_M0_DENSE, - GLM5_RTP_OWNERSHIP, -) - - -def test_glm5_bridge_mode_starts_in_m0_dense(): - assert GLM5_RTP_BRIDGE_MODE == GLM5_RTP_BRIDGE_MODE_M0_DENSE - - -def test_glm5_ownership_unique_and_separates_rope_paths(): - required = { - "main_q_norm", - "main_kv_norm", - "main_rope", - "main_kv_cache", - "indexer_k_norm", - "indexer_rope", - "indexer_cache", - "topk_selector", - } - - assert required <= set(GLM5_RTP_OWNERSHIP) - for key in required: - owner = GLM5_RTP_OWNERSHIP[key] - assert isinstance(owner, str) - assert owner - - assert GLM5_RTP_OWNERSHIP["main_rope"] != GLM5_RTP_OWNERSHIP["indexer_rope"] - - -def test_glm5_ownership_forbids_qwen_and_mha_components(): - forbidden = ("GatedDeltaNet", "RTPFullAttention", "Qwen3Next") - for owner in GLM5_RTP_OWNERSHIP.values(): - assert all(name not in owner for name in forbidden) - diff --git a/tests/plugin/test_rtpllm_glm5_registration.py b/tests/plugin/test_rtpllm_glm5_registration.py deleted file mode 100644 index 8abdb77e23..0000000000 --- a/tests/plugin/test_rtpllm_glm5_registration.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Tests for GLM5 rtp-llm plugin registration.""" - -import importlib -import sys -from types import ModuleType -from unittest.mock import MagicMock, call, patch - - -def _package(name: str) -> ModuleType: - module = ModuleType(name) - module.__path__ = [] - return module - - -def test_rtpllm_wrapper_registers_glm5_override_and_alias(): - register_model_mock = MagicMock() - - fake_rtp_register_mod = ModuleType("rtp_llm.model_factory_register") - fake_rtp_register_mod.register_model = register_model_mock - fake_rtp_register_mod._model_factory = {} - fake_rtp_register_mod._hf_architecture_2_ft = {} - - fake_atom_register_mod = ModuleType("atom.plugin.register") - fake_atom_register_mod._ATOM_SUPPORTED_MODELS = {} - - fake_atom_deepseek_mod = ModuleType("atom.models.deepseek_v2") - - class _FakeGlmMoeDsaForCausalLM: - pass - - fake_atom_deepseek_mod.GlmMoeDsaForCausalLM = _FakeGlmMoeDsaForCausalLM - - fake_atom_qwen_mod = ModuleType("atom.plugin.rtpllm.models.qwen3_5") - - class _FakeATOMQwen35Moe: - pass - - fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe - - fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") - - class _FakeATOMGlm5Moe: - pass - - fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe - - fake_modules = { - "rtp_llm": _package("rtp_llm"), - "rtp_llm.models": _package("rtp_llm.models"), - "rtp_llm.model_factory_register": fake_rtp_register_mod, - "atom.models.deepseek_v2": fake_atom_deepseek_mod, - "atom.plugin.register": fake_atom_register_mod, - "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, - "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, - } - - with patch.dict(sys.modules, fake_modules): - sys.modules.pop("atom.plugin.rtpllm.models", None) - sys.modules.pop("atom.plugin.rtpllm.models.base_model_wrapper", None) - importlib.import_module("atom.plugin.rtpllm.models") - - assert fake_rtp_register_mod._model_factory["glm_5"] is _FakeATOMGlm5Moe - assert ( - fake_rtp_register_mod._hf_architecture_2_ft["GlmMoeDsaForCausalLM"] - == "glm_5" - ) - assert ( - fake_atom_register_mod._ATOM_SUPPORTED_MODELS["GlmMoeDsaForCausalLM"] - is _FakeGlmMoeDsaForCausalLM - ) - register_model_mock.assert_has_calls( - [ - call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), - call("atom_glm5_moe", _FakeATOMGlm5Moe, []), - ], - any_order=False, - ) - diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index d6b16ab60b..6fdcaae730 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -882,3 +882,278 @@ def fake_mla_decode_fwd(*args, **kwargs): "Expected OOB paged_kv_indices to raise _SparseUnavailable" ) assert decode_called["value"] is False + + +def _load_rtp_mla_attention(): + module = importlib.import_module( + "atom.plugin.rtpllm.attention_backend.rtp_mla_attention" + ) + return module.RTPMLAAttention + + +class _FakeSparseBackend: + def __init__(self, v_head_dim: int): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + } + ) + return q.new_empty((q.shape[0], q.shape[1], self.v_head_dim)) + + +class _FakeIndexer: + def __init__(self, topk_values): + self.calls = [] + self.index_topk = topk_values.shape[1] + self.topk_indices_buffer = torch.full( + (topk_values.shape[0], topk_values.shape[1] + 2), + -1, + dtype=torch.int32, + ) + self.topk_indices_buffer[: topk_values.shape[0], : topk_values.shape[1]].copy_( + topk_values + ) + self.weights = torch.full(topk_values.shape, 99.0, dtype=torch.float32) + + def __call__(self, *args, **kwargs): + self.calls.append((args, kwargs)) + return self.weights + + +class _FakeQProj: + def __init__(self, output): + self.output = output + self.calls = [] + + def __call__(self, query, q_scale=None): + self.calls.append((query, q_scale)) + return self.output + + +class _FakeOProj: + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor + + +def _make_attention(topk_values): + token_count = topk_values.shape[0] + num_heads = 2 + qk_head_dim = 4 + v_head_dim = 3 + projected_q = torch.arange( + token_count * num_heads * qk_head_dim, dtype=torch.float32 + ).reshape(token_count, num_heads * qk_head_dim) + backend = _FakeSparseBackend(v_head_dim=v_head_dim) + indexer = _FakeIndexer(topk_values) + modules = SimpleNamespace( + q_proj=_FakeQProj(projected_q), + o_proj=_FakeOProj(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=v_head_dim, + qk_head_dim=qk_head_dim, + num_heads=num_heads, + num_local_heads=num_heads, + index_topk=topk_values.shape[1], + ) + attention = _load_rtp_mla_attention()( + mla_modules=modules, + sparse_backend=backend, + layer_num=7, + kv_cache="kv-cache", + ) + return attention, modules, backend + + +def _run_attention(attention, token_count: int): + query = torch.empty(token_count, 6) + compressed_kv = torch.empty(token_count, 8) + k_rope = torch.empty(token_count, 3) + positions = torch.arange(token_count, dtype=torch.int32) + return attention.forward( + query, + compressed_kv, + k_rope, + positions=positions, + ) + + +def _patch_forward_context(monkeypatch, *, is_dummy_run, is_prefill, max_seqlen_k): + forward_context_mod = sys.modules["atom.utils.forward_context"] + + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=is_dummy_run, is_prefill=is_prefill), + attn_metadata=SimpleNamespace(max_seqlen_k=max_seqlen_k), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + +def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace(topk_indices_buffer=topk_buffer, index_topk=4) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + attention = _load_rtp_mla_attention()(mla_modules=modules) + + assert attention.indexer is indexer + assert attention.topk_indices_buffer is topk_buffer + + +def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace( + topk_indices_buffer=topk_buffer, + index_topk=4, + sparse_attn_indexer_impl=default_op, + ) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + attention = _load_rtp_mla_attention()(mla_modules=modules, sparse_backend=object()) + + assert attention.indexer is indexer + assert indexer.sparse_attn_indexer_impl is rtp_op + + +def test_constructor_patches_indexer_forward_to_own_topk_buffer(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) + + class _ForwardIndexer: + def __init__(self): + self.topk_tokens = 4 + self.sparse_attn_indexer_impl = default_op + self.sparse_kv_indices_buffer = torch.empty(0, dtype=torch.int32) + self.seen_sparse_buffer = None + + def forward(self, hidden_states): + self.seen_sparse_buffer = self.sparse_kv_indices_buffer + return hidden_states + + indexer = _ForwardIndexer() + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + _load_rtp_mla_attention()(mla_modules=modules, sparse_backend=object()) + hidden_states = torch.empty(2, 8) + indexer.forward(hidden_states) + + assert indexer.sparse_attn_indexer_impl is rtp_op + assert indexer.topk_indices_buffer.shape == (2, 4) + assert indexer.topk_indices_buffer.dtype == torch.int32 + assert indexer.seen_sparse_buffer is indexer.topk_indices_buffer + + +def test_indexer_buffer_topk_is_passed_to_sparse_backend_when_emit_allowed(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert topk_indices.dtype == torch.int32 + assert topk_indices.shape == topk_values.shape + assert torch.equal(topk_indices, topk_values) + assert topk_indices is not modules.indexer.weights + assert not torch.equal(topk_indices.to(torch.float32), modules.indexer.weights) + + +def test_dummy_run_does_not_emit_topk_to_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=True, + is_prefill=False, + max_seqlen_k=4096, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + assert backend.calls[0]["topk_indices"] is None + + +def test_short_prefill_emits_topk_to_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=4, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) + + +def test_prefill_within_topk_buffer_padding_still_emits_topk(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=5, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.index_topk == 4 + assert modules.indexer.topk_indices_buffer.shape[1] == 6 + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index ffe25f3a56..b734cb37d2 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -1,14 +1,24 @@ """Lifecycle tests for the GLM5 rtp-llm wrapper.""" +import ast from contextlib import nullcontext import importlib import os +from pathlib import Path import sys from types import ModuleType, SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import torch +_ATOM_ROOT = Path(__file__).resolve().parents[2] +_FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS = { + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +} + def _package(name: str) -> ModuleType: module = ModuleType(name) @@ -168,6 +178,10 @@ def _patch_optional_attr(module, attr): return nullcontext(MagicMock(name=attr)) +def _read_plugin_file(relative_path: str) -> str: + return (_ATOM_ROOT / relative_path).read_text() + + def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): fake_modules = _install_fake_rtp_modules() fake_atom_model = MagicMock(name="atom_model") @@ -438,3 +452,212 @@ def test_glm5_runtime_graph_decode_ignores_stale_position_ids(): ) assert positions.cpu().tolist() == [34, 48, 49] + + +def test_rtpllm_wrapper_registers_glm5_override_and_alias(): + register_model_mock = MagicMock() + + fake_rtp_register_mod = ModuleType("rtp_llm.model_factory_register") + fake_rtp_register_mod.register_model = register_model_mock + fake_rtp_register_mod._model_factory = {} + fake_rtp_register_mod._hf_architecture_2_ft = {} + + fake_atom_register_mod = ModuleType("atom.plugin.register") + fake_atom_register_mod._ATOM_SUPPORTED_MODELS = {} + + fake_atom_deepseek_mod = ModuleType("atom.models.deepseek_v2") + + class _FakeGlmMoeDsaForCausalLM: + pass + + fake_atom_deepseek_mod.GlmMoeDsaForCausalLM = _FakeGlmMoeDsaForCausalLM + + fake_atom_qwen_mod = ModuleType("atom.plugin.rtpllm.models.qwen3_5") + + class _FakeATOMQwen35Moe: + pass + + fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe + + fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") + + class _FakeATOMGlm5Moe: + pass + + fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe + + fake_modules = { + "rtp_llm": _package("rtp_llm"), + "rtp_llm.models": _package("rtp_llm.models"), + "rtp_llm.model_factory_register": fake_rtp_register_mod, + "atom.models.deepseek_v2": fake_atom_deepseek_mod, + "atom.plugin.register": fake_atom_register_mod, + "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, + "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, + } + + with patch.dict(sys.modules, fake_modules): + sys.modules.pop("atom.plugin.rtpllm.models", None) + sys.modules.pop("atom.plugin.rtpllm.models.base_model_wrapper", None) + importlib.import_module("atom.plugin.rtpllm.models") + + assert fake_rtp_register_mod._model_factory["glm_5"] is _FakeATOMGlm5Moe + assert ( + fake_rtp_register_mod._hf_architecture_2_ft["GlmMoeDsaForCausalLM"] + == "glm_5" + ) + assert ( + fake_atom_register_mod._ATOM_SUPPORTED_MODELS["GlmMoeDsaForCausalLM"] + is _FakeGlmMoeDsaForCausalLM + ) + register_model_mock.assert_has_calls( + [ + call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), + call("atom_glm5_moe", _FakeATOMGlm5Moe, []), + ], + any_order=False, + ) + + +def test_glm5_bridge_mode_starts_in_m0_dense(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( + GLM5_RTP_BRIDGE_MODE, + GLM5_RTP_BRIDGE_MODE_M0_DENSE, + ) + + assert GLM5_RTP_BRIDGE_MODE == GLM5_RTP_BRIDGE_MODE_M0_DENSE + + +def test_glm5_ownership_unique_and_separates_rope_paths(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( + GLM5_RTP_OWNERSHIP, + ) + + required = { + "main_q_norm", + "main_kv_norm", + "main_rope", + "main_kv_cache", + "indexer_k_norm", + "indexer_rope", + "indexer_cache", + "topk_selector", + } + + assert required <= set(GLM5_RTP_OWNERSHIP) + for key in required: + owner = GLM5_RTP_OWNERSHIP[key] + assert isinstance(owner, str) + assert owner + + assert GLM5_RTP_OWNERSHIP["main_rope"] != GLM5_RTP_OWNERSHIP["indexer_rope"] + + +def test_glm5_ownership_forbids_qwen_and_mha_components(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( + GLM5_RTP_OWNERSHIP, + ) + + forbidden = ("GatedDeltaNet", "RTPFullAttention", "Qwen3Next") + for owner in GLM5_RTP_OWNERSHIP.values(): + assert all(name not in owner for name in forbidden) + + +def test_mla_attention_legacy_boundary_shape_stays_executable_during_migration(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + q = torch.empty(2, 4, 256) + compressed_kv = torch.empty(2, 512) + k_pe = torch.empty(2, 64) + positions = torch.arange(2, dtype=torch.int32) + attention = RTPMLAAttention(mla_modules=SimpleNamespace(v_head_dim=128)) + + output = attention(q, compressed_kv, k_pe, positions=positions) + + assert output.shape == (2, 4, 128) + + +def test_mla_attention_is_marked_as_mla_adapter(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + assert RTPMLAAttention.use_mla is True + + +def test_glm5_wrapper_does_not_use_mha_or_qwen_patches(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "RTPFullAttention" not in source + assert "apply_attention_mha_rtpllm_patch" not in source + assert "apply_attention_gdn_rtpllm_patch" not in source + assert "apply_qwen3_next_rtpllm_patch" not in source + + +def test_glm5_wrapper_does_not_import_or_call_deepseek_mla_patch(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "apply_deepseek_mla_rtpllm_patch" not in source + + +def test_rtp_mla_prepare_no_longer_contains_deepseek_forward_monkey_patch(): + assert not ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" + ).exists() + + +def test_glm5_mla_backend_is_not_full_attention_adapter(): + source = _read_plugin_file( + "atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py" + ) + + assert "class RTPMLAAttention" in source + assert "use_mla" in source + assert "RTPFullAttention" not in source + + +def test_sparse_mla_backend_has_no_import_time_cuda_sparse_kernel_dependencies(): + backend_path = ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py" + ) + assert backend_path.exists() + + tree = ast.parse(backend_path.read_text()) + imported_modules = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.update(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module is not None: + imported_modules.add(node.module) + + assert not any( + forbidden in module_name.split(".") + for module_name in imported_modules + for forbidden in _FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS + ) + + +def test_rtp_mla_patch_updates_deepseek_attention_symbol(monkeypatch): + import types + + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( + RTPMLAAttention, + apply_attention_mla_rtpllm_patch, + ) + + sentinel = object() + fake_ops = types.ModuleType("atom.model_ops") + fake_ops.Attention = sentinel + fake_base_attention = types.ModuleType("atom.model_ops.base_attention") + fake_base_attention.Attention = sentinel + fake_deepseek = types.ModuleType("atom.models.deepseek_v2") + fake_deepseek.Attention = sentinel + monkeypatch.setitem(sys.modules, "atom.model_ops", fake_ops) + monkeypatch.setitem( + sys.modules, "atom.model_ops.base_attention", fake_base_attention + ) + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + apply_attention_mla_rtpllm_patch() + + assert fake_ops.Attention is RTPMLAAttention + assert fake_base_attention.Attention is RTPMLAAttention + assert fake_deepseek.Attention is RTPMLAAttention From a5511850d79df1875e4a730a076a86f21fa7c65b Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 19 Jun 2026 05:57:50 +0000 Subject: [PATCH 15/20] refactor: cleanup GLM5 RTP sparse MLA backend --- .../rtp_sparse_mla_backend.py | 30 +++++-------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 21cb816ecb..39cbed07d3 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -60,7 +60,7 @@ class _AtomSparseMetadata: class _ContractSparseMlaImpl: - """CPU/mock sparse implementation used before the real RTP kernel is wired.""" + """Lightweight implementation for unit tests and explicit dependency injection.""" def __init__(self, v_head_dim: int) -> None: self.v_head_dim = int(v_head_dim) @@ -1138,24 +1138,6 @@ def _build_atom_sparse_metadata( page_size=page_size, ) - def _run_sparse_decode( - self, - *, - q_latent: torch.Tensor, - kv_cache_base: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: Any, - block_size: int, - ) -> torch.Tensor: - # Keep GLM5 sparse path aligned with ATOM native MLA kernels. - return self._run_aiter_sparse_decode( - q_latent=q_latent, - kv_cache_base=kv_cache_base, - topk_indices=topk_indices, - attn_metadata=attn_metadata, - block_size=block_size, - ) - def _run_aiter_sparse_decode( self, *, @@ -1357,7 +1339,7 @@ def forward( raise _SparseUnavailable( "GLM5 RTP sparse MLA requires physical block size." ) - latent_output = self._run_sparse_decode( + latent_output = self._run_aiter_sparse_decode( q_latent=q_latent, kv_cache_base=kv_cache_base, topk_indices=topk_indices, @@ -1387,10 +1369,12 @@ def forward( class RTPSparseMlaBackend: - """M2 sparse top-k consumption contract. + """Sparse MLA backend used by GLM5 RTP plugin mode. - This backend intentionally avoids importing RTP CUDA sparse kernels. It only - validates and threads the sparse contract so M2.5 can replace the mock impl. + Real GLM5 layers use ATOM-owned MLA modules and the AITER sparse decode + kernel. The lightweight implementation is kept for unit tests and explicit + injection only; production paths refuse dense fallback when sparse execution + is unavailable. """ def __init__( From d1ec87bf7dbb2b048911268e4c0e2df0937985ca Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 19 Jun 2026 06:35:14 +0000 Subject: [PATCH 16/20] refactor: RTP remove redundant labels --- .../rtpllm/attention_backend/__init__.py | 8 ++++---- .../attention_backend/rtp_mla_attention.py | 4 ++-- .../attention_backend/rtp_mla_metadata.py | 12 +++-------- .../rtp_sparse_mla_backend.py | 20 +++++++++---------- atom/plugin/rtpllm/models/glm5.py | 2 +- ...est_rtpllm_glm5_sparse_backend_contract.py | 6 +++--- .../test_rtpllm_glm5_wrapper_lifecycle.py | 8 ++++---- 7 files changed, 27 insertions(+), 33 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index 4045e2fe78..adb8dd8cda 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,7 +1,7 @@ from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch from .rtp_mla_metadata import ( - GLM5_RTP_BRIDGE_MODE, - GLM5_RTP_BRIDGE_MODE_M0_DENSE, + GLM5_RTP_MLA_MODE, + GLM5_RTP_MLA_MODE_DENSE, GLM5_RTP_OWNERSHIP, RTPMlaPluginMetadata, ) @@ -37,8 +37,8 @@ def __getattr__(name): "RTPFullAttention", "RTPMLAAttention", "RTPSparseMlaBackend", - "GLM5_RTP_BRIDGE_MODE", - "GLM5_RTP_BRIDGE_MODE_M0_DENSE", + "GLM5_RTP_MLA_MODE", + "GLM5_RTP_MLA_MODE_DENSE", "GLM5_RTP_OWNERSHIP", "RTPMlaPluginMetadata", "apply_attention_gdn_rtpllm_patch", diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index 0109a1f667..84233dd039 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -19,7 +19,7 @@ def _resolve_index_topk(attn) -> int: value = getattr(obj, attr, None) if obj is not None else None if value is not None: return int(value) - raise AttributeError("GLM5 RTP MLA M1 indexer requires index_topk/topk_tokens") + raise AttributeError("GLM5 RTP MLA indexer requires index_topk/topk_tokens") def _get_topk_indices_buffer(attn) -> torch.Tensor: @@ -32,7 +32,7 @@ def _get_topk_indices_buffer(attn) -> torch.Tensor: if buffer is None: buffer = getattr(attn, "_topk_indices_buffer", None) if buffer is None: - raise AttributeError("GLM5 RTP MLA M1 indexer requires topk_indices_buffer") + raise AttributeError("GLM5 RTP MLA indexer requires topk_indices_buffer") return buffer diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py index bcd1b20c5e..50997bc72e 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py @@ -7,9 +7,8 @@ import torch - -GLM5_RTP_BRIDGE_MODE_M0_DENSE = "m0_dense" -GLM5_RTP_BRIDGE_MODE = GLM5_RTP_BRIDGE_MODE_M0_DENSE +GLM5_RTP_MLA_MODE_DENSE = "dense" +GLM5_RTP_MLA_MODE = GLM5_RTP_MLA_MODE_DENSE GLM5_RTP_OWNERSHIP = { @@ -26,14 +25,9 @@ @dataclass(frozen=True) class RTPMlaPluginMetadata: - """Minimal M0 placeholder for RTP MLA metadata. - - M0 intentionally does not model indexer/top-k metadata. M1/M2 should extend - this structure instead of overloading MHA plugin metadata. - """ + """Metadata shared by GLM5 RTP MLA attention paths.""" is_prefill: bool slot_mapping: Optional[torch.Tensor] = None block_table: Optional[torch.Tensor] = None seq_lens: Optional[torch.Tensor] = None - diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 39cbed07d3..2daa35399f 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -1,4 +1,4 @@ -"""Contract-executable sparse MLA backend for GLM5 rtp-llm plugin mode.""" +"""Sparse MLA backend for GLM5 rtp-llm plugin mode.""" from __future__ import annotations @@ -21,9 +21,9 @@ def _resolve_plugin_sparse_index_converter(): """Resolve the plugin-style request-local topk to global KV index converter.""" errors: list[str] = [] for module_name in ( - # Old GLM5 RTP branch location. + # Compatibility import path used by earlier plugin layouts. "atom.plugin.attention_mla_sparse", - # Current refactored plugin location with the same call contract. + # Current plugin helper location with the same call signature. "atom.plugin.vllm.attention.layer_sparse_mla", ): try: @@ -59,7 +59,7 @@ class _AtomSparseMetadata: page_size: int -class _ContractSparseMlaImpl: +class _LightweightSparseMlaImpl: """Lightweight implementation for unit tests and explicit dependency injection.""" def __init__(self, v_head_dim: int) -> None: @@ -1394,7 +1394,7 @@ def __init__( self.v_head_dim = int(v_head_dim) if sparse_impl is not None: self.sparse_impl = sparse_impl - self._default_mock = False + self._uses_lightweight_impl = False elif mla_modules is not None and all( hasattr(mla_modules, attr) for attr in ( @@ -1410,10 +1410,10 @@ def __init__( v_head_dim=self.v_head_dim, scale=scale, ) - self._default_mock = False + self._uses_lightweight_impl = False else: - self.sparse_impl = _ContractSparseMlaImpl(self.v_head_dim) - self._default_mock = True + self.sparse_impl = _LightweightSparseMlaImpl(self.v_head_dim) + self._uses_lightweight_impl = True self._sparse_impl_accepts_positions = self._impl_accepts_positions( self.sparse_impl ) @@ -1492,13 +1492,13 @@ def forward( return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) if topk_indices is None: - if self._default_mock: + if self._uses_lightweight_impl: return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) raise _SparseUnavailable( "GLM5 RTP sparse MLA requires topk_indices; refusing dense fallback." ) self._validate_topk_indices(q, topk_indices) - if self._default_mock or not callable( + if self._uses_lightweight_impl or not callable( getattr(self.sparse_impl, "forward", None) ): raise _SparseUnavailable( diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py index ce0f9d2642..41c1b86131 100644 --- a/atom/plugin/rtpllm/models/glm5.py +++ b/atom/plugin/rtpllm/models/glm5.py @@ -757,7 +757,7 @@ def _create_python_model(self): getattr(atom_model, "model", None), "atom_config", None ) if atom_config is None: - # M0 tests use mocked ATOM models; real loading must expose atom_config. + # Unit tests may use mocked ATOM models; real loading must expose atom_config. atom_config = getattr(self, "atom_config", None) load_model_in_plugin_mode( diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py index 6fdcaae730..4a74ada455 100644 --- a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -1,4 +1,4 @@ -"""Contract-executable tests for GLM5 RTP MLA M2 sparse topk consumption.""" +"""Tests for GLM5 RTP MLA sparse topk consumption.""" import builtins import importlib @@ -23,7 +23,7 @@ def _guard_sparse_kernel_imports(monkeypatch): def _guarded_import(name, *args, **kwargs): if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): raise AssertionError( - f"M2 sparse contract must not import CUDA sparse kernel: {name}" + f"GLM5 RTP sparse tests must not import CUDA sparse kernel: {name}" ) return original_import(name, *args, **kwargs) @@ -302,7 +302,7 @@ def _build_backend(backend_cls, sparse_impl): kwargs["sparse_impl"] = sparse_impl else: raise AssertionError( - "RTPSparseMlaBackend must accept a mock sparse impl injection" + "RTPSparseMlaBackend must accept an injected sparse implementation" ) if "v_head_dim" in params: diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index b734cb37d2..3ebb14280c 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -519,13 +519,13 @@ class _FakeATOMGlm5Moe: ) -def test_glm5_bridge_mode_starts_in_m0_dense(): +def test_glm5_mla_mode_starts_in_dense_mode(): from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( - GLM5_RTP_BRIDGE_MODE, - GLM5_RTP_BRIDGE_MODE_M0_DENSE, + GLM5_RTP_MLA_MODE, + GLM5_RTP_MLA_MODE_DENSE, ) - assert GLM5_RTP_BRIDGE_MODE == GLM5_RTP_BRIDGE_MODE_M0_DENSE + assert GLM5_RTP_MLA_MODE == GLM5_RTP_MLA_MODE_DENSE def test_glm5_ownership_unique_and_separates_rope_paths(): From 8a441f9cfdfe428c6253681ccf613da063fa5dba Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 19 Jun 2026 06:57:18 +0000 Subject: [PATCH 17/20] refactor: RTP GLM5 remove redundant code --- .../rtpllm/attention_backend/__init__.py | 10 ----- .../attention_backend/rtp_mla_attention.py | 2 +- .../attention_backend/rtp_mla_metadata.py | 33 -------------- .../test_rtpllm_glm5_wrapper_lifecycle.py | 44 ------------------- 4 files changed, 1 insertion(+), 88 deletions(-) delete mode 100644 atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index adb8dd8cda..0e7f68318a 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,10 +1,4 @@ from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch -from .rtp_mla_metadata import ( - GLM5_RTP_MLA_MODE, - GLM5_RTP_MLA_MODE_DENSE, - GLM5_RTP_OWNERSHIP, - RTPMlaPluginMetadata, -) from .rtp_sparse_mla_backend import RTPSparseMlaBackend @@ -37,10 +31,6 @@ def __getattr__(name): "RTPFullAttention", "RTPMLAAttention", "RTPSparseMlaBackend", - "GLM5_RTP_MLA_MODE", - "GLM5_RTP_MLA_MODE_DENSE", - "GLM5_RTP_OWNERSHIP", - "RTPMlaPluginMetadata", "apply_attention_gdn_rtpllm_patch", "apply_attention_mha_rtpllm_patch", "apply_attention_mla_rtpllm_patch", diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py index 84233dd039..c6c3857f68 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -117,7 +117,7 @@ def __init__(self, *args, **kwargs) -> None: if self.indexer is not None else None ) - injected_backend = kwargs.get("sparse_backend", kwargs.get("dense_backend")) + injected_backend = kwargs.get("sparse_backend") if injected_backend is not None: self.sparse_backend = injected_backend elif mla_modules is not None: diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py deleted file mode 100644 index 50997bc72e..0000000000 --- a/atom/plugin/rtpllm/attention_backend/rtp_mla_metadata.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Metadata and static contracts for GLM5 MLA in rtp-llm plugin mode.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional - -import torch - -GLM5_RTP_MLA_MODE_DENSE = "dense" -GLM5_RTP_MLA_MODE = GLM5_RTP_MLA_MODE_DENSE - - -GLM5_RTP_OWNERSHIP = { - "main_q_norm": "DeepseekV2MLAAttention", - "main_kv_norm": "DeepseekV2MLAAttention", - "main_rope": "RTPMLAAttention", - "main_kv_cache": "RTPMLAAttention", - "indexer_k_norm": "Indexer", - "indexer_rope": "Indexer", - "indexer_cache": "Indexer", - "topk_selector": "Indexer", -} - - -@dataclass(frozen=True) -class RTPMlaPluginMetadata: - """Metadata shared by GLM5 RTP MLA attention paths.""" - - is_prefill: bool - slot_mapping: Optional[torch.Tensor] = None - block_table: Optional[torch.Tensor] = None - seq_lens: Optional[torch.Tensor] = None diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py index 3ebb14280c..d8c37cefad 100644 --- a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -519,50 +519,6 @@ class _FakeATOMGlm5Moe: ) -def test_glm5_mla_mode_starts_in_dense_mode(): - from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( - GLM5_RTP_MLA_MODE, - GLM5_RTP_MLA_MODE_DENSE, - ) - - assert GLM5_RTP_MLA_MODE == GLM5_RTP_MLA_MODE_DENSE - - -def test_glm5_ownership_unique_and_separates_rope_paths(): - from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( - GLM5_RTP_OWNERSHIP, - ) - - required = { - "main_q_norm", - "main_kv_norm", - "main_rope", - "main_kv_cache", - "indexer_k_norm", - "indexer_rope", - "indexer_cache", - "topk_selector", - } - - assert required <= set(GLM5_RTP_OWNERSHIP) - for key in required: - owner = GLM5_RTP_OWNERSHIP[key] - assert isinstance(owner, str) - assert owner - - assert GLM5_RTP_OWNERSHIP["main_rope"] != GLM5_RTP_OWNERSHIP["indexer_rope"] - - -def test_glm5_ownership_forbids_qwen_and_mha_components(): - from atom.plugin.rtpllm.attention_backend.rtp_mla_metadata import ( - GLM5_RTP_OWNERSHIP, - ) - - forbidden = ("GatedDeltaNet", "RTPFullAttention", "Qwen3Next") - for owner in GLM5_RTP_OWNERSHIP.values(): - assert all(name not in owner for name in forbidden) - - def test_mla_attention_legacy_boundary_shape_stays_executable_during_migration(): from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention From 3540e0ca32e8cb478683a16981af3d0f1957bc6e Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 19 Jun 2026 07:56:43 +0000 Subject: [PATCH 18/20] refactor: RTP GLM5 remove mla redundant code --- .../rtp_sparse_mla_backend.py | 10 --- atom/plugin/rtpllm/utils/forward_context.py | 82 ------------------- 2 files changed, 92 deletions(-) diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py index 2daa35399f..263863031e 100644 --- a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -209,16 +209,6 @@ def _kv_token_slot_capacity(kv_cache_base: torch.Tensor) -> int: return 0 return int(kv_cache_base.numel() // latent_dim) - @staticmethod - def _unwrap_linear_output(value: Any) -> torch.Tensor: - if isinstance(value, tuple): - value = value[0] - if not isinstance(value, torch.Tensor): - raise TypeError( - f"Expected kv_b_proj to return Tensor, got {type(value)!r}." - ) - return value - def _infer_num_heads(self, q: torch.Tensor) -> int: num_heads = int(q.shape[1]) if self.num_heads != num_heads: diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 0eca3270f5..03071fbcfd 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os -import warnings from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Iterator, Tuple @@ -365,22 +363,6 @@ def _select_block_table_for_layer( return by_group[gid] return getattr(attn_inputs, "kv_cache_kernel_block_id_device", None) - @staticmethod - def _select_physical_block_table_for_layer( - attn_inputs: Any, - group_id: int | None = None, - ) -> torch.Tensor | None: - # MLA cache writes use concat_and_cache_mla(slot_mapping), whose slot is - # indexed in the physical KV cache layout, not the smaller kernel block - # granularity used by some RTP attention kernels. - block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) - if block_table is not None: - return block_table - return RTPForwardContext._select_block_table_for_layer( - attn_inputs=attn_inputs, - group_id=group_id, - ) - @staticmethod def _recover_physical_block_table_from_kernel( kernel_block_table: torch.Tensor, @@ -1638,70 +1620,6 @@ def build( mla_layer_map=cls._resolve_mla_layer_map(resolved_layer_maps), ) - @staticmethod - def _use_rtp_indexer_cache() -> bool: - return os.getenv("ATOM_RTP_USE_RTP_INDEXER_CACHE", "0").strip().lower() in { - "1", - "true", - "yes", - "on", - } - - @staticmethod - def _resolve_rtp_indexer_cache( - *, - layer_num: int, - layer_cache: Any, - indexer: Any, - block_size: int, - ) -> torch.Tensor: - kv_scale_base = getattr(layer_cache, "kv_scale_base", None) - if kv_scale_base is None or kv_scale_base.numel() == 0: - raise ValueError( - f"Layer {layer_num} RTP indexer cache requires non-empty kv_scale_base." - ) - if kv_scale_base.dtype == torch.uint8: - kv_scale_base = kv_scale_base.view(dtypes.fp8) - if kv_scale_base.dtype != dtypes.fp8: - raise ValueError( - f"Layer {layer_num} RTP indexer cache dtype mismatch " - f"(got={kv_scale_base.dtype}, expected={dtypes.fp8} or torch.uint8)." - ) - if block_size <= 0: - raise ValueError( - f"Layer {layer_num} RTP indexer cache got invalid block_size={block_size}." - ) - head_dim = int(getattr(indexer, "head_dim", 0) or 0) - if head_dim <= 0: - raise ValueError( - f"Layer {layer_num} RTP indexer cache requires positive indexer.head_dim." - ) - if head_dim != 128: - warnings.warn( - "RTP indexer cache binding has only been layout-checked for " - "GLM5 head_dim=128; cross-kernel byte semantics are not verified " - f"for head_dim={head_dim}.", - RuntimeWarning, - stacklevel=2, - ) - expected_raw_dim = head_dim + (head_dim // 128) * 4 - expected_aligned_dim = ((expected_raw_dim + 15) // 16) * 16 - allowed_dims = {expected_raw_dim, expected_aligned_dim} - - if kv_scale_base.dim() == 3 and int(kv_scale_base.shape[-1]) in allowed_dims: - return kv_scale_base - if kv_scale_base.dim() == 2 and int(kv_scale_base.shape[1]) % block_size == 0: - per_token_dim = int(kv_scale_base.shape[1]) // block_size - if per_token_dim in allowed_dims: - return kv_scale_base.view( - kv_scale_base.shape[0], block_size, per_token_dim - ) - raise ValueError( - f"Layer {layer_num} RTP indexer cache layout mismatch " - f"(shape={tuple(kv_scale_base.shape)}, block_size={block_size}, " - f"allowed_last_dims={sorted(allowed_dims)})." - ) - @classmethod def _resolve_mla_layer_map(cls, layer_maps: LayerMaps) -> Dict[int, Any]: del cls, layer_maps From 0a3d321d6b992f01075220b4bc2ef55ec583658e Mon Sep 17 00:00:00 2001 From: Zhao An Date: Fri, 19 Jun 2026 09:30:22 +0000 Subject: [PATCH 19/20] fix: RTP Qwen35 use prewarmed req id buffer for RTP CUDA graphs --- atom/plugin/rtpllm/models/qwen3_5.py | 1 + atom/plugin/rtpllm/utils/forward_context.py | 30 +++++++++++--- .../test_rtpllm_forward_context_semantics.py | 39 +++++++++++++++++++ 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/atom/plugin/rtpllm/models/qwen3_5.py b/atom/plugin/rtpllm/models/qwen3_5.py index e45dc7717f..dbe21cc903 100644 --- a/atom/plugin/rtpllm/models/qwen3_5.py +++ b/atom/plugin/rtpllm/models/qwen3_5.py @@ -396,6 +396,7 @@ def _ensure_cuda_graph_prewarmed(self) -> None: 0, max_bs + 1, device=device, dtype=torch.int32 ), "seq_id": torch.arange(0, max_bs, device=device, dtype=torch.int64), + "seq_id_i32": torch.arange(0, max_bs, device=device, dtype=torch.int32), "block_col": torch.empty(max_bs, device=device, dtype=torch.int32), "block_col_i64": torch.empty(max_bs, device=device, dtype=torch.int64), "slot_base": torch.empty(max_bs, device=device, dtype=torch.int32), diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 03071fbcfd..0e536ace82 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -900,13 +900,33 @@ def _build_req_id_per_token( raise ValueError( "RTP plugin cannot build req_id_per_token for empty batch." ) + in_capture = torch.cuda.is_current_stream_capturing() + if cg_bufs is not None and "seq_id_i32" in cg_bufs: + seq_id_i32 = cg_bufs["seq_id_i32"] + if not isinstance(seq_id_i32, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed seq_id_i32 tensor." + ) + if int(seq_id_i32.shape[0]) < int(num_tokens): + raise RuntimeError( + "RTP plugin prewarmed seq_id_i32 buffer is too small " + f"(buffer={int(seq_id_i32.shape[0])}, required={int(num_tokens)})." + ) + if seq_id_i32.device != device or seq_id_i32.dtype != torch.int32: + raise RuntimeError( + "RTP plugin capture requires seq_id_i32 to be int32 on model device." + ) + if not seq_id_i32.is_contiguous(): + raise RuntimeError( + "RTP plugin capture requires seq_id_i32 to be contiguous." + ) + return seq_id_i32[:num_tokens] + if in_capture: + raise RuntimeError( + "RTP plugin capture requires prewarmed seq_id_i32 for req_id_per_token." + ) if int(num_tokens) == 0: return torch.empty((0,), dtype=torch.int32, device=device) - if cg_bufs is not None and "seq_id" in cg_bufs: - seq_id = cg_bufs["seq_id"][:num_tokens] - return seq_id.to( - device=device, dtype=torch.int32, non_blocking=True - ).contiguous() lengths = (query_start_loc[1:] - query_start_loc[:-1]).to(dtype=torch.int64) if not torch.cuda.is_current_stream_capturing() and int( lengths.sum().item() diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index 5f1761d8cc..e316879ee1 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -307,6 +307,45 @@ def test_plugin_attention_metadata_builds_req_id_per_token(): assert md.total_kv == 3 +def test_build_req_id_per_token_prefers_prewarmed_i32_buffer(monkeypatch): + query_start_loc = torch.tensor([0, 1, 2, 3], dtype=torch.int32) + seq_id_i32 = torch.arange(8, dtype=torch.int32) + + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + req_id = RTPForwardContext._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=3, + device=query_start_loc.device, + cg_bufs={ + "seq_id": torch.arange(8, dtype=torch.int64), + "seq_id_i32": seq_id_i32, + }, + ) + + assert req_id.dtype == torch.int32 + assert req_id.data_ptr() == seq_id_i32.data_ptr() + assert req_id.cpu().tolist() == [0, 1, 2] + + +def test_build_req_id_per_token_requires_prewarmed_i32_buffer_in_capture(monkeypatch): + query_start_loc = torch.tensor([0, 1], dtype=torch.int32) + + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + try: + RTPForwardContext._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=1, + device=query_start_loc.device, + cg_bufs={"seq_id": torch.arange(1, dtype=torch.int64)}, + ) + except RuntimeError as exc: + assert "prewarmed seq_id_i32" in str(exc) + else: + raise AssertionError("expected missing seq_id_i32 to fail during capture") + + def test_rtpllm_decode_seq_lens_uses_rtp_plus_one_in_graph_and_eager_modes(): input_lengths = torch.tensor([1], dtype=torch.int32) sequence_lengths = torch.tensor([35], dtype=torch.int32) From 7c6380b101fc8d4f98eb9c37020827c9b87578a0 Mon Sep 17 00:00:00 2001 From: Zhao An Date: Mon, 22 Jun 2026 09:26:31 +0000 Subject: [PATCH 20/20] fix: RTP remove redundant qwen35 code --- atom/plugin/rtpllm/models/qwen3_5.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/atom/plugin/rtpllm/models/qwen3_5.py b/atom/plugin/rtpllm/models/qwen3_5.py index dbe21cc903..0f1acd44a1 100644 --- a/atom/plugin/rtpllm/models/qwen3_5.py +++ b/atom/plugin/rtpllm/models/qwen3_5.py @@ -378,17 +378,9 @@ def _ensure_cuda_graph_prewarmed(self) -> None: or int(getattr(kv_cache, "seq_size_per_block", 0)) or 1 ) - seq_size_per_block = ( - int(getattr(kv_cache, "seq_size_per_block", 0)) - or kernel_seq_size_per_block - or 1 - ) max_blocks = ( int(max_seq_len) + kernel_seq_size_per_block - 1 ) // kernel_seq_size_per_block + 1 - physical_max_blocks = ( - int(max_seq_len) + seq_size_per_block - 1 - ) // seq_size_per_block # query_start_loc for decode: always [0, 1, 2, ..., bs], i.e. arange(bs+1). # seq_id for decode slot_mapping: seq_id[i] == i, i.e. arange(bs). self._cg_meta_bufs: dict = { @@ -396,7 +388,6 @@ def _ensure_cuda_graph_prewarmed(self) -> None: 0, max_bs + 1, device=device, dtype=torch.int32 ), "seq_id": torch.arange(0, max_bs, device=device, dtype=torch.int64), - "seq_id_i32": torch.arange(0, max_bs, device=device, dtype=torch.int32), "block_col": torch.empty(max_bs, device=device, dtype=torch.int32), "block_col_i64": torch.empty(max_bs, device=device, dtype=torch.int64), "slot_base": torch.empty(max_bs, device=device, dtype=torch.int32), @@ -406,16 +397,12 @@ def _ensure_cuda_graph_prewarmed(self) -> None: "block_table_i32": torch.empty( max_bs, max_blocks, device=device, dtype=torch.int32 ), - "physical_block_table_i32": torch.empty( - max_bs, max(physical_max_blocks, 1), device=device, dtype=torch.int32 - ), } self._cg_layers_prewarmed = True logger.info( "ATOM RTPFullAttention cuda-graph prewarmed for %d layers " "(max_num_tokens=%d, max_seq_len=%d, rtp_kv_heads=%s, " - "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d], " - "physical_block_table_i32[%dx%d])", + "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d])", len(self._atom_attn_pyobj._rtp_full_attn_layers), max_num_tokens, max_seq_len, @@ -424,8 +411,6 @@ def _ensure_cuda_graph_prewarmed(self) -> None: max_bs, max_bs, max_blocks, - max_bs, - max(physical_max_blocks, 1), ) def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutputs: