Skip to content

Commit 80d2e33

Browse files
halleriteclaude
andcommitted
feat: add Gemma4 dense and MoE custom implementations
Bump transformers to 5.5.0 and vLLM to 0.19.0 for Gemma4 support. Custom trainer implementations for both Gemma4-31B (dense) and Gemma4-26B-A4B (MoE) with: - Hybrid sliding window + global attention with dual RoPE - K=V sharing on global attention layers - QKV norms (q/k with scale, v without) - Attention scaling = 1.0 - Per-layer learnable scalar and scaled embeddings - Logit softcapping (tanh at 30.0) - MoE: shared expert + sparse experts in parallel, custom router - MoE weight conversion (HF fused gate_up_proj → PrimeRL w1/w2/w3) Unified registration: one Gemma4ForCausalLM handles both dense and MoE via the enable_moe_block config flag. Also includes: - ring_flash_attn compat shim for transformers 5.5 (removed symbol) - vLLM 0.19 server.py API adaptation (removed run_api_server_worker_proc) - perf.py None-guard for optional MoE config fields - Test fix for Qwen3.5 VLM strict dataclass validation Verified: 365 unit tests pass, SFT trains on both dense and MoE, RL e2e completes 2 steps on reverse-text with mismatch KL ~0.001. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent adbf84a commit 80d2e33

File tree

22 files changed

+1332
-473
lines changed

22 files changed

+1332
-473
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"torch>=2.9.0",
2020
"torchdata>=0.11.0",
2121
"transformers",
22-
"vllm>=0.17.0",
22+
"vllm>=0.19.0",
2323
"wandb>=0.24.2",
2424
"ring-flash-attn>=0.1.8",
2525
"prime>=0.5.37",
@@ -118,7 +118,7 @@ torch = { index = "pytorch-cu128" }
118118
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "adf8138" }
119119
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
120120
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
121-
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" }
121+
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
122122
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "abd9943b" }
123123
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
124124
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" }

src/prime_rl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import prime_rl._compat # noqa: F401 — must run before ring_flash_attn is imported

src/prime_rl/_compat.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Compatibility shim: ring_flash_attn + transformers >= 5.4.
2+
3+
ring_flash_attn 0.1.8 imports `is_flash_attn_greater_or_equal_2_10` from
4+
`transformers.modeling_flash_attention_utils`. This symbol was removed from
5+
that module in transformers 5.4 (still available as a deprecated function
6+
in `transformers.utils.import_utils`, scheduled for removal in 5.8).
7+
8+
ring_flash_attn's except-branch is a no-op (imports the same symbol again),
9+
so the import crashes on transformers >= 5.4. We patch the symbol back in as
10+
`True` — the check is dead code since no one uses flash_attn < 2.1.0 anymore.
11+
12+
Upstream fix: https://github.com/zhuzilin/ring-flash-attention/pull/85
13+
Remove this shim once ring_flash_attn ships a fixed version.
14+
"""
15+
16+
import transformers.modeling_flash_attention_utils as _mfau
17+
18+
if not hasattr(_mfau, "is_flash_attn_greater_or_equal_2_10"):
19+
_mfau.is_flash_attn_greater_or_equal_2_10 = True

src/prime_rl/inference/patches.py

Lines changed: 3 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,8 @@ def slice_lora_a(self, lora_a):
6464
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a
6565

6666

67-
# Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
68-
def monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode():
69-
from vllm.v1.metrics import loggers as vllm_metrics_loggers
70-
71-
_original_prometheus_stat_logger_init = vllm_metrics_loggers.PrometheusStatLogger.__init__
72-
73-
def _patched_prometheus_stat_logger_init(self, vllm_config, engine_indexes=None):
74-
"""Patched init that temporarily disables lora_config to skip the DP mode check."""
75-
original_lora_config = vllm_config.lora_config
76-
vllm_config.lora_config = None
77-
try:
78-
_original_prometheus_stat_logger_init(self, vllm_config, engine_indexes)
79-
finally:
80-
vllm_config.lora_config = original_lora_config
81-
# Re-initialize LoRA metrics if needed (after the DP check is bypassed)
82-
if original_lora_config is not None:
83-
self.labelname_max_lora = "max_lora"
84-
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
85-
self.labelname_running_lora_adapters = "running_lora_adapters"
86-
self.max_lora = original_lora_config.max_loras
87-
self.gauge_lora_info = vllm_metrics_loggers.PrometheusStatLogger._gauge_cls(
88-
name="vllm:lora_requests_info",
89-
documentation="Running stats on lora requests.",
90-
multiprocess_mode="sum",
91-
labelnames=[
92-
self.labelname_max_lora,
93-
self.labelname_waiting_lora_adapters,
94-
self.labelname_running_lora_adapters,
95-
],
96-
)
97-
98-
vllm_metrics_loggers.PrometheusStatLogger.__init__ = _patched_prometheus_stat_logger_init
99-
100-
10167
# Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
68+
# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
10269
def monkey_patch_load_lora_adapter():
10370
from http import HTTPStatus
10471

@@ -153,6 +120,7 @@ async def _patched_load_lora_adapter(
153120

154121

155122
# Monkeypatch LRUCacheWorkerLoRAManager to allow loading adapter inplace without doing it every request
123+
# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
156124
def monkey_patch_LRUCacheWorkerLoRAManager():
157125
from vllm.lora.worker_manager import LoRARequest, LRUCacheLoRAModelManager, LRUCacheWorkerLoRAManager
158126

@@ -278,109 +246,6 @@ def _patched_get_encode_kwargs(self):
278246
TokenizeParams.get_encode_kwargs = _patched_get_encode_kwargs
279247

280248

281-
def monkey_patch_hermes_tool_parser_thread_safety():
282-
"""Patch Hermes2ProToolParser to cache tokenizer encode/decode results.
283-
284-
The original __init__ calls tokenizer.encode() and tokenizer.decode() on
285-
every instantiation. Under concurrent load, the shared HuggingFace tokenizer's
286-
Rust backend panics with ``RuntimeError: Already borrowed`` because multiple
287-
threads mutably borrow the same internal state simultaneously.
288-
289-
Fix: run the first __init__ (which calls encode/decode) under a lock, cache
290-
the results, and reuse them for all subsequent instantiations without ever
291-
touching the tokenizer again.
292-
"""
293-
import threading
294-
295-
import regex as re
296-
from vllm.tool_parsers.abstract_tool_parser import ToolParser
297-
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
298-
299-
_original_init = Hermes2ProToolParser.__init__
300-
_cache: dict[int, dict] = {}
301-
_lock = threading.Lock()
302-
303-
def _patched_init(self, tokenizer):
304-
from vllm.tokenizers.mistral import MistralTokenizer
305-
306-
# Resolve the actual tokenizer that __init__ will use for encode/decode
307-
actual_tokenizer = tokenizer.tokenizer if isinstance(tokenizer, MistralTokenizer) else tokenizer
308-
key = id(actual_tokenizer)
309-
310-
if key in _cache:
311-
# Fast path: skip encode/decode entirely, set up instance from cache
312-
ToolParser.__init__(self, tokenizer)
313-
if isinstance(tokenizer, MistralTokenizer):
314-
self.model_tokenizer = tokenizer.tokenizer
315-
self.current_tool_name_sent = False
316-
self.prev_tool_call_arr = []
317-
self.current_tool_id = -1
318-
self.streamed_args_for_tool = []
319-
self.tool_call_start_token = "<tool_call>"
320-
self.tool_call_end_token = "</tool_call>"
321-
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
322-
self.scratch_pad_regex = re.compile(r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
323-
cached = _cache[key]
324-
self.tool_call_start_token_ids = cached["start_ids"]
325-
self.tool_call_end_token_ids = cached["end_ids"]
326-
self.tool_call_start_token_array = cached["start_array"]
327-
self.tool_call_end_token_array = cached["end_array"]
328-
self.buffered_delta_text = ""
329-
return
330-
331-
# Slow path: first instantiation for this tokenizer, run under lock
332-
with _lock:
333-
if key in _cache:
334-
# Another thread populated it while we waited
335-
_patched_init(self, tokenizer)
336-
return
337-
_original_init(self, tokenizer)
338-
_cache[key] = {
339-
"start_ids": self.tool_call_start_token_ids,
340-
"end_ids": self.tool_call_end_token_ids,
341-
"start_array": self.tool_call_start_token_array,
342-
"end_array": self.tool_call_end_token_array,
343-
}
344-
345-
Hermes2ProToolParser.__init__ = _patched_init
346-
347-
348-
def monkey_patch_tokenizer_thread_safety():
349-
"""Patch HuggingFace tokenizer to make _encode_plus thread-safe.
350-
351-
Under concurrent request load, vLLM's API server calls _encode_plus from
352-
multiple async handlers simultaneously. _encode_plus mutates the Rust
353-
tokenizer's internal state via set_truncation_and_padding (enable_truncation/
354-
enable_padding) and encode_special_tokens. The Rust backend uses RefCell-style
355-
borrow tracking (PyO3), and concurrent mutable borrows cause it to panic
356-
with ``RuntimeError: Already borrowed``.
357-
358-
Fix: wrap the entire _encode_plus method in a per-tokenizer threading lock
359-
so that state mutation and the subsequent encode call are atomic.
360-
"""
361-
import threading
362-
363-
from transformers import PreTrainedTokenizerFast
364-
365-
_original_encode_plus = PreTrainedTokenizerFast._encode_plus
366-
_locks: dict[int, threading.Lock] = {}
367-
_meta_lock = threading.Lock()
368-
369-
def _get_lock(tokenizer_id: int) -> threading.Lock:
370-
if tokenizer_id not in _locks:
371-
with _meta_lock:
372-
if tokenizer_id not in _locks:
373-
_locks[tokenizer_id] = threading.Lock()
374-
return _locks[tokenizer_id]
375-
376-
def _patched_encode_plus(self, *args, **kwargs):
377-
lock = _get_lock(id(self._tokenizer))
378-
with lock:
379-
return _original_encode_plus(self, *args, **kwargs)
380-
381-
PreTrainedTokenizerFast._encode_plus = _patched_encode_plus
382-
383-
384249
def monkey_patch_minimax_m2_for_lora():
385250
"""Patch vLLM's MiniMaxM2 model for LoRA compatibility.
386251
@@ -457,7 +322,7 @@ def _patched_forward(self, hidden_states):
457322

458323

459324
def monkey_patch_harmony_stop_token_propagation():
460-
"""Fix: vLLM 0.17.0 doesn't merge harmony stop tokens into per-request SamplingParams.
325+
"""Fix: vLLM doesn't merge harmony stop tokens into per-request SamplingParams.
461326
462327
The harmony mode sets stop_token_ids (including <|call|> and <|return|>) in
463328
default_sampling_params at server init, but ChatCompletionRequest.to_sampling_params()

src/prime_rl/inference/vllm/server.py

Lines changed: 14 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from fastapi.responses import JSONResponse, StreamingResponse
88
from starlette.datastructures import State
99
from vllm.engine.protocol import EngineClient
10-
from vllm.entrypoints.chat_utils import load_chat_template
11-
from vllm.entrypoints.cli.serve import run_api_server_worker_proc
12-
from vllm.entrypoints.logger import RequestLogger
1310
from vllm.entrypoints.openai.api_server import init_app_state
1411
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionResponse
1512
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
@@ -136,30 +133,23 @@ def resolve_tool_call_parser(model_name: str, tool_call_parser: str | None) -> s
136133
logger = get_logger()
137134
from prime_rl.inference.patches import (
138135
monkey_patch_harmony_stop_token_propagation,
139-
monkey_patch_hermes_tool_parser_thread_safety,
140136
monkey_patch_load_lora_adapter,
141-
monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode,
142137
monkey_patch_tokenize_params_validation,
143-
monkey_patch_tokenizer_thread_safety,
144138
)
145139
from prime_rl.inference.vllm.serving_chat_with_tokens import (
146140
ChatCompletionRequestWithTokens,
147141
OpenAIServingChatWithTokens,
148142
)
149143

150-
# NOTE: Fix harmony stop token propagation for GPT-OSS models (vLLM 0.17.0 bug)
144+
# NOTE: Fix harmony stop token propagation for GPT-OSS models
145+
# Upstream issue still open: https://github.com/vllm-project/vllm/issues/22519
151146
monkey_patch_harmony_stop_token_propagation()
152-
# NOTE: Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
153-
monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode()
154147
# NOTE: Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
148+
# May be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
155149
monkey_patch_load_lora_adapter()
156150
# NOTE: Monkeypatch TokenizeParams to fix overly conservative validation
151+
# Still needed in vLLM 0.19 — upstream rejects prompt_len > max_model_len - max_tokens
157152
monkey_patch_tokenize_params_validation()
158-
# NOTE: Monkeypatch Hermes tool parser to fix "Already borrowed" RuntimeError under concurrent load
159-
monkey_patch_hermes_tool_parser_thread_safety()
160-
# NOTE: Monkeypatch HF tokenizer to fix "Already borrowed" RuntimeError during concurrent chat template processing
161-
# Can be removed once https://github.com/vllm-project/vllm/pull/36557 is merged and we upgrade vllm
162-
monkey_patch_tokenizer_thread_safety()
163153

164154
logger = init_logger("vllm.entrypoints.openai.api_server")
165155

@@ -279,82 +269,34 @@ async def custom_init_app_state(
279269
):
280270
"""
281271
Modifies init_app_state:
282-
1. Set up the custom OpenAIServingChatWithTokens state.
283-
2. Monkey-patch to allow updating lora adapters in-place.
272+
1. Call the original init_app_state to set up standard state.
273+
2. Replace the serving_chat with our OpenAIServingChatWithTokens wrapper.
284274
"""
285-
# Setup the regular app state first (in-place)
286275
await init_app_state(engine_client, state, args, supported_tasks)
287276

288-
# NOTE: Initialize the custom OpenAIServingChatWithTokens state here
289-
# TODO: Here, we repeat some calls done in init_app_state to be able to
290-
# correctly set up the OpenAIServingChatWithTokens state, which is a bit
291-
# brittle, and could probably be made nicer
292-
if args.enable_log_requests:
293-
request_logger = RequestLogger(max_log_len=args.max_log_len)
294-
else:
295-
request_logger = None
296-
297-
resolved_chat_template = load_chat_template(args.chat_template)
298-
299-
chat_kwargs = dict(
300-
request_logger=request_logger,
301-
chat_template=resolved_chat_template,
302-
chat_template_content_format=args.chat_template_content_format,
303-
trust_request_chat_template=args.trust_request_chat_template,
304-
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
305-
enable_auto_tools=args.enable_auto_tool_choice,
306-
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
307-
tool_parser=args.tool_call_parser,
308-
reasoning_parser=args.structured_outputs_config.reasoning_parser,
309-
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
310-
enable_force_include_usage=args.enable_force_include_usage,
311-
enable_log_outputs=args.enable_log_outputs,
312-
)
313-
if hasattr(args, "log_error_stack"):
314-
chat_kwargs["log_error_stack"] = args.log_error_stack
315-
316-
serving_chat = OpenAIServingChatWithTokens(
317-
engine_client,
318-
state.openai_serving_models,
319-
args.response_role,
320-
**chat_kwargs,
321-
)
322-
state.openai_serving_chat = serving_chat if "generate" in supported_tasks else None
323-
state.openai_serving_chat_with_tokens = serving_chat if "generate" in supported_tasks else None
324-
325-
326-
def custom_run_api_server_worker_proc(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None:
327-
"""
328-
Modifies run_api_server_worker_proc:
329-
1. Re-import our module to ensure monkey patches are applied in child processes
330-
"""
331-
# NOTE: This hack ensures that monkey patches are applied in child processes
332-
# to make our custom routes work in multi-API-server settings.
333-
import prime_rl.inference.vllm.server # noqa: F401
334-
335-
run_api_server_worker_proc(listen_address, sock, args, client_config, **uvicorn_kwargs)
277+
if "generate" in supported_tasks and state.openai_serving_chat is not None:
278+
original_chat = state.openai_serving_chat
279+
serving_chat = object.__new__(OpenAIServingChatWithTokens)
280+
serving_chat.__dict__.update(original_chat.__dict__)
281+
state.openai_serving_chat = serving_chat
282+
state.openai_serving_chat_with_tokens = serving_chat
336283

337284

338-
import vllm.entrypoints.cli.serve
339285
import vllm.entrypoints.openai.api_server
340286
from vllm.entrypoints.openai.api_server import build_app as _original_build_app
341287

342288

343-
def custom_build_app(args: Namespace, supported_tasks: tuple):
289+
def custom_build_app(args: Namespace, supported_tasks: tuple, model_config=None):
344290
"""
345291
Wrap build_app to include our custom router.
346292
"""
347-
app = _original_build_app(args, supported_tasks)
293+
app = _original_build_app(args, supported_tasks, model_config)
348294
app.include_router(router)
349295
return app
350296

351297

352-
# Also monkey patch run_api_server_worker_proc for multi-api-server mode
353-
# This is needed because worker processes spawned by run_multi_api_server
354-
# re-import modules and would otherwise use the original run_server_worker
355298
vllm.entrypoints.openai.api_server.init_app_state = custom_init_app_state
356299
vllm.entrypoints.openai.api_server.build_app = custom_build_app
357-
vllm.entrypoints.cli.serve.run_api_server_worker_proc = custom_run_api_server_worker_proc
358300

359301

360302
# Adapted from vllm/entrypoints/cli/serve.py

src/prime_rl/inference/vllm/serving_chat_with_tokens.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def create_chat_completion_with_tokens(
250250
request_metadata,
251251
reasoning_parser,
252252
)
253-
except GenerationError as e:
254-
return self._convert_generation_error_to_response(e)
253+
except GenerationError:
254+
raise # Let FastAPI's global generation_error_handler handle it
255255
except ValueError as e:
256256
return self.create_error_response(e)

src/prime_rl/trainer/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
1212
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
13+
from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM, Gemma4TextConfig
1314
from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
1415
from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM
1516
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput, cast_float_and_contiguous
@@ -20,6 +21,7 @@
2021
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
2122

2223
# Make custom config discoverable by AutoConfig
24+
AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True)
2325
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
2426
AutoConfig.register("glm4_moe", Glm4MoeConfig, exist_ok=True)
2527
AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True)
@@ -29,6 +31,7 @@
2931
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True)
3032

3133
_CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict())
34+
_CUSTOM_CAUSAL_LM_MAPPING.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True)
3235
_CUSTOM_CAUSAL_LM_MAPPING.register(LlamaConfig, LlamaForCausalLM, exist_ok=True)
3336
_CUSTOM_CAUSAL_LM_MAPPING.register(AfmoeConfig, AfmoeForCausalLM, exist_ok=True)
3437
_CUSTOM_CAUSAL_LM_MAPPING.register(Glm4MoeConfig, Glm4MoeForCausalLM, exist_ok=True)

src/prime_rl/trainer/models/afmoe/configuration_afmoe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
2-
from transformers.modeling_rope_utils import rope_config_validation
32
from transformers.utils import logging
43

54
logger = logging.get_logger(__name__)
@@ -106,7 +105,7 @@ def __init__(
106105
# Validate rope configs
107106
if self.rope_scaling is not None and "type" in self.rope_scaling:
108107
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
109-
rope_config_validation(self)
108+
self.standardize_rope_params()
110109

111110
super().__init__(
112111
tie_word_embeddings=tie_word_embeddings,

0 commit comments

Comments
 (0)