diff --git a/src/prime_rl/configs/rl.py b/src/prime_rl/configs/rl.py index d370a175b2..04a24f41d6 100644 --- a/src/prime_rl/configs/rl.py +++ b/src/prime_rl/configs/rl.py @@ -126,6 +126,13 @@ class SharedModelConfig(BaseConfig): Field(description="The name of the model to use."), ] = "Qwen/Qwen3-0.6B" + index_topk_freq: Annotated[ + int | None, + Field( + description="Override the loaded Hugging Face config's `index_topk_freq` for trainer and inference without editing the model directory.", + ), + ] = None + vlm: Annotated[ "VLMConfig | None", Field(description="VLM configuration. Set to enable vision-language model support."), @@ -524,10 +531,14 @@ def auto_setup_model(self): """Auto-setup shared model config for trainer, orchestrator, and inference.""" if self.model is not None: self.trainer.model.name = self.model.name + if self.trainer.model.index_topk_freq is None: + self.trainer.model.index_topk_freq = self.model.index_topk_freq if self.inference is not None: inference_model_explicitly_set = "name" in self.inference.model.model_fields_set if not inference_model_explicitly_set: self.inference.model.name = self.model.name + if self.inference.model.index_topk_freq is None: + self.inference.model.index_topk_freq = self.model.index_topk_freq self.orchestrator.model.name = self.inference.model.name else: self.orchestrator.model.name = self.model.name diff --git a/src/prime_rl/configs/shared.py b/src/prime_rl/configs/shared.py index 0438ac86fb..33dd31e83f 100644 --- a/src/prime_rl/configs/shared.py +++ b/src/prime_rl/configs/shared.py @@ -116,6 +116,13 @@ class BaseModelConfig(BaseConfig): name: Annotated[str, Field(description="Name or path of the HF model to use.")] = "Qwen/Qwen3-0.6B" + index_topk_freq: Annotated[ + int | None, + Field( + description="Override the loaded Hugging Face config's `index_topk_freq` at runtime without editing the model directory.", + ), + ] = None + trust_remote_code: Annotated[ bool, Field( diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 9b0ffe053e..07dd8691f8 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -1,3 +1,11 @@ +import inspect +import os +import re + +_LAYER_INDEX_RE = re.compile(r"\.(\d+)(?=\.|$)") +_INDEX_TOPK_FREQ_ENV_VAR = "PRIME_RL_INDEX_TOPK_FREQ" + + def transformers_v5_compat(): """vLLM general plugin: patch transformers v5 config attrs that vLLM 0.16 still expects. @@ -9,10 +17,177 @@ def transformers_v5_compat(): if not hasattr(Qwen3VLMoeTextConfig, "tie_word_embeddings"): Qwen3VLMoeTextConfig.tie_word_embeddings = False + monkey_patch_indexcache() _patch_qwen35_lora() monkey_patch_dp_engine_core_pause_resume_deadlock() +def _skip_topk_from_prefix(config, prefix): + if not hasattr(config, "index_topk"): + return False + + num_hidden_layers = getattr(config, "num_hidden_layers", None) + if num_hidden_layers is None: + return False + + layer_idx = int(_LAYER_INDEX_RE.findall(prefix)[-1]) + if layer_idx >= num_hidden_layers: + return False + + return layer_idx % int(getattr(config, "index_topk_freq", 1)) != 0 + + +def monkey_patch_indexcache(): + if getattr(monkey_patch_indexcache, "_patched", False): + return + + from contextvars import ContextVar + + import torch + from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper + from vllm.model_executor.models.deepseek_v2 import ( + DeepseekAttention, + DeepseekV2DecoderLayer, + DeepseekV2MLAAttention, + DeepseekV2MLP, + DeepseekV2Model, + ) + + topk_context = ContextVar("prime_rl_indexcache_topk", default=None) + raw_index_topk_freq = os.environ.get(_INDEX_TOPK_FREQ_ENV_VAR) + index_topk_freq = int(raw_index_topk_freq) if raw_index_topk_freq is not None else None + + _original_mla_forward = MultiHeadLatentAttentionWrapper.forward + _original_attn_init = DeepseekV2MLAAttention.__init__ + _original_model_forward = DeepseekV2Model.forward + _attn_init_signature = inspect.signature(_original_attn_init) + + def _patched_mla_forward(self, positions, hidden_states, llama_4_scaling=None, prev_topk_indices=None): + if not self.indexer or not self.is_sparse: + return _original_mla_forward(self, positions, hidden_states, llama_4_scaling) + + q_c = None + kv_lora = None + + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, "fused_qkv_a_proj is required when q_lora_rank is not None" + assert self.q_a_layernorm is not None, "q_a_layernorm is required when q_lora_rank is not None" + assert self.q_b_proj is not None, "q_b_proj is required when q_lora_rank is not None" + + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, "kv_a_proj_with_mqa is required when q_lora_rank is None" + assert self.q_proj is not None, "q_proj is required when q_lora_rank is None" + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) + + q = q.view(-1, self.num_heads, self.qk_head_dim) + k_pe = k_pe.unsqueeze(1) + + if self.rotary_emb is not None: + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(positions, q[..., self.qk_nope_head_dim :], k_pe) + + if getattr(self, "skip_topk", False): + if prev_topk_indices is None: + raise ValueError("IndexCache shared layers require cached top-k indices.") + topk_indices = prev_topk_indices + else: + topk_indices = self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb) + + if llama_4_scaling is not None: + q *= llama_4_scaling + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), + ) + + output = self.o_proj(attn_out)[0] + return output, topk_indices + + def _patched_attn_init(self, *args, **kwargs): + bound = _attn_init_signature.bind_partial(self, *args, **kwargs) + config = bound.arguments.get("config") + if config is not None and index_topk_freq is not None: + config.index_topk_freq = index_topk_freq + + _original_attn_init(self, *args, **kwargs) + + prefix = bound.arguments.get("prefix", "") + self.mla_attn.skip_topk = _skip_topk_from_prefix(config, prefix) + + def _patched_attn_forward(self, positions, hidden_states, llama_4_scaling, prev_topk_indices=None): + return self.mla_attn( + positions, + hidden_states, + llama_4_scaling, + prev_topk_indices=prev_topk_indices, + ) + + def _patched_decoder_forward(self, positions, hidden_states, residual, llama_4_scaling=None): + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_kwargs = { + "positions": positions, + "hidden_states": hidden_states, + } + if not self.use_mha: + attn_kwargs["llama_4_scaling"] = llama_4_scaling + attn_kwargs["prev_topk_indices"] = topk_context.get() + hidden_states = self.self_attn(**attn_kwargs) + if isinstance(hidden_states, tuple): + hidden_states, topk_indices = hidden_states + else: + topk_indices = None + + if self.use_mha: + topk_context.set(None) + else: + topk_context.set(topk_indices) + + if not isinstance(self.self_attn, DeepseekAttention) and hidden_states.dtype == torch.float16: + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + residual *= 1.0 / self.routed_scaling_factor + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if hidden_states.dtype == torch.float16 and isinstance(self.mlp, DeepseekV2MLP): + hidden_states *= 1.0 / self.routed_scaling_factor + + return hidden_states, residual + + def _patched_model_forward(self, *args, **kwargs): + token = topk_context.set(None) + try: + return _original_model_forward(self, *args, **kwargs) + finally: + topk_context.reset(token) + + MultiHeadLatentAttentionWrapper.forward = _patched_mla_forward + DeepseekV2MLAAttention.__init__ = _patched_attn_init + DeepseekV2MLAAttention.forward = _patched_attn_forward + DeepseekV2DecoderLayer.forward = _patched_decoder_forward + DeepseekV2Model.forward = _patched_model_forward + monkey_patch_indexcache._patched = True + + def _patch_qwen35_lora(): """Fix Qwen3.5 LoRA: align packed_modules_mapping with output_sizes. diff --git a/src/prime_rl/inference/server.py b/src/prime_rl/inference/server.py index 322f0145eb..1b6b758d77 100644 --- a/src/prime_rl/inference/server.py +++ b/src/prime_rl/inference/server.py @@ -3,6 +3,8 @@ from prime_rl.configs.inference import InferenceConfig from prime_rl.utils.config import cli +_INDEX_TOPK_FREQ_ENV_VAR = "PRIME_RL_INDEX_TOPK_FREQ" + def setup_vllm_env(config: InferenceConfig): """Set vLLM environment variables based on config. Must be called before importing vLLM.""" @@ -13,6 +15,11 @@ def setup_vllm_env(config: InferenceConfig): if config.enable_lora: os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True" + if config.model.index_topk_freq is not None: + os.environ[_INDEX_TOPK_FREQ_ENV_VAR] = str(config.model.index_topk_freq) + else: + os.environ.pop(_INDEX_TOPK_FREQ_ENV_VAR, None) + def main(): config = cli(InferenceConfig) diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 188b0f61f7..197f71f2b0 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -225,6 +225,9 @@ def get_model( config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code ), ) + if config.index_topk_freq is not None: + logger.info(f"Applying trainer index_topk_freq override: {config.index_topk_freq}") + model_config.index_topk_freq = config.index_topk_freq model_config.use_cache = False is_vlm_arch = is_vlm_architecture(model_config) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py index caca4fc5e9..8184954182 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -73,6 +73,8 @@ class GlmMoeDsaConfig(PretrainedConfig): Whether to use interleaved RoPE style in the sparse indexer. index_topk (`int`, defaults to 2048): Number of top tokens selected by the sparse indexer. + index_topk_freq (`int`, defaults to 1): + Keep one sparse indexer every N layers. `1` disables cross-layer reuse. scoring_func (`str`, defaults to `"sigmoid"`): Scoring function for MoE router. Must match the vLLM inference server's expectation (vLLM defaults to ``"softmax"`` when this @@ -141,6 +143,7 @@ def __init__( indexer_rope_interleave=True, pad_token_id=154820, index_topk=2048, + index_topk_freq=1, scoring_func="sigmoid", topk_method="noaux_tc", use_grouped_mm=True, @@ -194,6 +197,7 @@ def __init__( self.index_head_dim = index_head_dim self.indexer_rope_interleave = indexer_rope_interleave self.index_topk = index_topk + self.index_topk_freq = index_topk_freq self.scoring_func = scoring_func self.topk_method = topk_method self.use_grouped_mm = use_grouped_mm diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index a87074be7b..7970e7a33c 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -55,6 +55,7 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config)) + self.self_attn.skip_topk = layer_idx % config.index_topk_freq != 0 moe_args = MoEArgs( num_experts=config.n_routed_experts, @@ -110,16 +111,18 @@ def forward( position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ks: Optional[torch.Tensor] = None, ke: Optional[torch.Tensor] = None, + cached_topk_indices: Optional[torch.Tensor] = None, routed_experts: Optional[torch.LongTensor] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.gather_for_cp(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, topk_indices = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_topk_indices=cached_topk_indices, ) hidden_states = self.shard_to_cp(hidden_states) hidden_states = residual + hidden_states @@ -128,7 +131,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, routed_experts=routed_experts) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, topk_indices @auto_docstring @@ -259,14 +262,16 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids_for_attn) + cached_topk_indices = None for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None - hidden_states = decoder_layer( + hidden_states, cached_topk_indices = decoder_layer( hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_topk_indices=cached_topk_indices, routed_experts=routed_experts_layer, ) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py index e0d2021e57..0c7896b017 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import cast import torch from torch import nn @@ -126,6 +127,7 @@ def __init__(self, args: SparseMlaAttentionArgs): self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias) self.indexer = Indexer(args) + self.skip_topk = False self.scaling = self.qk_head_dim ** (-0.5) def attn_projections(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -179,12 +181,16 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], ks: torch.Tensor | None = None, ke: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + cached_topk_indices: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states) - indices = self.indexer.compute_sparse_indices( - hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings - ) + if self.skip_topk: + indices = cast(torch.Tensor, cached_topk_indices) + else: + indices = self.indexer.compute_sparse_indices( + hidden_states, q_latent, ks, ke, self.args.index_topk, position_embeddings + ) sparse_q, sparse_kv, w_v = self.mla_up_proj( q_latent, @@ -194,4 +200,4 @@ def forward( ) out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling) - return self.output_proj(out, w_v), None + return self.output_proj(out, w_v), indices