Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8d2791e
feat: RTPLLM plugin GLM5 integration
zhaoan12-prc May 20, 2026
bf1c92f
feat: RTPLLM GLM5 enable cuda graph
zhaoan12-prc Jun 2, 2026
a49dc53
fix: RTP glm5 qwen35 cuda graph conflict
zhaoan12-prc Jun 4, 2026
d9cbe9d
fix: RTP crash when long input_len > 16384
zhaoan12-prc Jun 5, 2026
8281090
fix:[RTP] making GLM5 run true Sparse MLA
zhaoan12-prc Jun 5, 2026
8b92a5c
refactor: RTP glm5 code
zhaoan12-prc Jun 5, 2026
27be06e
feat: RTP glm5 optimize sparse decode path
zhaoan12-prc Jun 5, 2026
d6afeda
refactor: RTP remove redundant envs
zhaoan12-prc Jun 5, 2026
0afe687
refactor: [RTP] unify GLM5 MLA on sparse path, drop dead dense backend
zhaoan12-prc Jun 8, 2026
b4997d6
fix: RTP GLM5 prefil reuse Sparse MLA metadata
zhaoan12-prc Jun 13, 2026
d208756
fix: RTP GLM5 enable FP8 MLA path
zhaoan12-prc Jun 15, 2026
48089f9
feat: RTP GLM5 conflict issue after rebase
zhaoan12-prc Jun 17, 2026
d31dbb0
fix: RTP plugin imports conflict after rebase main
zhaoan12-prc Jun 18, 2026
21b8465
refactor: RTP GLM5 tests merge
zhaoan12-prc Jun 18, 2026
a551185
refactor: cleanup GLM5 RTP sparse MLA backend
zhaoan12-prc Jun 19, 2026
d1ec87b
refactor: RTP remove redundant labels
zhaoan12-prc Jun 19, 2026
8a441f9
refactor: RTP GLM5 remove redundant code
zhaoan12-prc Jun 19, 2026
3540e0c
refactor: RTP GLM5 remove mla redundant code
zhaoan12-prc Jun 19, 2026
0a3d321
fix: RTP Qwen35 use prewarmed req id buffer for RTP CUDA graphs
zhaoan12-prc Jun 19, 2026
7c6380b
fix: RTP remove redundant qwen35 code
zhaoan12-prc Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atom/plugin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _generate_atom_config_from_rtpllm_config(config: Any):

return Config(
model=rtpllm_model_config.ckpt_path,
max_num_batched_tokens=max(16384, max_generate_batch_size),
max_num_batched_tokens=max(max_model_len, max_generate_batch_size),
max_num_seqs=max_generate_batch_size,
max_model_len=max_model_len,
gpu_memory_utilization=0.9,
Expand Down
8 changes: 8 additions & 0 deletions atom/plugin/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _prepare_model_atom_sglang(
def _prepare_model_atom_rtpllm(
config: Any,
atom_config: Any,
model_arch: str,
model_cls: Any,
set_attn_cls: Any,
init_aiter_dist: Any,
Expand All @@ -120,6 +121,12 @@ def _prepare_model_atom_rtpllm(
)

set_attn_cls()
if model_arch == "GlmMoeDsaForCausalLM":
from atom.plugin.rtpllm.attention_backend import (
apply_attention_mla_rtpllm_patch,
)

apply_attention_mla_rtpllm_patch()

# init aiter dist for using aiter custom collective ops
init_aiter_dist(config=atom_config)
Expand Down Expand Up @@ -172,6 +179,7 @@ def prepare_model(config: Any, engine: str):
return _prepare_model_atom_rtpllm(
config,
atom_config,
model_arch,
model_cls,
set_attn_cls,
init_aiter_dist,
Expand Down
7 changes: 7 additions & 0 deletions atom/plugin/rtpllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""RTP-LLM plugin helpers.

Keep the package root import side-effect free. RTP external model registration
is triggered by importing ``atom.plugin.rtpllm.models``.
"""

__all__: list[str] = []
33 changes: 30 additions & 3 deletions atom/plugin/rtpllm/attention_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
from .attention_gdn import apply_attention_gdn_rtpllm_patch
from .attention_switch import apply_attention_mha_rtpllm_patch
from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention
from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch
from .rtp_sparse_mla_backend import RTPSparseMlaBackend


def __getattr__(name):
if name == "AttentionForRTPLLM":
from .rtp_full_attention import AttentionForRTPLLM

return AttentionForRTPLLM
if name == "RTPFullAttention":
from .rtp_full_attention import RTPFullAttention

return RTPFullAttention
if name == "RTPAttention":
from .rtp_full_attention import RTPFullAttention

return RTPFullAttention
if name == "apply_attention_gdn_rtpllm_patch":
from .attention_gdn import apply_attention_gdn_rtpllm_patch

return apply_attention_gdn_rtpllm_patch
if name == "apply_attention_mha_rtpllm_patch":
from .attention_switch import apply_attention_mha_rtpllm_patch

return apply_attention_mha_rtpllm_patch
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = [
"AttentionForRTPLLM",
"RTPFullAttention",
"RTPMLAAttention",
"RTPSparseMlaBackend",
"apply_attention_gdn_rtpllm_patch",
"apply_attention_mha_rtpllm_patch",
"apply_attention_mla_rtpllm_patch",
]
253 changes: 253 additions & 0 deletions atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""RTP-style MLA adapter for GLM5 rtp-llm plugin mode."""

from __future__ import annotations

import inspect
from types import MethodType
from typing import Optional

import torch


def _resolve_index_topk(attn) -> int:
for obj, attr in (
(getattr(attn, "indexer", None), "index_topk"),
(getattr(attn, "indexer", None), "topk_tokens"),
(attn, "index_topk"),
(getattr(attn, "config", None), "index_topk"),
):
value = getattr(obj, attr, None) if obj is not None else None
if value is not None:
return int(value)
raise AttributeError("GLM5 RTP MLA indexer requires index_topk/topk_tokens")


def _get_topk_indices_buffer(attn) -> torch.Tensor:
indexer = getattr(attn, "indexer", None)
buffer = (
getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None
)
if buffer is None:
buffer = getattr(attn, "topk_indices_buffer", None)
if buffer is None:
buffer = getattr(attn, "_topk_indices_buffer", None)
if buffer is None:
raise AttributeError("GLM5 RTP MLA indexer requires topk_indices_buffer")
return buffer


def _should_emit_topk_indices(attn) -> bool:
try:
from atom.utils.forward_context import get_forward_context

forward_context = get_forward_context()
except Exception:
return True

context = getattr(forward_context, "context", None)
if getattr(context, "is_dummy_run", False):
return False
return True


def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None:
if indexer is None or not hasattr(indexer, "sparse_attn_indexer_impl"):
return
__import__("atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend")
indexer.sparse_attn_indexer_impl = torch.ops.aiter.rtp_sparse_attn_indexer
if getattr(indexer, "_atom_rtp_topk_buffer_patched", False) or not hasattr(
indexer, "forward"
):
return
original_forward = indexer.forward

def _forward_with_topk_buffer(self, hidden_states, *args, **kwargs):
num_tokens = int(hidden_states.shape[0])
topk_tokens = getattr(self, "topk_tokens", None)
if topk_tokens is None:
topk_tokens = getattr(self, "index_topk")
topk_tokens = int(topk_tokens)
buffer = getattr(self, "topk_indices_buffer", None)
needs_new_buffer = (
buffer is None
or buffer.dim() != 2
or buffer.device != hidden_states.device
or int(buffer.shape[0]) < num_tokens
or int(buffer.shape[1]) < topk_tokens
)
if needs_new_buffer:
buffer = torch.empty(
num_tokens,
topk_tokens,
dtype=torch.int32,
device=hidden_states.device,
)
self.topk_indices_buffer = buffer
self.sparse_kv_indices_buffer = self.topk_indices_buffer
return original_forward(hidden_states, *args, **kwargs)

indexer.forward = MethodType(_forward_with_topk_buffer, indexer)
indexer._atom_rtp_topk_buffer_patched = True


class RTPMLAAttention:
"""RTP MLA adapter for the native GLM5 MLA call contract."""

use_mla = True

def __init__(self, *args, **kwargs) -> None:
self.args = args
self.kwargs = kwargs
mla_modules = kwargs.get("mla_modules")
self.mla_modules = mla_modules
self.q_proj = getattr(mla_modules, "q_proj", None)
self.o_proj = getattr(mla_modules, "o_proj", None)
self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None)
self.indexer = getattr(mla_modules, "indexer", None)
_use_rtp_sparse_attn_indexer(self.indexer)
self.qk_head_dim = getattr(mla_modules, "qk_head_dim", None)
self.v_head_dim = getattr(mla_modules, "v_head_dim", None)
self.q_lora_rank = getattr(mla_modules, "q_lora_rank", None)
self.kv_lora_rank = getattr(mla_modules, "kv_lora_rank", None)
self.num_heads = getattr(mla_modules, "num_heads", None)
self.num_local_heads = getattr(mla_modules, "num_local_heads", self.num_heads)
self.index_topk = getattr(mla_modules, "index_topk", None)
self.topk_indices_buffer = (
getattr(self.indexer, "topk_indices_buffer", None)
if self.indexer is not None
else None
)
injected_backend = kwargs.get("sparse_backend")
if injected_backend is not None:
self.sparse_backend = injected_backend
elif mla_modules is not None:
from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import (
RTPSparseMlaBackend,
)

self.sparse_backend = RTPSparseMlaBackend(
v_head_dim=mla_modules.v_head_dim,
mla_modules=mla_modules,
scale=kwargs.get("scale"),
)
else:
self.sparse_backend = None
self.kv_cache = kwargs.get("kv_cache")
self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0)))
self._sparse_backend_accepts_positions = (
self._backend_accepts_positions(self.sparse_backend)
if self.sparse_backend is not None
else False
)

@staticmethod
def _backend_accepts_positions(backend: object) -> bool:
try:
signature = inspect.signature(backend.forward)
except (AttributeError, TypeError, ValueError):
return False
return "positions" in signature.parameters or any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)

def _project_query(
self, query: torch.Tensor, q_scale: Optional[torch.Tensor]
) -> tuple[torch.Tensor, bool]:
if query.ndim == 3:
return query, False
if self.q_proj is None:
return query, False

q = self.q_proj(query, q_scale)
if q.ndim == 3:
return q, True

num_heads = (
self.num_local_heads if self.num_local_heads is not None else self.num_heads
)
if num_heads is None:
if self.qk_head_dim is None:
raise AttributeError("GLM5 RTP MLA native contract requires num_heads")
num_heads = q.shape[-1] // int(self.qk_head_dim)
if self.qk_head_dim is None:
self.qk_head_dim = q.shape[-1] // int(num_heads)
return q.reshape(-1, int(num_heads), int(self.qk_head_dim)), True

def _resolve_topk_indices(
self,
query: torch.Tensor,
q_scale: Optional[torch.Tensor],
positions: Optional[torch.Tensor],
explicit_topk_indices: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if explicit_topk_indices is not None:
return explicit_topk_indices
if self.indexer is None:
return None

if not _should_emit_topk_indices(self):
return None
index_topk = _resolve_index_topk(self)
return _get_topk_indices_buffer(self)[: query.shape[0], :index_topk]

def forward(
self,
query: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
positions: Optional[torch.Tensor] = None,
q_scale: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if self.sparse_backend is None:
raise NotImplementedError(
"RTPMLAAttention requires an attention backend for contract execution"
)
q, native_projected = self._project_query(query, q_scale)
topk_indices = self._resolve_topk_indices(
query,
q_scale,
positions,
kwargs.get("topk_indices", topk_indices),
)
forward_kwargs = {"topk_indices": topk_indices}
if self._sparse_backend_accepts_positions:
forward_kwargs["positions"] = positions
attn_output = self.sparse_backend.forward(
q,
compressed_kv,
k_pe,
self.kv_cache,
self.layer_id,
**forward_kwargs,
)
if native_projected and self.o_proj is not None:
attn_output = attn_output.reshape(attn_output.shape[0], -1).contiguous()
return self.o_proj(attn_output)
return attn_output

__call__ = forward


def apply_attention_mla_rtpllm_patch() -> None:
"""Switch ATOM's generic Attention symbol to the RTP MLA adapter."""

import importlib
import sys

ops = importlib.import_module("atom.model_ops")
base_attention = importlib.import_module("atom.model_ops.base_attention")

ops.RTPMLAAttention = RTPMLAAttention
ops.Attention = RTPMLAAttention
base_attention.Attention = RTPMLAAttention

deepseek_v2 = sys.modules.get("atom.models.deepseek_v2")
if deepseek_v2 is None:
try:
import atom.models.deepseek_v2 as deepseek_v2
except (ImportError, ModuleNotFoundError):
return
deepseek_v2.Attention = RTPMLAAttention
Loading
Loading