Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"torch>=2.9.0",
"torchdata>=0.11.0",
"transformers",
"vllm>=0.17.0",
"vllm>=0.19.0",
"wandb>=0.24.2",
"ring-flash-attn>=0.1.8",
"prime>=0.5.37",
Expand Down Expand Up @@ -123,7 +123,7 @@ torch = { index = "pytorch-cu128" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "d3c830c" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "abd9943b" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" }
Expand Down
1 change: 1 addition & 0 deletions src/prime_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import prime_rl._compat # noqa: F401 — must run before ring_flash_attn is imported
19 changes: 19 additions & 0 deletions src/prime_rl/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Compatibility shim: ring_flash_attn + transformers >= 5.4.

ring_flash_attn 0.1.8 imports `is_flash_attn_greater_or_equal_2_10` from
`transformers.modeling_flash_attention_utils`. This symbol was removed from
that module in transformers 5.4 (still available as a deprecated function
in `transformers.utils.import_utils`, scheduled for removal in 5.8).

ring_flash_attn's except-branch is a no-op (imports the same symbol again),
so the import crashes on transformers >= 5.4. We patch the symbol back in as
`True` — the check is dead code since no one uses flash_attn < 2.1.0 anymore.

Upstream fix: https://github.com/zhuzilin/ring-flash-attention/pull/85
Remove this shim once ring_flash_attn ships a fixed version.
"""

import transformers.modeling_flash_attention_utils as _mfau

if not hasattr(_mfau, "is_flash_attn_greater_or_equal_2_10"):
_mfau.is_flash_attn_greater_or_equal_2_10 = True
141 changes: 3 additions & 138 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,8 @@ def slice_lora_a(self, lora_a):
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a


# Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
def monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode():
from vllm.v1.metrics import loggers as vllm_metrics_loggers

_original_prometheus_stat_logger_init = vllm_metrics_loggers.PrometheusStatLogger.__init__

def _patched_prometheus_stat_logger_init(self, vllm_config, engine_indexes=None):
"""Patched init that temporarily disables lora_config to skip the DP mode check."""
original_lora_config = vllm_config.lora_config
vllm_config.lora_config = None
try:
_original_prometheus_stat_logger_init(self, vllm_config, engine_indexes)
finally:
vllm_config.lora_config = original_lora_config
# Re-initialize LoRA metrics if needed (after the DP check is bypassed)
if original_lora_config is not None:
self.labelname_max_lora = "max_lora"
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
self.labelname_running_lora_adapters = "running_lora_adapters"
self.max_lora = original_lora_config.max_loras
self.gauge_lora_info = vllm_metrics_loggers.PrometheusStatLogger._gauge_cls(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters,
],
)

vllm_metrics_loggers.PrometheusStatLogger.__init__ = _patched_prometheus_stat_logger_init


# Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
def monkey_patch_load_lora_adapter():
from http import HTTPStatus

Expand Down Expand Up @@ -153,6 +120,7 @@ async def _patched_load_lora_adapter(


# Monkeypatch LRUCacheWorkerLoRAManager to allow loading adapter inplace without doing it every request
# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
def monkey_patch_LRUCacheWorkerLoRAManager():
from vllm.lora.worker_manager import LoRARequest, LRUCacheLoRAModelManager, LRUCacheWorkerLoRAManager

Expand Down Expand Up @@ -278,109 +246,6 @@ def _patched_get_encode_kwargs(self):
TokenizeParams.get_encode_kwargs = _patched_get_encode_kwargs


def monkey_patch_hermes_tool_parser_thread_safety():
"""Patch Hermes2ProToolParser to cache tokenizer encode/decode results.

The original __init__ calls tokenizer.encode() and tokenizer.decode() on
every instantiation. Under concurrent load, the shared HuggingFace tokenizer's
Rust backend panics with ``RuntimeError: Already borrowed`` because multiple
threads mutably borrow the same internal state simultaneously.

Fix: run the first __init__ (which calls encode/decode) under a lock, cache
the results, and reuse them for all subsequent instantiations without ever
touching the tokenizer again.
"""
import threading

import regex as re
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser

_original_init = Hermes2ProToolParser.__init__
_cache: dict[int, dict] = {}
_lock = threading.Lock()

def _patched_init(self, tokenizer):
from vllm.tokenizers.mistral import MistralTokenizer

# Resolve the actual tokenizer that __init__ will use for encode/decode
actual_tokenizer = tokenizer.tokenizer if isinstance(tokenizer, MistralTokenizer) else tokenizer
key = id(actual_tokenizer)

if key in _cache:
# Fast path: skip encode/decode entirely, set up instance from cache
ToolParser.__init__(self, tokenizer)
if isinstance(tokenizer, MistralTokenizer):
self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent = False
self.prev_tool_call_arr = []
self.current_tool_id = -1
self.streamed_args_for_tool = []
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
self.scratch_pad_regex = re.compile(r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
cached = _cache[key]
self.tool_call_start_token_ids = cached["start_ids"]
self.tool_call_end_token_ids = cached["end_ids"]
self.tool_call_start_token_array = cached["start_array"]
self.tool_call_end_token_array = cached["end_array"]
self.buffered_delta_text = ""
return

# Slow path: first instantiation for this tokenizer, run under lock
with _lock:
if key in _cache:
# Another thread populated it while we waited
_patched_init(self, tokenizer)
return
_original_init(self, tokenizer)
_cache[key] = {
"start_ids": self.tool_call_start_token_ids,
"end_ids": self.tool_call_end_token_ids,
"start_array": self.tool_call_start_token_array,
"end_array": self.tool_call_end_token_array,
}

Hermes2ProToolParser.__init__ = _patched_init


def monkey_patch_tokenizer_thread_safety():
"""Patch HuggingFace tokenizer to make _encode_plus thread-safe.

Under concurrent request load, vLLM's API server calls _encode_plus from
multiple async handlers simultaneously. _encode_plus mutates the Rust
tokenizer's internal state via set_truncation_and_padding (enable_truncation/
enable_padding) and encode_special_tokens. The Rust backend uses RefCell-style
borrow tracking (PyO3), and concurrent mutable borrows cause it to panic
with ``RuntimeError: Already borrowed``.

Fix: wrap the entire _encode_plus method in a per-tokenizer threading lock
so that state mutation and the subsequent encode call are atomic.
"""
import threading

from transformers import PreTrainedTokenizerFast

_original_encode_plus = PreTrainedTokenizerFast._encode_plus
_locks: dict[int, threading.Lock] = {}
_meta_lock = threading.Lock()

def _get_lock(tokenizer_id: int) -> threading.Lock:
if tokenizer_id not in _locks:
with _meta_lock:
if tokenizer_id not in _locks:
_locks[tokenizer_id] = threading.Lock()
return _locks[tokenizer_id]

def _patched_encode_plus(self, *args, **kwargs):
lock = _get_lock(id(self._tokenizer))
with lock:
return _original_encode_plus(self, *args, **kwargs)

PreTrainedTokenizerFast._encode_plus = _patched_encode_plus


def monkey_patch_minimax_m2_for_lora():
"""Patch vLLM's MiniMaxM2 model for LoRA compatibility.

Expand Down Expand Up @@ -457,7 +322,7 @@ def _patched_forward(self, hidden_states):


def monkey_patch_harmony_stop_token_propagation():
"""Fix: vLLM 0.17.0 doesn't merge harmony stop tokens into per-request SamplingParams.
"""Fix: vLLM doesn't merge harmony stop tokens into per-request SamplingParams.

The harmony mode sets stop_token_ids (including <|call|> and <|return|>) in
default_sampling_params at server init, but ChatCompletionRequest.to_sampling_params()
Expand Down
86 changes: 14 additions & 72 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.cli.serve import run_api_server_worker_proc
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.api_server import init_app_state
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionResponse
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
Expand Down Expand Up @@ -136,30 +133,23 @@ def resolve_tool_call_parser(model_name: str, tool_call_parser: str | None) -> s
logger = get_logger()
from prime_rl.inference.patches import (
monkey_patch_harmony_stop_token_propagation,
monkey_patch_hermes_tool_parser_thread_safety,
monkey_patch_load_lora_adapter,
monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode,
monkey_patch_tokenize_params_validation,
monkey_patch_tokenizer_thread_safety,
)
from prime_rl.inference.vllm.serving_chat_with_tokens import (
ChatCompletionRequestWithTokens,
OpenAIServingChatWithTokens,
)

# NOTE: Fix harmony stop token propagation for GPT-OSS models (vLLM 0.17.0 bug)
# NOTE: Fix harmony stop token propagation for GPT-OSS models
# Upstream issue still open: https://github.com/vllm-project/vllm/issues/22519
monkey_patch_harmony_stop_token_propagation()
# NOTE: Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode()
# NOTE: Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
# May be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
monkey_patch_load_lora_adapter()
# NOTE: Monkeypatch TokenizeParams to fix overly conservative validation
# Still needed in vLLM 0.19 — upstream rejects prompt_len > max_model_len - max_tokens
monkey_patch_tokenize_params_validation()
# NOTE: Monkeypatch Hermes tool parser to fix "Already borrowed" RuntimeError under concurrent load
monkey_patch_hermes_tool_parser_thread_safety()
# NOTE: Monkeypatch HF tokenizer to fix "Already borrowed" RuntimeError during concurrent chat template processing
# Can be removed once https://github.com/vllm-project/vllm/pull/36557 is merged and we upgrade vllm
monkey_patch_tokenizer_thread_safety()

logger = init_logger("vllm.entrypoints.openai.api_server")

Expand Down Expand Up @@ -279,82 +269,34 @@ async def custom_init_app_state(
):
"""
Modifies init_app_state:
1. Set up the custom OpenAIServingChatWithTokens state.
2. Monkey-patch to allow updating lora adapters in-place.
1. Call the original init_app_state to set up standard state.
2. Replace the serving_chat with our OpenAIServingChatWithTokens wrapper.
"""
# Setup the regular app state first (in-place)
await init_app_state(engine_client, state, args, supported_tasks)

# NOTE: Initialize the custom OpenAIServingChatWithTokens state here
# TODO: Here, we repeat some calls done in init_app_state to be able to
# correctly set up the OpenAIServingChatWithTokens state, which is a bit
# brittle, and could probably be made nicer
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None

resolved_chat_template = load_chat_template(args.chat_template)

chat_kwargs = dict(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
)
if hasattr(args, "log_error_stack"):
chat_kwargs["log_error_stack"] = args.log_error_stack

serving_chat = OpenAIServingChatWithTokens(
engine_client,
state.openai_serving_models,
args.response_role,
**chat_kwargs,
)
state.openai_serving_chat = serving_chat if "generate" in supported_tasks else None
state.openai_serving_chat_with_tokens = serving_chat if "generate" in supported_tasks else None


def custom_run_api_server_worker_proc(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None:
"""
Modifies run_api_server_worker_proc:
1. Re-import our module to ensure monkey patches are applied in child processes
"""
# NOTE: This hack ensures that monkey patches are applied in child processes
# to make our custom routes work in multi-API-server settings.
import prime_rl.inference.vllm.server # noqa: F401

run_api_server_worker_proc(listen_address, sock, args, client_config, **uvicorn_kwargs)
if "generate" in supported_tasks and state.openai_serving_chat is not None:
original_chat = state.openai_serving_chat
serving_chat = object.__new__(OpenAIServingChatWithTokens)
serving_chat.__dict__.update(original_chat.__dict__)
state.openai_serving_chat = serving_chat
state.openai_serving_chat_with_tokens = serving_chat
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing state attribute causes AttributeError on fallback path

Medium Severity

The new custom_init_app_state only sets state.openai_serving_chat_with_tokens inside the if "generate" in supported_tasks and state.openai_serving_chat is not None block. When this condition is false, the attribute is never set. The old code always set it (to None in the else case). The chat_with_tokens dependency at line 179 directly accesses request.app.state.openai_serving_chat_with_tokens, which will raise AttributeError on Starlette's State object if the attribute was never assigned.

Fix in Cursor Fix in Web



import vllm.entrypoints.cli.serve
import vllm.entrypoints.openai.api_server
from vllm.entrypoints.openai.api_server import build_app as _original_build_app


def custom_build_app(args: Namespace, supported_tasks: tuple):
def custom_build_app(args: Namespace, supported_tasks: tuple, model_config=None):
"""
Wrap build_app to include our custom router.
"""
app = _original_build_app(args, supported_tasks)
app = _original_build_app(args, supported_tasks, model_config)
app.include_router(router)
return app


# Also monkey patch run_api_server_worker_proc for multi-api-server mode
# This is needed because worker processes spawned by run_multi_api_server
# re-import modules and would otherwise use the original run_server_worker
vllm.entrypoints.openai.api_server.init_app_state = custom_init_app_state
vllm.entrypoints.openai.api_server.build_app = custom_build_app
vllm.entrypoints.cli.serve.run_api_server_worker_proc = custom_run_api_server_worker_proc


# Adapted from vllm/entrypoints/cli/serve.py
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/inference/vllm/serving_chat_with_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def create_chat_completion_with_tokens(
request_metadata,
reasoning_parser,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except GenerationError:
raise # Let FastAPI's global generation_error_handler handle it
except ValueError as e:
return self.create_error_response(e)
3 changes: 3 additions & 0 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM, Gemma4TextConfig
from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput, cast_float_and_contiguous
Expand All @@ -20,6 +21,7 @@
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM

# Make custom config discoverable by AutoConfig
AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True)
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
AutoConfig.register("glm4_moe", Glm4MoeConfig, exist_ok=True)
AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True)
Expand All @@ -29,6 +31,7 @@
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True)

_CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict())
_CUSTOM_CAUSAL_LM_MAPPING.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(LlamaConfig, LlamaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(AfmoeConfig, AfmoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Glm4MoeConfig, Glm4MoeForCausalLM, exist_ok=True)
Expand Down
3 changes: 1 addition & 2 deletions src/prime_rl/trainer/models/afmoe/configuration_afmoe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -106,7 +105,7 @@ def __init__(
# Validate rope configs
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.standardize_rope_params()

super().__init__(
tie_word_embeddings=tie_word_embeddings,
Expand Down
Loading
Loading