Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import inspect
import re

_LAYER_INDEX_RE = re.compile(r"\.(\d+)(?=\.|$)")


def transformers_v5_compat():
"""vLLM general plugin: patch transformers v5 config attrs that vLLM 0.16 still expects.

Expand All @@ -9,10 +15,172 @@
if not hasattr(Qwen3VLMoeTextConfig, "tie_word_embeddings"):
Qwen3VLMoeTextConfig.tie_word_embeddings = False

monkey_patch_indexcache()
_patch_qwen35_lora()
monkey_patch_dp_engine_core_pause_resume_deadlock()


def _skip_topk_from_prefix(config, prefix):
if not hasattr(config, "index_topk"):
return False

num_hidden_layers = getattr(config, "num_hidden_layers", None)
if num_hidden_layers is None:
return False

layer_idx = int(_LAYER_INDEX_RE.findall(prefix)[-1])
if layer_idx >= num_hidden_layers:
return False

return layer_idx % int(getattr(config, "index_topk_freq", 1)) != 0


def monkey_patch_indexcache():
if getattr(monkey_patch_indexcache, "_patched", False):
return

from contextvars import ContextVar

import torch
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
from vllm.model_executor.models.deepseek_v2 import (
DeepseekAttention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP,
DeepseekV2Model,
)

topk_context = ContextVar("prime_rl_indexcache_topk", default=None)

_original_mla_forward = MultiHeadLatentAttentionWrapper.forward
_original_attn_init = DeepseekV2MLAAttention.__init__
_original_model_forward = DeepseekV2Model.forward
_attn_init_signature = inspect.signature(_original_attn_init)

def _patched_mla_forward(self, positions, hidden_states, llama_4_scaling=None, prev_topk_indices=None):
if not self.indexer or not self.is_sparse:
return _original_mla_forward(self, positions, hidden_states, llama_4_scaling)

q_c = None
kv_lora = None

if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, "fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, "q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, "q_b_proj is required when q_lora_rank is not None"

qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, "kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, "q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]

kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)

q = q.view(-1, self.num_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1)

if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe)

if getattr(self, "skip_topk", False):
if prev_topk_indices is None:
raise ValueError("IndexCache shared layers require cached top-k indices.")
topk_indices = prev_topk_indices
else:
topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)

if llama_4_scaling is not None:
q *= llama_4_scaling

attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)

output = self.o_proj(attn_out)[0]
return output, topk_indices

def _patched_attn_init(self, *args, **kwargs):
_original_attn_init(self, *args, **kwargs)

bound = _attn_init_signature.bind_partial(self, *args, **kwargs)
config = bound.arguments.get("config")
prefix = bound.arguments.get("prefix", "")
self.mla_attn.skip_topk = _skip_topk_from_prefix(config, prefix)

def _patched_attn_forward(self, positions, hidden_states, llama_4_scaling, prev_topk_indices=None):
return self.mla_attn(
positions,
hidden_states,
llama_4_scaling,
prev_topk_indices=prev_topk_indices,
)

def _patched_decoder_forward(self, positions, hidden_states, residual, llama_4_scaling=None):
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
attn_kwargs["prev_topk_indices"] = topk_context.get()
hidden_states = self.self_attn(**attn_kwargs)
if isinstance(hidden_states, tuple):
hidden_states, topk_indices = hidden_states
else:
topk_indices = None

if self.use_mha:
topk_context.set(None)
else:
topk_context.set(topk_indices)

if not isinstance(self.self_attn, DeepseekAttention) and hidden_states.dtype == torch.float16:
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
residual *= 1.0 / self.routed_scaling_factor

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)

if hidden_states.dtype == torch.float16 and isinstance(self.mlp, DeepseekV2MLP):
hidden_states *= 1.0 / self.routed_scaling_factor

return hidden_states, residual

def _patched_model_forward(self, *args, **kwargs):
token = topk_context.set(None)
try:
return _original_model_forward(self, *args, **kwargs)
finally:
topk_context.reset(token)

MultiHeadLatentAttentionWrapper.forward = _patched_mla_forward
DeepseekV2MLAAttention.__init__ = _patched_attn_init
DeepseekV2MLAAttention.forward = _patched_attn_forward
DeepseekV2DecoderLayer.forward = _patched_decoder_forward
DeepseekV2Model.forward = _patched_model_forward
monkey_patch_indexcache._patched = True


def _patch_qwen35_lora():
"""Fix Qwen3.5 LoRA: align packed_modules_mapping with output_sizes.

Expand Down Expand Up @@ -499,19 +667,20 @@

Upstream: https://github.com/vllm-project/vllm/issues/23244
"""
import types

from vllm import envs
from vllm.distributed.utils import divide
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import UnfusedOAITritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import MoEPrepareAndFinalizeNoDPEPModular

from vllm import envs

Check failure on line 682 in src/prime_rl/inference/patches.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (I001)

src/prime_rl/inference/patches.py:670:5: I001 Import block is un-sorted or un-formatted

def _fixed_inject(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class GlmMoeDsaConfig(PretrainedConfig):
Whether to use interleaved RoPE style in the sparse indexer.
index_topk (`int`, defaults to 2048):
Number of top tokens selected by the sparse indexer.
index_topk_freq (`int`, defaults to 1):
Keep one sparse indexer every N layers. `1` disables cross-layer reuse.
scoring_func (`str`, defaults to `"sigmoid"`):
Scoring function for MoE router. Must match the vLLM inference
server's expectation (vLLM defaults to ``"softmax"`` when this
Expand Down Expand Up @@ -141,6 +143,7 @@ def __init__(
indexer_rope_interleave=True,
pad_token_id=154820,
index_topk=2048,
index_topk_freq=1,
scoring_func="sigmoid",
topk_method="noaux_tc",
use_grouped_mm=True,
Expand Down Expand Up @@ -194,6 +197,7 @@ def __init__(
self.index_head_dim = index_head_dim
self.indexer_rope_interleave = indexer_rope_interleave
self.index_topk = index_topk
self.index_topk_freq = index_topk_freq
self.scoring_func = scoring_func
self.topk_method = topk_method
self.use_grouped_mm = use_grouped_mm
Expand Down
13 changes: 9 additions & 4 deletions src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config))
self.self_attn.skip_topk = layer_idx % config.index_topk_freq != 0

moe_args = MoEArgs(
num_experts=config.n_routed_experts,
Expand Down Expand Up @@ -110,16 +111,18 @@ def forward(
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
ks: Optional[torch.Tensor] = None,
ke: Optional[torch.Tensor] = None,
cached_topk_indices: Optional[torch.Tensor] = None,
routed_experts: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.gather_for_cp(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states, topk_indices = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_topk_indices=cached_topk_indices,
)
hidden_states = self.shard_to_cp(hidden_states)
hidden_states = residual + hidden_states
Expand All @@ -128,7 +131,7 @@ def forward(
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, routed_experts=routed_experts)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, topk_indices


@auto_docstring
Expand Down Expand Up @@ -259,14 +262,16 @@ def forward(

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids_for_attn)
cached_topk_indices = None

for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None
hidden_states = decoder_layer(
hidden_states, cached_topk_indices = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_topk_indices=cached_topk_indices,
routed_experts=routed_experts_layer,
)

Expand Down
16 changes: 11 additions & 5 deletions src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import cast

import torch
from torch import nn
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self, args: SparseMlaAttentionArgs):

self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias)
self.indexer = Indexer(args)
self.skip_topk = False
self.scaling = self.qk_head_dim ** (-0.5)

def attn_projections(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -179,12 +181,16 @@ def forward(
position_embeddings: tuple[torch.Tensor, torch.Tensor],
ks: torch.Tensor | None = None,
ke: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
cached_topk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states)

indices = self.indexer.compute_sparse_indices(
hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings
)
if self.skip_topk:
indices = cast(torch.Tensor, cached_topk_indices)
else:
indices = self.indexer.compute_sparse_indices(
hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings
)

sparse_q, sparse_kv, w_v = self.mla_up_proj(
q_latent,
Expand All @@ -194,4 +200,4 @@ def forward(
)

out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling)
return self.output_proj(out, w_v), None
return self.output_proj(out, w_v), indices
Loading