Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ class SharedModelConfig(BaseConfig):
Field(description="The name of the model to use."),
] = "Qwen/Qwen3-0.6B"

index_topk_freq: Annotated[
int | None,
Field(
description="Override the loaded Hugging Face config's `index_topk_freq` for trainer and inference without editing the model directory.",
),
] = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config field

Low Severity

This PR adds a new index_topk_freq config field to both SharedModelConfig in configs/rl.py and BaseModelConfig in configs/shared.py, but CHANGELOG.md has no corresponding entry. Per project rules, any PR that modifies configuration structures (added, removed, renamed, moved, or default value changes) must update the changelog.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


vlm: Annotated[
"VLMConfig | None",
Field(description="VLM configuration. Set to enable vision-language model support."),
Expand Down Expand Up @@ -524,10 +531,14 @@ def auto_setup_model(self):
"""Auto-setup shared model config for trainer, orchestrator, and inference."""
if self.model is not None:
self.trainer.model.name = self.model.name
if self.trainer.model.index_topk_freq is None:
self.trainer.model.index_topk_freq = self.model.index_topk_freq
if self.inference is not None:
inference_model_explicitly_set = "name" in self.inference.model.model_fields_set
if not inference_model_explicitly_set:
self.inference.model.name = self.model.name
if self.inference.model.index_topk_freq is None:
self.inference.model.index_topk_freq = self.model.index_topk_freq
self.orchestrator.model.name = self.inference.model.name
else:
self.orchestrator.model.name = self.model.name
Expand Down
7 changes: 7 additions & 0 deletions src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ class BaseModelConfig(BaseConfig):

name: Annotated[str, Field(description="Name or path of the HF model to use.")] = "Qwen/Qwen3-0.6B"

index_topk_freq: Annotated[
int | None,
Field(
description="Override the loaded Hugging Face config's `index_topk_freq` at runtime without editing the model directory.",
),
] = None

trust_remote_code: Annotated[
bool,
Field(
Expand Down
175 changes: 175 additions & 0 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import inspect
import os
import re

_LAYER_INDEX_RE = re.compile(r"\.(\d+)(?=\.|$)")
_INDEX_TOPK_FREQ_ENV_VAR = "PRIME_RL_INDEX_TOPK_FREQ"


def transformers_v5_compat():
"""vLLM general plugin: patch transformers v5 config attrs that vLLM 0.16 still expects.

Expand All @@ -9,10 +17,177 @@ def transformers_v5_compat():
if not hasattr(Qwen3VLMoeTextConfig, "tie_word_embeddings"):
Qwen3VLMoeTextConfig.tie_word_embeddings = False

monkey_patch_indexcache()
_patch_qwen35_lora()
monkey_patch_dp_engine_core_pause_resume_deadlock()


def _skip_topk_from_prefix(config, prefix):
if not hasattr(config, "index_topk"):
return False

num_hidden_layers = getattr(config, "num_hidden_layers", None)
if num_hidden_layers is None:
return False

layer_idx = int(_LAYER_INDEX_RE.findall(prefix)[-1])
if layer_idx >= num_hidden_layers:
return False

return layer_idx % int(getattr(config, "index_topk_freq", 1)) != 0


def monkey_patch_indexcache():
if getattr(monkey_patch_indexcache, "_patched", False):
return

from contextvars import ContextVar

import torch
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
from vllm.model_executor.models.deepseek_v2 import (
DeepseekAttention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP,
DeepseekV2Model,
)

topk_context = ContextVar("prime_rl_indexcache_topk", default=None)
raw_index_topk_freq = os.environ.get(_INDEX_TOPK_FREQ_ENV_VAR)
index_topk_freq = int(raw_index_topk_freq) if raw_index_topk_freq is not None else None

_original_mla_forward = MultiHeadLatentAttentionWrapper.forward
_original_attn_init = DeepseekV2MLAAttention.__init__
_original_model_forward = DeepseekV2Model.forward
_attn_init_signature = inspect.signature(_original_attn_init)

def _patched_mla_forward(self, positions, hidden_states, llama_4_scaling=None, prev_topk_indices=None):
if not self.indexer or not self.is_sparse:
return _original_mla_forward(self, positions, hidden_states, llama_4_scaling)

q_c = None
kv_lora = None

if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, "fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, "q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, "q_b_proj is required when q_lora_rank is not None"

qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
assert self.kv_a_proj_with_mqa is not None, "kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, "q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]

kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)

q = q.view(-1, self.num_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1)

if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe)

if getattr(self, "skip_topk", False):
if prev_topk_indices is None:
raise ValueError("IndexCache shared layers require cached top-k indices.")
topk_indices = prev_topk_indices
else:
topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)

if llama_4_scaling is not None:
q *= llama_4_scaling

attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)

output = self.o_proj(attn_out)[0]
return output, topk_indices

def _patched_attn_init(self, *args, **kwargs):
bound = _attn_init_signature.bind_partial(self, *args, **kwargs)
config = bound.arguments.get("config")
if config is not None and index_topk_freq is not None:
config.index_topk_freq = index_topk_freq

_original_attn_init(self, *args, **kwargs)

prefix = bound.arguments.get("prefix", "")
self.mla_attn.skip_topk = _skip_topk_from_prefix(config, prefix)

def _patched_attn_forward(self, positions, hidden_states, llama_4_scaling, prev_topk_indices=None):
return self.mla_attn(
positions,
hidden_states,
llama_4_scaling,
prev_topk_indices=prev_topk_indices,
)

def _patched_decoder_forward(self, positions, hidden_states, residual, llama_4_scaling=None):
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
attn_kwargs["prev_topk_indices"] = topk_context.get()
hidden_states = self.self_attn(**attn_kwargs)
if isinstance(hidden_states, tuple):
hidden_states, topk_indices = hidden_states
else:
topk_indices = None

if self.use_mha:
topk_context.set(None)
else:
topk_context.set(topk_indices)

if not isinstance(self.self_attn, DeepseekAttention) and hidden_states.dtype == torch.float16:
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
residual *= 1.0 / self.routed_scaling_factor

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)

if hidden_states.dtype == torch.float16 and isinstance(self.mlp, DeepseekV2MLP):
hidden_states *= 1.0 / self.routed_scaling_factor

return hidden_states, residual

def _patched_model_forward(self, *args, **kwargs):
token = topk_context.set(None)
try:
return _original_model_forward(self, *args, **kwargs)
finally:
topk_context.reset(token)

MultiHeadLatentAttentionWrapper.forward = _patched_mla_forward
DeepseekV2MLAAttention.__init__ = _patched_attn_init
DeepseekV2MLAAttention.forward = _patched_attn_forward
DeepseekV2DecoderLayer.forward = _patched_decoder_forward
DeepseekV2Model.forward = _patched_model_forward
monkey_patch_indexcache._patched = True


def _patch_qwen35_lora():
"""Fix Qwen3.5 LoRA: align packed_modules_mapping with output_sizes.

Expand Down
7 changes: 7 additions & 0 deletions src/prime_rl/inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from prime_rl.configs.inference import InferenceConfig
from prime_rl.utils.config import cli

_INDEX_TOPK_FREQ_ENV_VAR = "PRIME_RL_INDEX_TOPK_FREQ"


def setup_vllm_env(config: InferenceConfig):
"""Set vLLM environment variables based on config. Must be called before importing vLLM."""
Expand All @@ -13,6 +15,11 @@ def setup_vllm_env(config: InferenceConfig):
if config.enable_lora:
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"

if config.model.index_topk_freq is not None:
os.environ[_INDEX_TOPK_FREQ_ENV_VAR] = str(config.model.index_topk_freq)
else:
os.environ.pop(_INDEX_TOPK_FREQ_ENV_VAR, None)


def main():
config = cli(InferenceConfig)
Expand Down
3 changes: 3 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def get_model(
config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code
),
)
if config.index_topk_freq is not None:
logger.info(f"Applying trainer index_topk_freq override: {config.index_topk_freq}")
model_config.index_topk_freq = config.index_topk_freq
model_config.use_cache = False
is_vlm_arch = is_vlm_architecture(model_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class GlmMoeDsaConfig(PretrainedConfig):
Whether to use interleaved RoPE style in the sparse indexer.
index_topk (`int`, defaults to 2048):
Number of top tokens selected by the sparse indexer.
index_topk_freq (`int`, defaults to 1):
Keep one sparse indexer every N layers. `1` disables cross-layer reuse.
scoring_func (`str`, defaults to `"sigmoid"`):
Scoring function for MoE router. Must match the vLLM inference
server's expectation (vLLM defaults to ``"softmax"`` when this
Expand Down Expand Up @@ -141,6 +143,7 @@ def __init__(
indexer_rope_interleave=True,
pad_token_id=154820,
index_topk=2048,
index_topk_freq=1,
scoring_func="sigmoid",
topk_method="noaux_tc",
use_grouped_mm=True,
Expand Down Expand Up @@ -194,6 +197,7 @@ def __init__(
self.index_head_dim = index_head_dim
self.indexer_rope_interleave = indexer_rope_interleave
self.index_topk = index_topk
self.index_topk_freq = index_topk_freq
self.scoring_func = scoring_func
self.topk_method = topk_method
self.use_grouped_mm = use_grouped_mm
Expand Down
13 changes: 9 additions & 4 deletions src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config))
self.self_attn.skip_topk = layer_idx % config.index_topk_freq != 0

moe_args = MoEArgs(
num_experts=config.n_routed_experts,
Expand Down Expand Up @@ -110,16 +111,18 @@ def forward(
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
ks: Optional[torch.Tensor] = None,
ke: Optional[torch.Tensor] = None,
cached_topk_indices: Optional[torch.Tensor] = None,
routed_experts: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.gather_for_cp(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states, topk_indices = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_topk_indices=cached_topk_indices,
)
hidden_states = self.shard_to_cp(hidden_states)
hidden_states = residual + hidden_states
Expand All @@ -128,7 +131,7 @@ def forward(
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, routed_experts=routed_experts)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, topk_indices


@auto_docstring
Expand Down Expand Up @@ -259,14 +262,16 @@ def forward(

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids_for_attn)
cached_topk_indices = None

for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None
hidden_states = decoder_layer(
hidden_states, cached_topk_indices = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_topk_indices=cached_topk_indices,
routed_experts=routed_experts_layer,
)

Expand Down
16 changes: 11 additions & 5 deletions src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import cast

import torch
from torch import nn
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self, args: SparseMlaAttentionArgs):

self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias)
self.indexer = Indexer(args)
self.skip_topk = False
self.scaling = self.qk_head_dim ** (-0.5)

def attn_projections(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -179,12 +181,16 @@ def forward(
position_embeddings: tuple[torch.Tensor, torch.Tensor],
ks: torch.Tensor | None = None,
ke: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
cached_topk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states)

indices = self.indexer.compute_sparse_indices(
hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings
)
if self.skip_topk:
indices = cast(torch.Tensor, cached_topk_indices)
else:
indices = self.indexer.compute_sparse_indices(
hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings
)

sparse_q, sparse_kv, w_v = self.mla_up_proj(
q_latent,
Expand All @@ -194,4 +200,4 @@ def forward(
)

out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling)
return self.output_proj(out, w_v), None
return self.output_proj(out, w_v), indices
Loading