diff --git a/configs/gemma4/rl_dense.toml b/configs/gemma4/rl_dense.toml
new file mode 100644
index 0000000000..05e48cc74e
--- /dev/null
+++ b/configs/gemma4/rl_dense.toml
@@ -0,0 +1,39 @@
+# Gemma4 31B dense RL test (reverse-text)
+output_dir = "outputs/gemma4-dense-rl"
+max_steps = 5
+seq_len = 2048
+
+[slurm]
+job_name = "gemma4-dense-rl"
+
+[deployment]
+type = "single_node"
+num_train_gpus = 4
+num_infer_gpus = 4
+
+[model]
+name = "google/gemma-4-31B-it"
+
+[wandb]
+project = "gemma4-test"
+name = "rl-dense"
+
+[orchestrator]
+batch_size = 128
+rollouts_per_example = 16
+
+[orchestrator.sampling]
+max_tokens = 128
+
+[[orchestrator.env]]
+id = "reverse-text"
+
+[trainer.model]
+attn = "flash_attention_2"
+
+[trainer.model.ac]
+
+[trainer.optim]
+lr = 3e-6
+
+[inference]
diff --git a/configs/gemma4/rl_moe.toml b/configs/gemma4/rl_moe.toml
new file mode 100644
index 0000000000..75be4d6d17
--- /dev/null
+++ b/configs/gemma4/rl_moe.toml
@@ -0,0 +1,43 @@
+# Gemma4 26B-A4B MoE RL test (reverse-text)
+output_dir = "outputs/gemma4-moe-rl"
+max_steps = 5
+seq_len = 2048
+
+[slurm]
+job_name = "gemma4-moe-rl"
+
+[deployment]
+type = "single_node"
+num_train_gpus = 4
+num_infer_gpus = 4
+
+[model]
+name = "google/gemma-4-26B-A4B-it"
+
+[wandb]
+project = "gemma4-test"
+name = "rl-moe"
+
+[orchestrator]
+batch_size = 128
+rollouts_per_example = 16
+
+[orchestrator.sampling]
+max_tokens = 128
+
+[[orchestrator.env]]
+id = "reverse-text"
+
+[trainer.model]
+attn = "flash_attention_2"
+
+[trainer.model.ac]
+
+[trainer.optim]
+lr = 3e-6
+
+[inference]
+gpu_memory_utilization = 0.7
+
+[inference.model]
+max_model_len = 2048
diff --git a/configs/gemma4/sft_dense.toml b/configs/gemma4/sft_dense.toml
new file mode 100644
index 0000000000..727271be25
--- /dev/null
+++ b/configs/gemma4/sft_dense.toml
@@ -0,0 +1,29 @@
+# Gemma4 31B dense SFT test
+# Usage: uv run sft @ configs/gemma4/sft_dense.toml
+output_dir = "outputs/gemma4-dense-sft"
+max_steps = 10
+
+[slurm]
+job_name = "gemma4-dense-sft"
+
+[deployment]
+type = "single_node"
+num_gpus = 8
+
+[model]
+name = "google/gemma-4-31B-it"
+attn = "flash_attention_2"
+
+[model.ac]
+
+[wandb]
+project = "gemma4-test"
+name = "sft-dense"
+
+[data]
+name = "PrimeIntellect/Reverse-Text-SFT"
+batch_size = 8
+seq_len = 2048
+
+[data.chat_template_kwargs]
+enable_thinking = true
diff --git a/configs/gemma4/sft_moe.toml b/configs/gemma4/sft_moe.toml
new file mode 100644
index 0000000000..8510096a23
--- /dev/null
+++ b/configs/gemma4/sft_moe.toml
@@ -0,0 +1,29 @@
+# Gemma4 26B-A4B MoE SFT test
+# Usage: uv run sft @ configs/gemma4/sft_moe.toml
+output_dir = "outputs/gemma4-moe-sft"
+max_steps = 10
+
+[slurm]
+job_name = "gemma4-moe-sft"
+
+[deployment]
+type = "single_node"
+num_gpus = 8
+
+[model]
+name = "google/gemma-4-26B-A4B-it"
+attn = "flash_attention_2"
+
+[model.ac]
+
+[wandb]
+project = "gemma4-test"
+name = "sft-moe"
+
+[data]
+name = "PrimeIntellect/Reverse-Text-SFT"
+batch_size = 8
+seq_len = 2048
+
+[data.chat_template_kwargs]
+enable_thinking = true
diff --git a/configs/gemma4/sft_qwen3.toml b/configs/gemma4/sft_qwen3.toml
new file mode 100644
index 0000000000..c088f02b31
--- /dev/null
+++ b/configs/gemma4/sft_qwen3.toml
@@ -0,0 +1,18 @@
+# Qwen3 0.6B SFT sanity check
+output_dir = "outputs/qwen3-sft"
+max_steps = 5
+
+[slurm]
+job_name = "qwen3-sft"
+
+[model]
+name = "PrimeIntellect/Qwen3-0.6B"
+
+[wandb]
+project = "gemma4-test"
+name = "sft-qwen3"
+
+[data]
+name = "PrimeIntellect/Reverse-Text-SFT"
+batch_size = 4
+seq_len = 1024
diff --git a/configs/gemma4/sft_qwen35.toml b/configs/gemma4/sft_qwen35.toml
new file mode 100644
index 0000000000..b25d4893b2
--- /dev/null
+++ b/configs/gemma4/sft_qwen35.toml
@@ -0,0 +1,19 @@
+# Qwen3.5 MoE text-only SFT test (proving VLM bug)
+# Usage: uv run sft @ configs/gemma4/sft_qwen35.toml
+output_dir = "outputs/qwen35-moe-sft"
+max_steps = 5
+
+[slurm]
+job_name = "qwen35-moe-sft"
+
+[model]
+name = "Qwen/Qwen3.5-35B-A3B"
+
+[wandb]
+project = "gemma4-test"
+name = "sft-qwen35"
+
+[data]
+name = "PrimeIntellect/Reverse-Text-SFT"
+batch_size = 4
+seq_len = 2048
diff --git a/pyproject.toml b/pyproject.toml
index 957b85c4d2..6642b882a8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ dependencies = [
"torch>=2.9.0",
"torchdata>=0.11.0",
"transformers",
- "vllm>=0.17.0",
+ "vllm>=0.19.0",
"wandb>=0.24.2",
"ring-flash-attn>=0.1.8",
"prime>=0.5.37",
@@ -123,7 +123,7 @@ torch = { index = "pytorch-cu128" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "d3c830c" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
-transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" }
+transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "abd9943b" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" }
diff --git a/src/prime_rl/__init__.py b/src/prime_rl/__init__.py
index e69de29bb2..d76436d3c5 100644
--- a/src/prime_rl/__init__.py
+++ b/src/prime_rl/__init__.py
@@ -0,0 +1 @@
+import prime_rl._compat # noqa: F401 — must run before ring_flash_attn is imported
diff --git a/src/prime_rl/_compat.py b/src/prime_rl/_compat.py
new file mode 100644
index 0000000000..5471030c84
--- /dev/null
+++ b/src/prime_rl/_compat.py
@@ -0,0 +1,19 @@
+"""Compatibility shim: ring_flash_attn + transformers >= 5.4.
+
+ring_flash_attn 0.1.8 imports `is_flash_attn_greater_or_equal_2_10` from
+`transformers.modeling_flash_attention_utils`. This symbol was removed from
+that module in transformers 5.4 (still available as a deprecated function
+in `transformers.utils.import_utils`, scheduled for removal in 5.8).
+
+ring_flash_attn's except-branch is a no-op (imports the same symbol again),
+so the import crashes on transformers >= 5.4. We patch the symbol back in as
+`True` — the check is dead code since no one uses flash_attn < 2.1.0 anymore.
+
+Upstream fix: https://github.com/zhuzilin/ring-flash-attention/pull/85
+Remove this shim once ring_flash_attn ships a fixed version.
+"""
+
+import transformers.modeling_flash_attention_utils as _mfau
+
+if not hasattr(_mfau, "is_flash_attn_greater_or_equal_2_10"):
+ _mfau.is_flash_attn_greater_or_equal_2_10 = True
diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py
index 50d1baacd0..11c778130d 100644
--- a/src/prime_rl/configs/sft.py
+++ b/src/prime_rl/configs/sft.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from typing import Annotated, Literal, TypeAlias
+from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -86,6 +86,14 @@ class SFTDataConfig(BaseDataConfig):
# Configuring
loss_mask: LossMaskConfig = LossMaskConfig()
+ chat_template_kwargs: Annotated[
+ dict[str, Any] | None,
+ Field(
+ description="Extra keyword arguments passed to tokenizer.apply_chat_template(). "
+ "E.g. {'enable_thinking': true} for models with thinking-aware templates (Gemma4, Qwen3)."
+ ),
+ ] = None
+
@model_validator(mode="after")
def validate_subsets_and_splits(self):
if self.subsets is not None or self.splits is not None:
diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py
index 9b0ffe053e..18c6c3c021 100644
--- a/src/prime_rl/inference/patches.py
+++ b/src/prime_rl/inference/patches.py
@@ -64,41 +64,8 @@ def slice_lora_a(self, lora_a):
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a
-# Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
-def monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode():
- from vllm.v1.metrics import loggers as vllm_metrics_loggers
-
- _original_prometheus_stat_logger_init = vllm_metrics_loggers.PrometheusStatLogger.__init__
-
- def _patched_prometheus_stat_logger_init(self, vllm_config, engine_indexes=None):
- """Patched init that temporarily disables lora_config to skip the DP mode check."""
- original_lora_config = vllm_config.lora_config
- vllm_config.lora_config = None
- try:
- _original_prometheus_stat_logger_init(self, vllm_config, engine_indexes)
- finally:
- vllm_config.lora_config = original_lora_config
- # Re-initialize LoRA metrics if needed (after the DP check is bypassed)
- if original_lora_config is not None:
- self.labelname_max_lora = "max_lora"
- self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
- self.labelname_running_lora_adapters = "running_lora_adapters"
- self.max_lora = original_lora_config.max_loras
- self.gauge_lora_info = vllm_metrics_loggers.PrometheusStatLogger._gauge_cls(
- name="vllm:lora_requests_info",
- documentation="Running stats on lora requests.",
- multiprocess_mode="sum",
- labelnames=[
- self.labelname_max_lora,
- self.labelname_waiting_lora_adapters,
- self.labelname_running_lora_adapters,
- ],
- )
-
- vllm_metrics_loggers.PrometheusStatLogger.__init__ = _patched_prometheus_stat_logger_init
-
-
# Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
+# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
def monkey_patch_load_lora_adapter():
from http import HTTPStatus
@@ -153,6 +120,7 @@ async def _patched_load_lora_adapter(
# Monkeypatch LRUCacheWorkerLoRAManager to allow loading adapter inplace without doing it every request
+# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
def monkey_patch_LRUCacheWorkerLoRAManager():
from vllm.lora.worker_manager import LoRARequest, LRUCacheLoRAModelManager, LRUCacheWorkerLoRAManager
@@ -278,109 +246,6 @@ def _patched_get_encode_kwargs(self):
TokenizeParams.get_encode_kwargs = _patched_get_encode_kwargs
-def monkey_patch_hermes_tool_parser_thread_safety():
- """Patch Hermes2ProToolParser to cache tokenizer encode/decode results.
-
- The original __init__ calls tokenizer.encode() and tokenizer.decode() on
- every instantiation. Under concurrent load, the shared HuggingFace tokenizer's
- Rust backend panics with ``RuntimeError: Already borrowed`` because multiple
- threads mutably borrow the same internal state simultaneously.
-
- Fix: run the first __init__ (which calls encode/decode) under a lock, cache
- the results, and reuse them for all subsequent instantiations without ever
- touching the tokenizer again.
- """
- import threading
-
- import regex as re
- from vllm.tool_parsers.abstract_tool_parser import ToolParser
- from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
-
- _original_init = Hermes2ProToolParser.__init__
- _cache: dict[int, dict] = {}
- _lock = threading.Lock()
-
- def _patched_init(self, tokenizer):
- from vllm.tokenizers.mistral import MistralTokenizer
-
- # Resolve the actual tokenizer that __init__ will use for encode/decode
- actual_tokenizer = tokenizer.tokenizer if isinstance(tokenizer, MistralTokenizer) else tokenizer
- key = id(actual_tokenizer)
-
- if key in _cache:
- # Fast path: skip encode/decode entirely, set up instance from cache
- ToolParser.__init__(self, tokenizer)
- if isinstance(tokenizer, MistralTokenizer):
- self.model_tokenizer = tokenizer.tokenizer
- self.current_tool_name_sent = False
- self.prev_tool_call_arr = []
- self.current_tool_id = -1
- self.streamed_args_for_tool = []
- self.tool_call_start_token = ""
- self.tool_call_end_token = ""
- self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL)
- self.scratch_pad_regex = re.compile(r"(.*?)", re.DOTALL)
- cached = _cache[key]
- self.tool_call_start_token_ids = cached["start_ids"]
- self.tool_call_end_token_ids = cached["end_ids"]
- self.tool_call_start_token_array = cached["start_array"]
- self.tool_call_end_token_array = cached["end_array"]
- self.buffered_delta_text = ""
- return
-
- # Slow path: first instantiation for this tokenizer, run under lock
- with _lock:
- if key in _cache:
- # Another thread populated it while we waited
- _patched_init(self, tokenizer)
- return
- _original_init(self, tokenizer)
- _cache[key] = {
- "start_ids": self.tool_call_start_token_ids,
- "end_ids": self.tool_call_end_token_ids,
- "start_array": self.tool_call_start_token_array,
- "end_array": self.tool_call_end_token_array,
- }
-
- Hermes2ProToolParser.__init__ = _patched_init
-
-
-def monkey_patch_tokenizer_thread_safety():
- """Patch HuggingFace tokenizer to make _encode_plus thread-safe.
-
- Under concurrent request load, vLLM's API server calls _encode_plus from
- multiple async handlers simultaneously. _encode_plus mutates the Rust
- tokenizer's internal state via set_truncation_and_padding (enable_truncation/
- enable_padding) and encode_special_tokens. The Rust backend uses RefCell-style
- borrow tracking (PyO3), and concurrent mutable borrows cause it to panic
- with ``RuntimeError: Already borrowed``.
-
- Fix: wrap the entire _encode_plus method in a per-tokenizer threading lock
- so that state mutation and the subsequent encode call are atomic.
- """
- import threading
-
- from transformers import PreTrainedTokenizerFast
-
- _original_encode_plus = PreTrainedTokenizerFast._encode_plus
- _locks: dict[int, threading.Lock] = {}
- _meta_lock = threading.Lock()
-
- def _get_lock(tokenizer_id: int) -> threading.Lock:
- if tokenizer_id not in _locks:
- with _meta_lock:
- if tokenizer_id not in _locks:
- _locks[tokenizer_id] = threading.Lock()
- return _locks[tokenizer_id]
-
- def _patched_encode_plus(self, *args, **kwargs):
- lock = _get_lock(id(self._tokenizer))
- with lock:
- return _original_encode_plus(self, *args, **kwargs)
-
- PreTrainedTokenizerFast._encode_plus = _patched_encode_plus
-
-
def monkey_patch_minimax_m2_for_lora():
"""Patch vLLM's MiniMaxM2 model for LoRA compatibility.
@@ -457,7 +322,7 @@ def _patched_forward(self, hidden_states):
def monkey_patch_harmony_stop_token_propagation():
- """Fix: vLLM 0.17.0 doesn't merge harmony stop tokens into per-request SamplingParams.
+ """Fix: vLLM doesn't merge harmony stop tokens into per-request SamplingParams.
The harmony mode sets stop_token_ids (including <|call|> and <|return|>) in
default_sampling_params at server init, but ChatCompletionRequest.to_sampling_params()
diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py
index b294aa447b..609b44f84d 100644
--- a/src/prime_rl/inference/vllm/server.py
+++ b/src/prime_rl/inference/vllm/server.py
@@ -7,9 +7,6 @@
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template
-from vllm.entrypoints.cli.serve import run_api_server_worker_proc
-from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.api_server import init_app_state
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionResponse
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
@@ -123,6 +120,9 @@
# NemotronH
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "qwen3_coder",
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "qwen3_coder",
+ # Gemma4
+ "google/gemma-4-26B-A4B-it": "gemma4",
+ "google/gemma-4-31B-it": "gemma4",
}
@@ -136,30 +136,23 @@ def resolve_tool_call_parser(model_name: str, tool_call_parser: str | None) -> s
logger = get_logger()
from prime_rl.inference.patches import (
monkey_patch_harmony_stop_token_propagation,
- monkey_patch_hermes_tool_parser_thread_safety,
monkey_patch_load_lora_adapter,
- monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode,
monkey_patch_tokenize_params_validation,
- monkey_patch_tokenizer_thread_safety,
)
from prime_rl.inference.vllm.serving_chat_with_tokens import (
ChatCompletionRequestWithTokens,
OpenAIServingChatWithTokens,
)
-# NOTE: Fix harmony stop token propagation for GPT-OSS models (vLLM 0.17.0 bug)
+# NOTE: Fix harmony stop token propagation for GPT-OSS models
+# Upstream issue still open: https://github.com/vllm-project/vllm/issues/22519
monkey_patch_harmony_stop_token_propagation()
-# NOTE: Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
-monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode()
# NOTE: Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
+# May be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
monkey_patch_load_lora_adapter()
# NOTE: Monkeypatch TokenizeParams to fix overly conservative validation
+# Still needed in vLLM 0.19 — upstream rejects prompt_len > max_model_len - max_tokens
monkey_patch_tokenize_params_validation()
-# NOTE: Monkeypatch Hermes tool parser to fix "Already borrowed" RuntimeError under concurrent load
-monkey_patch_hermes_tool_parser_thread_safety()
-# NOTE: Monkeypatch HF tokenizer to fix "Already borrowed" RuntimeError during concurrent chat template processing
-# Can be removed once https://github.com/vllm-project/vllm/pull/36557 is merged and we upgrade vllm
-monkey_patch_tokenizer_thread_safety()
logger = init_logger("vllm.entrypoints.openai.api_server")
@@ -279,82 +272,36 @@ async def custom_init_app_state(
):
"""
Modifies init_app_state:
- 1. Set up the custom OpenAIServingChatWithTokens state.
- 2. Monkey-patch to allow updating lora adapters in-place.
+ 1. Call the original init_app_state to set up standard state.
+ 2. Replace the serving_chat with our OpenAIServingChatWithTokens wrapper.
"""
- # Setup the regular app state first (in-place)
await init_app_state(engine_client, state, args, supported_tasks)
- # NOTE: Initialize the custom OpenAIServingChatWithTokens state here
- # TODO: Here, we repeat some calls done in init_app_state to be able to
- # correctly set up the OpenAIServingChatWithTokens state, which is a bit
- # brittle, and could probably be made nicer
- if args.enable_log_requests:
- request_logger = RequestLogger(max_log_len=args.max_log_len)
+ if "generate" in supported_tasks and state.openai_serving_chat is not None:
+ original_chat = state.openai_serving_chat
+ serving_chat = object.__new__(OpenAIServingChatWithTokens)
+ serving_chat.__dict__.update(original_chat.__dict__)
+ state.openai_serving_chat = serving_chat
+ state.openai_serving_chat_with_tokens = serving_chat
else:
- request_logger = None
-
- resolved_chat_template = load_chat_template(args.chat_template)
-
- chat_kwargs = dict(
- request_logger=request_logger,
- chat_template=resolved_chat_template,
- chat_template_content_format=args.chat_template_content_format,
- trust_request_chat_template=args.trust_request_chat_template,
- return_tokens_as_token_ids=args.return_tokens_as_token_ids,
- enable_auto_tools=args.enable_auto_tool_choice,
- exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
- tool_parser=args.tool_call_parser,
- reasoning_parser=args.structured_outputs_config.reasoning_parser,
- enable_prompt_tokens_details=args.enable_prompt_tokens_details,
- enable_force_include_usage=args.enable_force_include_usage,
- enable_log_outputs=args.enable_log_outputs,
- )
- if hasattr(args, "log_error_stack"):
- chat_kwargs["log_error_stack"] = args.log_error_stack
-
- serving_chat = OpenAIServingChatWithTokens(
- engine_client,
- state.openai_serving_models,
- args.response_role,
- **chat_kwargs,
- )
- state.openai_serving_chat = serving_chat if "generate" in supported_tasks else None
- state.openai_serving_chat_with_tokens = serving_chat if "generate" in supported_tasks else None
-
-
-def custom_run_api_server_worker_proc(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None:
- """
- Modifies run_api_server_worker_proc:
- 1. Re-import our module to ensure monkey patches are applied in child processes
- """
- # NOTE: This hack ensures that monkey patches are applied in child processes
- # to make our custom routes work in multi-API-server settings.
- import prime_rl.inference.vllm.server # noqa: F401
-
- run_api_server_worker_proc(listen_address, sock, args, client_config, **uvicorn_kwargs)
+ state.openai_serving_chat_with_tokens = None
-import vllm.entrypoints.cli.serve
import vllm.entrypoints.openai.api_server
from vllm.entrypoints.openai.api_server import build_app as _original_build_app
-def custom_build_app(args: Namespace, supported_tasks: tuple):
+def custom_build_app(args: Namespace, supported_tasks: tuple, model_config=None):
"""
Wrap build_app to include our custom router.
"""
- app = _original_build_app(args, supported_tasks)
+ app = _original_build_app(args, supported_tasks, model_config)
app.include_router(router)
return app
-# Also monkey patch run_api_server_worker_proc for multi-api-server mode
-# This is needed because worker processes spawned by run_multi_api_server
-# re-import modules and would otherwise use the original run_server_worker
vllm.entrypoints.openai.api_server.init_app_state = custom_init_app_state
vllm.entrypoints.openai.api_server.build_app = custom_build_app
-vllm.entrypoints.cli.serve.run_api_server_worker_proc = custom_run_api_server_worker_proc
# Adapted from vllm/entrypoints/cli/serve.py
diff --git a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py
index 9d37e72d2a..e7f4b911f6 100644
--- a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py
+++ b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py
@@ -250,7 +250,7 @@ async def create_chat_completion_with_tokens(
request_metadata,
reasoning_parser,
)
- except GenerationError as e:
- return self._convert_generation_error_to_response(e)
+ except GenerationError:
+ raise # Let FastAPI's global generation_error_handler handle it
except ValueError as e:
return self.create_error_response(e)
diff --git a/src/prime_rl/trainer/models/__init__.py b/src/prime_rl/trainer/models/__init__.py
index b2e2068217..2daeb771f0 100644
--- a/src/prime_rl/trainer/models/__init__.py
+++ b/src/prime_rl/trainer/models/__init__.py
@@ -10,6 +10,7 @@
from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
+from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM, Gemma4TextConfig
from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput, cast_float_and_contiguous
@@ -20,6 +21,7 @@
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
# Make custom config discoverable by AutoConfig
+AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True)
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
AutoConfig.register("glm4_moe", Glm4MoeConfig, exist_ok=True)
AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True)
@@ -29,6 +31,7 @@
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict())
+_CUSTOM_CAUSAL_LM_MAPPING.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(LlamaConfig, LlamaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(AfmoeConfig, AfmoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Glm4MoeConfig, Glm4MoeForCausalLM, exist_ok=True)
@@ -62,6 +65,7 @@ def supports_custom_impl(model_config: PretrainedConfig) -> bool:
# Used by get_model() to dispatch VLMs that have a custom text model implementation.
# Points to the same unified class — the config drives text-only vs VLM behavior.
_CUSTOM_VLM_MAPPING: dict[str, type] = {
+ "gemma4": Gemma4ForCausalLM,
"qwen3_5_moe": Qwen3_5MoeForCausalLM,
}
diff --git a/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py b/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py
index 2090c72d51..24e8013722 100644
--- a/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py
+++ b/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py
@@ -1,5 +1,4 @@
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
-from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
@@ -106,7 +105,7 @@ def __init__(
# Validate rope configs
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
- rope_config_validation(self)
+ self.standardize_rope_params()
super().__init__(
tie_word_embeddings=tie_word_embeddings,
diff --git a/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py b/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
index 136ad70306..e26f919140 100644
--- a/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
+++ b/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
@@ -508,7 +508,7 @@ def forward(
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
- "input_embeds": inputs_embeds,
+ "inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": None,
diff --git a/src/prime_rl/trainer/models/gemma4/__init__.py b/src/prime_rl/trainer/models/gemma4/__init__.py
new file mode 100644
index 0000000000..300b15f0e2
--- /dev/null
+++ b/src/prime_rl/trainer/models/gemma4/__init__.py
@@ -0,0 +1,5 @@
+from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
+
+from prime_rl.trainer.models.gemma4.modeling_gemma4 import Gemma4ForCausalLM
+
+__all__ = ["Gemma4TextConfig", "Gemma4ForCausalLM"]
diff --git a/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py b/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py
new file mode 100644
index 0000000000..6a0ae367e9
--- /dev/null
+++ b/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py
@@ -0,0 +1,760 @@
+"""Gemma4 custom implementation for PrimeRL training.
+
+Supports both dense (31B) and MoE (26B-A4B) variants, in text-only and VLM modes:
+- Hybrid sliding window + global attention (5:1 pattern)
+- Dual RoPE: theta=10K for sliding, theta=1M + partial_rotary_factor=0.25 for global
+- K=V sharing on global attention layers (no v_proj)
+- QKV norms (q/k with scale, v without scale)
+- Attention scaling = 1.0 (QK norms handle magnitude)
+- Logit softcapping (tanh at 30.0)
+- Per-layer learnable scalar
+- Scaled embeddings (× sqrt(hidden_size))
+
+VLM mode is auto-detected from the config (presence of vision_config). In VLM mode,
+the HF vision tower is used as-is and only the language model uses our custom impl.
+"""
+
+from typing import Optional, Union
+
+import torch
+from torch import Tensor, nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.generation import GenerationMixin
+from transformers.modeling_layers import GradientCheckpointingLayer
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
+from transformers.processing_utils import Unpack
+from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
+
+from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
+from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput
+from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig
+
+# flash-attention-2
+try:
+ from flash_attn import flash_attn_varlen_func
+except ImportError:
+ flash_attn_varlen_func = None # type: ignore
+
+# flash-attention-3
+try:
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+except ImportError:
+ flash_attn_3_varlen_func = None # type: ignore
+
+try:
+ from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func
+except ImportError:
+ flash_attn_4_varlen_func = None # type: ignore
+
+
+# ---------------------------------------------------------------------------
+# Norms
+# ---------------------------------------------------------------------------
+
+
+class Gemma4RMSNormNoScale(nn.Module):
+ """RMSNorm without a learnable scale parameter, for value normalization."""
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ return hidden_states.to(input_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Scaled word embedding
+# ---------------------------------------------------------------------------
+
+
+class Gemma4ScaledWordEmbedding(nn.Embedding):
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
+
+
+# ---------------------------------------------------------------------------
+# Dual RoPE (per layer type)
+# ---------------------------------------------------------------------------
+
+
+def _compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None, head_dim_key=None):
+ config.standardize_rope_params()
+ rope_params = config.rope_parameters[layer_type] if layer_type else config.rope_parameters
+ base = rope_params["rope_theta"]
+ partial_rotary_factor = rope_params.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, head_dim_key, None) if head_dim_key else None
+ if head_dim is None:
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+ return inv_freq, 1.0
+
+
+class Gemma4DualRotaryEmbedding(nn.Module):
+ """Stores separate inv_freq buffers per layer type (sliding vs full)."""
+
+ def __init__(self, config: Gemma4TextConfig, device=None):
+ super().__init__()
+ self.config = config
+ self.layer_types = set(config.layer_types)
+ self.attention_scaling = {}
+
+ for layer_type in self.layer_types:
+ rope_params = config.rope_parameters[layer_type]
+ rope_type = rope_params["rope_type"]
+
+ if rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
+ else:
+ rope_init_fn = _compute_default_rope_parameters
+
+ kwargs = {"device": device, "layer_type": layer_type}
+ if layer_type == "full_attention" and rope_type == "proportional":
+ kwargs["head_dim_key"] = "global_head_dim"
+
+ inv_freq, attn_scaling = rope_init_fn(config, **kwargs)
+ self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False)
+ self.attention_scaling[layer_type] = attn_scaling
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor, layer_type: str):
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
+ attn_scaling = self.attention_scaling[layer_type]
+
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * attn_scaling
+ sin = emb.sin() * attn_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# ---------------------------------------------------------------------------
+# RoPE application
+# ---------------------------------------------------------------------------
+
+
+def _rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def _apply_rotary_pos_emb_single(t, cos, sin, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ rotary_dim = cos.shape[-1]
+ t_rot, t_pass = t[..., :rotary_dim], t[..., rotary_dim:]
+ t_embed = (t_rot * cos) + (_rotate_half(t_rot) * sin)
+ return torch.cat([t_embed, t_pass], dim=-1)
+
+
+# ---------------------------------------------------------------------------
+# Attention
+# ---------------------------------------------------------------------------
+
+_FLASH_ATTN_FUNCS = {
+ 2: flash_attn_varlen_func,
+ 3: flash_attn_3_varlen_func,
+ 4: flash_attn_4_varlen_func,
+}
+
+
+class Gemma4Attention(nn.Module):
+ """Gemma4 attention with hybrid sliding/global, K=V sharing, QKV norms."""
+
+ def __init__(self, config: Gemma4TextConfig, layer_idx: int, flash_attn_version: int = 2):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.is_sliding = self.layer_type == "sliding_attention"
+ self.sliding_window = config.sliding_window if self.is_sliding else None
+
+ # Global attention uses larger head_dim
+ self.head_dim = config.global_head_dim if (not self.is_sliding and config.global_head_dim) else config.head_dim
+ self.num_attention_heads = config.num_attention_heads
+
+ # K=V sharing: only on global layers when enabled
+ self.use_kv_sharing = config.attention_k_eq_v and not self.is_sliding
+ num_kv_heads = (
+ config.num_global_key_value_heads
+ if self.use_kv_sharing and config.num_global_key_value_heads
+ else config.num_key_value_heads
+ )
+ self.num_key_value_heads = num_kv_heads
+ self.num_key_value_groups = config.num_attention_heads // num_kv_heads
+
+ # Projections
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = (
+ nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias)
+ if not self.use_kv_sharing
+ else None
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ # QKV norms
+ self.q_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps))
+ self.k_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps))
+ self.v_norm = Gemma4RMSNormNoScale(self.head_dim, eps=config.rms_norm_eps)
+
+ # Flash attention
+ self._flash_attn_version = flash_attn_version
+ self._flash_attn_func = _FLASH_ATTN_FUNCS[flash_attn_version]
+ if self._flash_attn_version == 4:
+ self._flash_attn_call = torch._dynamo.disable(self._flash_attn_func)
+ else:
+ self._flash_attn_call = self._flash_attn_func
+
+ def _compute_flash_attention(self, q, k, v, cu_seqlens, max_seqlen):
+ args = [q, k, v, cu_seqlens, cu_seqlens]
+ if self._flash_attn_version != 4:
+ args.extend([max_seqlen, max_seqlen])
+ kwargs = {"causal": True, "softmax_scale": 1.0}
+ if self.sliding_window is not None:
+ kwargs["window_size"] = (self.sliding_window - 1, 0)
+ out = self._flash_attn_call(*args, **kwargs)
+ if isinstance(out, tuple):
+ out = out[0]
+ return out
+
+ def _compute_sdpa_attention(self, q, k, v, cu_seqlens):
+ """SDPA fallback for head_dim > 256 (global attention layers).
+
+ Handles packed sequences by building a block-diagonal causal mask from cu_seqlens.
+ """
+ # q/k/v: [total_tokens, heads, dim] -> [1, heads, total_tokens, dim]
+ q = q.unsqueeze(0).transpose(1, 2)
+ k = k.unsqueeze(0).transpose(1, 2)
+ v = v.unsqueeze(0).transpose(1, 2)
+
+ # GQA: repeat k/v heads to match q heads
+ if k.shape[1] != q.shape[1]:
+ n_rep = q.shape[1] // k.shape[1]
+ k = k.repeat_interleave(n_rep, dim=1)
+ v = v.repeat_interleave(n_rep, dim=1)
+
+ # Build block-diagonal causal mask for packed sequences
+ total_len = q.shape[2]
+ if cu_seqlens is not None and len(cu_seqlens) > 2:
+ # Multiple packed sequences — need block-diagonal mask
+ mask = torch.full((total_len, total_len), float("-inf"), device=q.device, dtype=q.dtype)
+ for i in range(len(cu_seqlens) - 1):
+ start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
+ seq_len = end - start
+ causal = torch.tril(torch.zeros(seq_len, seq_len, device=q.device, dtype=q.dtype))
+ causal.masked_fill_(
+ torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1),
+ float("-inf"),
+ )
+ mask[start:end, start:end] = causal
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=1.0)
+ else:
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True, scale=1.0)
+
+ # [1, heads, total_tokens, dim] -> [total_tokens, heads, dim]
+ return out.transpose(1, 2).squeeze(0)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ cu_seqlens: torch.LongTensor | None = None,
+ max_seqlen: int | None = None,
+ ) -> tuple[torch.Tensor, None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ cos, sin = position_embeddings
+
+ # Q projection + norm + RoPE
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states = self.q_norm(query_states)
+ query_states = _apply_rotary_pos_emb_single(query_states, cos, sin, unsqueeze_dim=2)
+ query_states = query_states.transpose(1, 2)
+
+ # K projection + norm + RoPE
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ # V: either from v_proj or reuse k_proj output (K=V sharing)
+ value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
+
+ key_states = self.k_norm(key_states)
+ key_states = _apply_rotary_pos_emb_single(key_states, cos, sin, unsqueeze_dim=2)
+ key_states = key_states.transpose(1, 2)
+
+ value_states = self.v_norm(value_states)
+ value_states = value_states.transpose(1, 2)
+
+ # Flash attention expects [total_tokens, heads, dim]
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # FlashAttention only supports head_dim <= 256; fall back to SDPA for global layers (head_dim=512)
+ if self.head_dim > 256:
+ attn_output = self._compute_sdpa_attention(query_states[0], key_states[0], value_states[0], cu_seqlens)
+ else:
+ attn_output = self._compute_flash_attention(
+ query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen
+ )
+ attn_output = attn_output.contiguous().view(1, attn_output.shape[0], -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None
+
+
+# ---------------------------------------------------------------------------
+# MLP
+# ---------------------------------------------------------------------------
+
+
+class Gemma4MLP(nn.Module):
+ def __init__(self, config: Gemma4TextConfig):
+ super().__init__()
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# ---------------------------------------------------------------------------
+# Decoder layer
+# ---------------------------------------------------------------------------
+
+
+def _get_flash_attn_version(attn_impl: str) -> int:
+ mapping = {
+ "flash_attention_2": 2,
+ "flash_attention_3": 3,
+ "fa4": 4,
+ }
+ return mapping.get(attn_impl, 2)
+
+
+class Gemma4DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Gemma4TextConfig, layer_idx: int):
+ super().__init__()
+ self.self_attn = Gemma4Attention(
+ config, layer_idx, flash_attn_version=_get_flash_attn_version(config._attn_implementation)
+ )
+ self.mlp = Gemma4MLP(config)
+
+ # 4 layernorms
+ self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.pre_feedforward_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.post_feedforward_layernorm = RMSNorm(
+ RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
+ )
+
+ # Per-layer scalar
+ self.register_buffer("layer_scalar", torch.ones(1))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ cu_seqlens: torch.LongTensor | None = None,
+ max_seqlen: int | None = None,
+ ) -> torch.Tensor:
+ # Self-attention block
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # FFN block
+ residual = hidden_states
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ hidden_states *= self.layer_scalar
+ return hidden_states
+
+
+# ---------------------------------------------------------------------------
+# VLM helpers
+# ---------------------------------------------------------------------------
+
+
+def _has_vlm_keys(state_dict: dict[str, Tensor]) -> bool:
+ return any(k.startswith("model.language_model.") for k in state_dict)
+
+
+def _remap_lm_keys(state_dict: dict[str, Tensor], to_flat: bool = True) -> None:
+ """Remap language model keys between VLM and flat format for weight conversion.
+
+ to_flat=True: model.language_model.* -> model.*
+ to_flat=False: model.* -> model.language_model.*
+
+ Vision keys (model.vision_tower.*, model.embed_vision.*) are never touched.
+ """
+ VISION_PREFIXES = ("model.vision_tower.", "model.embed_vision.")
+ src = "model.language_model." if to_flat else "model."
+ dst = "model." if to_flat else "model.language_model."
+ for k in [
+ k for k in list(state_dict.keys()) if k.startswith(src) and not any(k.startswith(p) for p in VISION_PREFIXES)
+ ]:
+ state_dict[dst + k[len(src) :]] = state_dict.pop(k)
+
+
+# ---------------------------------------------------------------------------
+# PreTrained base
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4PreTrainedModel(PreTrainedModelPrimeRL):
+ config_class = Gemma4TextConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Gemma4DecoderLayer"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ @classmethod
+ def _has_moe_keys(cls, state_dict: dict[str, Tensor]) -> bool:
+ return any("experts." in k for k in state_dict)
+
+ @classmethod
+ def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool:
+ if not cls._has_moe_keys(state_dict):
+ return True # Dense: HF format = PrimeRL format
+ return any("experts.gate_up_proj" in k or "experts.0.gate_proj" in k for k in state_dict)
+
+ @classmethod
+ def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool:
+ if not cls._has_moe_keys(state_dict):
+ return True # Dense: HF format = PrimeRL format
+ return any("experts.w1" in k for k in state_dict)
+
+ @classmethod
+ def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
+ from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_prime_to_hf
+
+ vlm = _has_vlm_keys(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=True)
+ convert_prime_to_hf(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=False)
+ return state_dict
+
+ @classmethod
+ def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
+ from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_hf_to_prime
+
+ vlm = _has_vlm_keys(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=True)
+ convert_hf_to_prime(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=False)
+ return state_dict
+
+ @classmethod
+ def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]:
+ from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_prime_layer_to_hf
+
+ vlm = _has_vlm_keys(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=True)
+ convert_prime_layer_to_hf(state_dict, layer_idx)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=False)
+ return state_dict
+
+ @classmethod
+ def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]:
+ from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_hf_layer_to_prime
+
+ vlm = _has_vlm_keys(state_dict)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=True)
+ convert_hf_layer_to_prime(state_dict, layer_idx)
+ if vlm:
+ _remap_lm_keys(state_dict, to_flat=False)
+ return state_dict
+
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4Model(Gemma4PreTrainedModel):
+ def __init__(self, config: Gemma4TextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = Gemma4ScaledWordEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ embed_scale=config.hidden_size**0.5,
+ )
+ if config.enable_moe_block:
+ from prime_rl.trainer.models.gemma4_moe.modeling_gemma4_moe import Gemma4MoeDecoderLayer
+
+ layer_cls = Gemma4MoeDecoderLayer
+ else:
+ layer_cls = Gemma4DecoderLayer
+ self.layers = nn.ModuleList([layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+ self.norm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.rotary_emb = Gemma4DualRotaryEmbedding(config)
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"):
+ flat_position_ids = position_ids.view(-1)
+ seqlens = torch.cat(
+ [
+ flat_position_ids[0:1],
+ flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1,
+ flat_position_ids[-1:] + 1,
+ ]
+ )
+ max_seqlen = seqlens.max().item()
+ cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32)
+ torch._dynamo.mark_dynamic(cu_seqlens, 0)
+ else:
+ max_seqlen = None
+ cu_seqlens = None
+
+ hidden_states = inputs_embeds
+
+ # Compute RoPE embeddings per layer type
+ unique_layer_types = set(self.config.layer_types)
+ position_embeddings = {}
+ for layer_type in unique_layer_types:
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[layer_idx]
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings[layer_type],
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states)
+
+
+# ---------------------------------------------------------------------------
+# VLM composite model
+# ---------------------------------------------------------------------------
+
+
+class Gemma4VLMModel(nn.Module):
+ """Composite VLM body: HF vision tower + custom PrimeRL text model."""
+
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.config = config
+ from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4VisionModel
+
+ self.vision_tower = Gemma4VisionModel._from_config(config.vision_config)
+ self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
+ self.language_model = Gemma4Model(config.text_config)
+
+ def get_input_embeddings(self):
+ return self.language_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.language_model.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ **kwargs,
+ ) -> BaseModelOutputWithPast:
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
+
+ if pixel_values is not None:
+ pixel_values = pixel_values.type(self.vision_tower.dtype)
+ vision_output = self.vision_tower(pixel_values, return_dict=True)
+ image_features = self.embed_vision(vision_output.last_hidden_state)
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ image_mask = input_ids == self.config.image_token_id
+ image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
+
+ if position_ids is None:
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
+
+ return self.language_model(
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Causal LM (unified text-only + VLM)
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
+ """Unified Gemma4 model for both text-only and VLM configs.
+
+ When config has a vision_config, creates a composite model with HF's frozen
+ vision tower + custom text model. Otherwise creates a text-only model.
+ """
+
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+ self._is_vlm = hasattr(config, "vision_config")
+
+ if self._is_vlm:
+ self.model = Gemma4VLMModel(config)
+ text_config = config.text_config
+ self._tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ else:
+ self.model = Gemma4Model(config)
+ text_config = config
+
+ self.vocab_size = text_config.vocab_size
+ self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
+ self.final_logit_softcapping = text_config.final_logit_softcapping
+ self.post_init()
+
+ def get_input_embeddings(self):
+ if self._is_vlm:
+ return self.model.get_input_embeddings()
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ if self._is_vlm:
+ self.model.set_input_embeddings(value)
+ else:
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ temperature: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> PrimeLmOutput:
+ if position_ids is None:
+ if inputs_embeds is not None:
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
+ else:
+ position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)
+
+ if self._is_vlm:
+ outputs = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ )
+ else:
+ outputs = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ return self.lm_head(
+ hidden_states[:, slice_indices, :],
+ labels[:, slice_indices] if labels is not None else None,
+ temperature=temperature,
+ )
+
+ def _get_text_config(self):
+ return self.config.text_config if self._is_vlm else self.config
+
+ def init_buffers_post_meta(self):
+ text_config = self._get_text_config()
+
+ # Reinitialize embed_scale (non-persistent buffer)
+ if self._is_vlm:
+ embed_tokens = self.model.language_model.embed_tokens
+ rotary_emb = self.model.language_model.rotary_emb
+ else:
+ embed_tokens = self.model.embed_tokens
+ rotary_emb = self.model.rotary_emb
+
+ embed_tokens.embed_scale.fill_(text_config.hidden_size**0.5)
+
+ # Initialize dual RoPE inv_freq buffers
+ for layer_type in rotary_emb.layer_types:
+ rope_params = text_config.rope_parameters[layer_type]
+ rope_type = rope_params["rope_type"]
+ if rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
+ else:
+ rope_init_fn = _compute_default_rope_parameters
+
+ kwargs = {"device": getattr(rotary_emb, f"{layer_type}_inv_freq").device, "layer_type": layer_type}
+ if layer_type == "full_attention" and rope_type == "proportional":
+ kwargs["head_dim_key"] = "global_head_dim"
+
+ inv_freq, attn_scaling = rope_init_fn(text_config, **kwargs)
+ getattr(rotary_emb, f"{layer_type}_inv_freq").copy_(inv_freq)
+ rotary_emb.attention_scaling[layer_type] = attn_scaling
diff --git a/src/prime_rl/trainer/models/gemma4_moe/__init__.py b/src/prime_rl/trainer/models/gemma4_moe/__init__.py
new file mode 100644
index 0000000000..7c79a57790
--- /dev/null
+++ b/src/prime_rl/trainer/models/gemma4_moe/__init__.py
@@ -0,0 +1,5 @@
+from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
+
+from prime_rl.trainer.models.gemma4_moe.modeling_gemma4_moe import Gemma4MoeForCausalLM
+
+__all__ = ["Gemma4TextConfig", "Gemma4MoeForCausalLM"]
diff --git a/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py b/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py
new file mode 100644
index 0000000000..b4648f0ac7
--- /dev/null
+++ b/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py
@@ -0,0 +1,66 @@
+"""Weight conversion between HF and PrimeRL formats for Gemma4 MoE.
+
+HF format: experts.gate_up_proj (num_experts, 2*moe_dim, dim) + experts.down_proj (num_experts, dim, moe_dim)
+ router.norm.weight, router.proj.weight, router.scale, router.per_expert_scale
+PrimeRL format: experts.w1 (num_experts, moe_dim, dim), experts.w2 (num_experts, dim, moe_dim),
+ experts.w3 (num_experts, moe_dim, dim)
+ router keys stay the same (no rename needed since Gemma4 router has its own structure)
+"""
+
+from torch import Tensor
+
+
+def get_max_layer_num(state_dict: dict[str, Tensor]) -> int:
+ return max(int(i.split(".")[2]) for i in state_dict.keys() if "model.layers." in i) + 1
+
+
+def convert_hf_layer_to_prime(state_dict: dict[str, Tensor], layer_idx: int):
+ """Convert Gemma4 MoE layer from HF fused format to PrimeRL w1/w2/w3 format."""
+ i = layer_idx
+ gate_up_key = f"model.layers.{i}.experts.gate_up_proj"
+ down_key = f"model.layers.{i}.experts.down_proj"
+
+ if gate_up_key not in state_dict:
+ return
+
+ gate_up_proj = state_dict.pop(gate_up_key)
+ down_proj = state_dict.pop(down_key)
+
+ num_experts, fused_dim, dim = gate_up_proj.shape
+ moe_dim = fused_dim // 2
+
+ # Split gate_up into w1 (gate) and w3 (up)
+ state_dict[f"model.layers.{i}.experts.w1"] = gate_up_proj[:, :moe_dim, :]
+ state_dict[f"model.layers.{i}.experts.w3"] = gate_up_proj[:, moe_dim:, :]
+ state_dict[f"model.layers.{i}.experts.w2"] = down_proj
+
+
+def convert_prime_layer_to_hf(state_dict: dict[str, Tensor], layer_idx: int):
+ """Convert Gemma4 MoE layer from PrimeRL w1/w2/w3 to HF per-expert format."""
+ i = layer_idx
+ w1_key = f"model.layers.{i}.experts.w1"
+
+ if w1_key not in state_dict:
+ return
+
+ w1 = state_dict.pop(w1_key)
+ w2 = state_dict.pop(f"model.layers.{i}.experts.w2")
+ w3 = state_dict.pop(f"model.layers.{i}.experts.w3")
+
+ num_experts = w1.shape[0]
+ for j in range(num_experts):
+ state_dict[f"model.layers.{i}.experts.{j}.gate_proj.weight"] = w1[j]
+ state_dict[f"model.layers.{i}.experts.{j}.down_proj.weight"] = w2[j]
+ state_dict[f"model.layers.{i}.experts.{j}.up_proj.weight"] = w3[j]
+
+
+def convert_hf_to_prime(state_dict: dict[str, Tensor]):
+ num_layers = get_max_layer_num(state_dict)
+ for i in range(num_layers):
+ convert_hf_layer_to_prime(state_dict, i)
+
+
+def convert_prime_to_hf(state_dict: dict[str, Tensor]):
+ num_layers = get_max_layer_num(state_dict)
+ for i in range(num_layers):
+ convert_prime_layer_to_hf(state_dict, i)
diff --git a/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py b/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py
new file mode 100644
index 0000000000..34f9822ebc
--- /dev/null
+++ b/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py
@@ -0,0 +1,435 @@
+"""Gemma4 MoE custom implementation for PrimeRL training.
+
+Extends the dense Gemma4 model with:
+- 128 sparse experts (top-8) running in parallel with a shared dense MLP
+- Custom router: RMSNorm (no scale) + learnable scale + softmax + renormalize + per_expert_scale
+- Router input is the pre-MLP residual (not the MLP output)
+- Expert activation: gelu_pytorch_tanh (not silu)
+- MoE weight conversion: HF fused gate_up_proj → PrimeRL w1/w2/w3
+"""
+
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from transformers.activations import ACT2FN
+from transformers.generation import GenerationMixin
+from transformers.modeling_layers import GradientCheckpointingLayer
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
+from transformers.processing_utils import Unpack
+from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
+
+from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
+from prime_rl.trainer.models.gemma4.modeling_gemma4 import (
+ Gemma4Attention,
+ Gemma4DualRotaryEmbedding,
+ Gemma4MLP,
+ Gemma4RMSNormNoScale,
+ Gemma4ScaledWordEmbedding,
+ _compute_default_rope_parameters,
+ _get_flash_attn_version,
+)
+from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import (
+ convert_hf_layer_to_prime,
+ convert_hf_to_prime,
+ convert_prime_layer_to_hf,
+ convert_prime_to_hf,
+)
+from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput
+from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig
+
+# ---------------------------------------------------------------------------
+# Gemma4 MoE Router
+# ---------------------------------------------------------------------------
+
+
+class Gemma4Router(nn.Module):
+ """Gemma4 router: RMSNorm(no_scale) → scale → linear → softmax → topk → renorm → per_expert_scale."""
+
+ def __init__(self, config: Gemma4TextConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_experts = config.num_experts
+ self.top_k = config.top_k_experts
+ self.scalar_root_size = self.hidden_size**-0.5
+
+ self.norm = Gemma4RMSNormNoScale(self.hidden_size, eps=config.rms_norm_eps)
+ self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
+ self.scale = nn.Parameter(torch.ones(self.hidden_size))
+ self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states * self.scale * self.scalar_root_size
+
+ expert_scores = self.proj(hidden_states)
+ router_probs = F.softmax(expert_scores.float(), dim=-1)
+
+ top_k_weights, top_k_index = torch.topk(router_probs, k=self.top_k, dim=-1)
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
+ top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
+
+ num_tokens_per_expert = torch.histc(
+ top_k_index.reshape(-1).float(),
+ bins=self.num_experts,
+ min=0,
+ max=self.num_experts,
+ )
+
+ return top_k_weights.to(hidden_states.dtype), top_k_index, num_tokens_per_expert
+
+
+# ---------------------------------------------------------------------------
+# Gemma4 MoE Experts (w1/w2/w3 format with gelu_pytorch_tanh)
+# ---------------------------------------------------------------------------
+
+
+def _run_experts_for_loop(
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w3: torch.Tensor,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+ act_fn,
+) -> torch.Tensor:
+ num_tokens_per_expert_list = num_tokens_per_expert.tolist()
+ num_padding = x.shape[0] - sum(num_tokens_per_expert_list)
+ x_splits = torch.split(x[: sum(num_tokens_per_expert_list)], num_tokens_per_expert_list, dim=0)
+ out_splits = []
+ for expert_idx, x_expert in enumerate(x_splits):
+ h = act_fn(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
+ h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
+ h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
+ out_splits.append(h)
+ out = torch.cat(out_splits, dim=0)
+ if num_padding > 0:
+ out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
+ return out
+
+
+def _run_experts_grouped_mm(
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ w3: torch.Tensor,
+ x: torch.Tensor,
+ num_tokens_per_expert: torch.Tensor,
+ act_fn,
+) -> torch.Tensor:
+ offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
+ h = act_fn(torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets))
+ h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets)
+ out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x)
+ return out
+
+
+class Gemma4Experts(nn.Module):
+ """Expert weights in PrimeRL w1/w2/w3 format with configurable activation."""
+
+ def __init__(self, config: Gemma4TextConfig, use_grouped_mm: bool = True):
+ super().__init__()
+ self.num_experts = config.num_experts
+ dim = config.hidden_size
+ hidden_dim = config.moe_intermediate_size
+ self.w1 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim))
+ self.w2 = nn.Parameter(torch.empty(self.num_experts, dim, hidden_dim))
+ self.w3 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim))
+ self.act_fn = ACT2FN[config.hidden_activation]
+ self.use_grouped_mm = use_grouped_mm
+
+ def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
+ if self.use_grouped_mm:
+ return _run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert, self.act_fn)
+ return _run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert, self.act_fn)
+
+
+# ---------------------------------------------------------------------------
+# MoE Decoder Layer
+# ---------------------------------------------------------------------------
+
+
+class Gemma4MoeDecoderLayer(GradientCheckpointingLayer):
+ """Gemma4 decoder layer with shared MLP + sparse MoE in parallel."""
+
+ def __init__(self, config: Gemma4TextConfig, layer_idx: int):
+ super().__init__()
+ self.self_attn = Gemma4Attention(
+ config, layer_idx, flash_attn_version=_get_flash_attn_version(config._attn_implementation)
+ )
+ self.mlp = Gemma4MLP(config)
+
+ # Standard 4 layernorms
+ self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.pre_feedforward_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.post_feedforward_layernorm = RMSNorm(
+ RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
+ )
+
+ # MoE components
+ self.router = Gemma4Router(config)
+ self.experts = Gemma4Experts(config, use_grouped_mm=getattr(config, "use_grouped_mm", True))
+
+ # Extra norms for MoE parallel path
+ self.post_feedforward_layernorm_1 = RMSNorm(
+ RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
+ )
+ self.post_feedforward_layernorm_2 = RMSNorm(
+ RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
+ )
+ self.pre_feedforward_layernorm_2 = RMSNorm(
+ RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
+ )
+
+ # Per-layer scalar
+ self.register_buffer("layer_scalar", torch.ones(1))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ cu_seqlens: torch.LongTensor | None = None,
+ max_seqlen: int | None = None,
+ ) -> torch.Tensor:
+ # Self-attention block
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # FFN + MoE parallel block
+ residual = hidden_states
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+
+ # Shared MLP output normalized
+ hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
+
+ # Sparse experts: router takes pre-MLP residual
+ hidden_states_flat = residual.reshape(-1, residual.shape[-1])
+ top_k_weights, top_k_index, num_tokens_per_expert = self.router(hidden_states_flat)
+
+ # Reorder tokens by expert assignment
+ selected_flat = top_k_index.reshape(-1)
+ token_indices_sorted = torch.argsort(selected_flat, stable=True)
+ scores_sorted = top_k_weights.view(-1)[token_indices_sorted]
+ token_indices_sorted = token_indices_sorted // self.router.top_k
+
+ dim = hidden_states_flat.shape[-1]
+ routed_indices = token_indices_sorted.reshape(-1, 1).expand(-1, dim)
+
+ # Pre-norm for expert input
+ expert_input = self.pre_feedforward_layernorm_2(hidden_states_flat)
+ routed_input = torch.gather(expert_input, dim=0, index=routed_indices)
+ routed_input = (routed_input.float() * scores_sorted.reshape(-1, 1)).to(routed_input.dtype)
+
+ routed_output = self.experts(routed_input, num_tokens_per_expert)
+
+ # Scatter back
+ hidden_states_2 = torch.zeros_like(hidden_states_flat)
+ hidden_states_2.scatter_add_(dim=0, index=routed_indices, src=routed_output)
+ hidden_states_2 = hidden_states_2.reshape(residual.shape)
+ hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
+
+ # Combine shared MLP + sparse experts
+ hidden_states = hidden_states_1 + hidden_states_2
+
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ hidden_states *= self.layer_scalar
+ return hidden_states
+
+
+# ---------------------------------------------------------------------------
+# PreTrained base
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4MoePreTrainedModel(PreTrainedModelPrimeRL):
+ config_class = Gemma4TextConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Gemma4MoeDecoderLayer"]
+ _supports_flash_attn = True
+
+ @classmethod
+ def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool:
+ return any("experts.gate_up_proj" in k or "experts.0.gate_proj" in k for k in state_dict)
+
+ @classmethod
+ def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool:
+ return any("experts.w1" in k for k in state_dict)
+
+ @classmethod
+ def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
+ convert_prime_to_hf(state_dict)
+ return state_dict
+
+ @classmethod
+ def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
+ convert_hf_to_prime(state_dict)
+ return state_dict
+
+ @classmethod
+ def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]:
+ convert_prime_layer_to_hf(state_dict, layer_idx)
+ return state_dict
+
+ @classmethod
+ def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]:
+ convert_hf_layer_to_prime(state_dict, layer_idx)
+ return state_dict
+
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4MoeModel(Gemma4MoePreTrainedModel):
+ def __init__(self, config: Gemma4TextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = Gemma4ScaledWordEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ embed_scale=config.hidden_size**0.5,
+ )
+ self.layers = nn.ModuleList(
+ [Gemma4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps))
+ self.rotary_emb = Gemma4DualRotaryEmbedding(config)
+ self.gradient_checkpointing = False
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"):
+ flat_position_ids = position_ids.view(-1)
+ seqlens = torch.cat(
+ [
+ flat_position_ids[0:1],
+ flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1,
+ flat_position_ids[-1:] + 1,
+ ]
+ )
+ max_seqlen = seqlens.max().item()
+ cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32)
+ torch._dynamo.mark_dynamic(cu_seqlens, 0)
+ else:
+ max_seqlen = None
+ cu_seqlens = None
+
+ hidden_states = inputs_embeds
+
+ unique_layer_types = set(self.config.layer_types)
+ position_embeddings = {}
+ for layer_type in unique_layer_types:
+ position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ layer_type = self.config.layer_types[layer_idx]
+ hidden_states = decoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings[layer_type],
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states)
+
+
+# ---------------------------------------------------------------------------
+# Causal LM
+# ---------------------------------------------------------------------------
+
+
+@auto_docstring
+class Gemma4MoeForCausalLM(Gemma4MoePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
+
+ def __init__(self, config: Gemma4TextConfig):
+ super().__init__(config)
+ self.model = Gemma4MoeModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.final_logit_softcapping = config.final_logit_softcapping
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ temperature: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> PrimeLmOutput:
+ if position_ids is None:
+ if inputs_embeds is not None:
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
+ else:
+ position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ return self.lm_head(
+ hidden_states[:, slice_indices, :],
+ labels[:, slice_indices] if labels is not None else None,
+ temperature=temperature,
+ )
+
+ def init_buffers_post_meta(self):
+ self.model.embed_tokens.embed_scale.fill_(self.config.hidden_size**0.5)
+ rotary_emb = self.model.rotary_emb
+ for layer_type in rotary_emb.layer_types:
+ rope_params = self.config.rope_parameters[layer_type]
+ rope_type = rope_params["rope_type"]
+ if rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
+ else:
+ rope_init_fn = _compute_default_rope_parameters
+
+ kwargs = {"device": getattr(rotary_emb, f"{layer_type}_inv_freq").device, "layer_type": layer_type}
+ if layer_type == "full_attention" and rope_type == "proportional":
+ kwargs["head_dim_key"] = "global_head_dim"
+
+ inv_freq, attn_scaling = rope_init_fn(self.config, **kwargs)
+ getattr(rotary_emb, f"{layer_type}_inv_freq").copy_(inv_freq)
+ rotary_emb.attention_scaling[layer_type] = attn_scaling
diff --git a/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py b/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py
index a09d03d648..6fcceec309 100644
--- a/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py
+++ b/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py
@@ -1,5 +1,4 @@
from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_rope_utils import rope_config_validation
class MiniMaxM2Config(PretrainedConfig):
@@ -124,7 +123,7 @@ def __init__(
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
- rope_config_validation(self)
+ self.standardize_rope_params()
# MoE arguments
self.num_local_experts = num_local_experts
diff --git a/src/prime_rl/trainer/perf.py b/src/prime_rl/trainer/perf.py
index a4aeaf9b2c..b22a05fd72 100644
--- a/src/prime_rl/trainer/perf.py
+++ b/src/prime_rl/trainer/perf.py
@@ -135,16 +135,16 @@ def get_active_mm_params(config: PretrainedConfig) -> float:
sparse_mlp_params = 0
# Some MoE models (e.g. DeepSeek) use moe_intermediate_size, others (e.g. Granite) just use intermediate_size
- moe_intermediate_size = getattr(config, "moe_intermediate_size", intermediate_size)
- if hasattr(config, "num_shared_experts"): # Shared experts
+ moe_intermediate_size = getattr(config, "moe_intermediate_size", None) or intermediate_size
+ if hasattr(config, "num_shared_experts") and config.num_shared_experts: # Shared experts
sparse_mlp_params += num_sparse_layers * config.num_shared_experts * 3 * moe_intermediate_size * hidden_size
- if hasattr(config, "num_experts_per_tok"): # Routed experts
+ if hasattr(config, "num_experts_per_tok") and config.num_experts_per_tok: # Routed experts
sparse_mlp_params += (
num_sparse_layers * config.num_experts_per_tok * 3 * moe_intermediate_size * hidden_size
)
if hasattr(config, "n_routed_experts"): # DeepSeek Router
sparse_mlp_params += num_sparse_layers * config.n_routed_experts * hidden_size
- elif hasattr(config, "num_experts"): # Qwen Router
+ elif hasattr(config, "num_experts") and config.num_experts is not None: # Qwen Router
sparse_mlp_params += num_sparse_layers * config.num_experts * hidden_size
else:
sparse_mlp_params = 0
diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py
index 626ed151c7..c363f406b5 100644
--- a/src/prime_rl/trainer/rl/train.py
+++ b/src/prime_rl/trainer/rl/train.py
@@ -1,3 +1,5 @@
+import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before import
+
from contextlib import nullcontext
import time
from datetime import timedelta
diff --git a/src/prime_rl/trainer/sft/data.py b/src/prime_rl/trainer/sft/data.py
index 34ad868df2..80b7a823b3 100644
--- a/src/prime_rl/trainer/sft/data.py
+++ b/src/prime_rl/trainer/sft/data.py
@@ -124,6 +124,7 @@ def __init__(
seq_len: int = 128,
non_dp_size: int = 1,
loss_mask_config: LossMaskConfig = LossMaskConfig(),
+ chat_template_kwargs: dict | None = None,
max_examples: int | None = None,
max_epochs: int | None = None,
):
@@ -136,6 +137,7 @@ def __init__(
self.seed = seed
self.seq_len = seq_len
self.loss_mask_config = loss_mask_config
+ self.chat_template_kwargs = chat_template_kwargs or {}
self.max_examples = max_examples
self.max_epochs = max_epochs
@@ -205,12 +207,16 @@ def should_mask(message: dict) -> bool:
case _:
raise ValueError(f"Invalid message role: {message['role']}")
+ per_example_kwargs = example.get("chat_template_kwargs", {})
+ if self.chat_template_kwargs:
+ per_example_kwargs = {**self.chat_template_kwargs, **per_example_kwargs}
+
input_ids, loss_mask = build_incremental_token_mask(
self.tokenizer,
messages,
role_to_mask=should_mask,
tools=tools,
- chat_template_kwargs=example.get("chat_template_kwargs", {}),
+ chat_template_kwargs=per_example_kwargs,
collapse_consecutive_tool_messages=True,
)
@@ -547,6 +553,7 @@ def setup_dataset(
seed=config.seed,
seq_len=config.seq_len,
loss_mask_config=config.loss_mask,
+ chat_template_kwargs=config.chat_template_kwargs,
non_dp_size=non_dp_size,
max_epochs=max_epochs,
)
diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py
index afca134c9f..ce3a5c02a5 100644
--- a/src/prime_rl/trainer/sft/train.py
+++ b/src/prime_rl/trainer/sft/train.py
@@ -1,3 +1,5 @@
+import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before import
+
import time
from contextlib import nullcontext
from datetime import timedelta
diff --git a/src/prime_rl/utils/vlm.py b/src/prime_rl/utils/vlm.py
index 15b5f1bf3b..18040c1c0a 100644
--- a/src/prime_rl/utils/vlm.py
+++ b/src/prime_rl/utils/vlm.py
@@ -25,6 +25,7 @@ class VLMModelInfo:
# Central registry: model_type -> architecture info.
VLM_REGISTRY: dict[str, VLMModelInfo] = {
+ "gemma4": VLMModelInfo(vision_encoder_attr="model.vision_tower", language_model_attr="model.language_model"),
"qwen3_vl": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
"qwen3_5": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
"qwen3_5_moe": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"),
diff --git a/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py b/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py
index d7892dfda6..e5fd36c96e 100644
--- a/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py
+++ b/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py
@@ -525,7 +525,7 @@ def forward(
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
- "input_embeds": inputs_embeds,
+ "inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
diff --git a/tests/unit/train/models/test_gemma4.py b/tests/unit/train/models/test_gemma4.py
new file mode 100644
index 0000000000..8ace0ae83b
--- /dev/null
+++ b/tests/unit/train/models/test_gemma4.py
@@ -0,0 +1,140 @@
+import pytest
+import torch
+from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
+
+from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM
+from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head
+
+pytestmark = [pytest.mark.gpu]
+
+
+def _tiny_config(**overrides):
+ defaults = dict(
+ vocab_size=256,
+ hidden_size=256,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=64,
+ global_head_dim=128,
+ intermediate_size=512,
+ attention_k_eq_v=True,
+ num_global_key_value_heads=1,
+ sliding_window=128,
+ max_position_embeddings=512,
+ final_logit_softcapping=30.0,
+ hidden_activation="gelu_pytorch_tanh",
+ _attn_implementation="flash_attention_2",
+ )
+ defaults.update(overrides)
+ return Gemma4TextConfig(**defaults)
+
+
+def test_gemma4_forward():
+ config = _tiny_config()
+ model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16)
+ inject_prime_lm_head(model)
+
+ input_ids = torch.randint(0, 256, (1, 32), device="cuda")
+ position_ids = torch.arange(32, device="cuda").unsqueeze(0)
+
+ with torch.no_grad():
+ output = model(input_ids=input_ids, position_ids=position_ids)
+ assert "logits" in output
+ assert output["logits"].shape == (1, 32, 256)
+
+
+def test_gemma4_backward():
+ config = _tiny_config()
+ model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16)
+ inject_prime_lm_head(model)
+
+ input_ids = torch.randint(0, 256, (1, 32), device="cuda")
+ position_ids = torch.arange(32, device="cuda").unsqueeze(0)
+
+ output = model(input_ids=input_ids, position_ids=position_ids)
+ logits = output["logits"]
+ assert logits is not None
+ loss = logits.sum()
+ loss.backward()
+
+ # Check gradients flow
+ has_grad = False
+ for name, param in model.named_parameters():
+ if param.requires_grad and param.grad is not None:
+ has_grad = True
+ break
+ assert has_grad, "No gradients found"
+
+
+def test_gemma4_no_kv_sharing():
+ """Test without K=V sharing (like a hypothetical smaller model)."""
+ config = _tiny_config(attention_k_eq_v=False, num_global_key_value_heads=None)
+ model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16)
+ inject_prime_lm_head(model)
+
+ input_ids = torch.randint(0, 256, (1, 32), device="cuda")
+ position_ids = torch.arange(32, device="cuda").unsqueeze(0)
+
+ with torch.no_grad():
+ output = model(input_ids=input_ids, position_ids=position_ids)
+ assert output["logits"].shape == (1, 32, 256)
+
+
+def _tiny_moe_config(**overrides):
+ defaults = dict(
+ vocab_size=256,
+ hidden_size=128,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=32,
+ global_head_dim=64,
+ intermediate_size=256,
+ attention_k_eq_v=True,
+ num_global_key_value_heads=1,
+ sliding_window=128,
+ max_position_embeddings=512,
+ final_logit_softcapping=30.0,
+ hidden_activation="gelu_pytorch_tanh",
+ _attn_implementation="flash_attention_2",
+ enable_moe_block=True,
+ num_experts=8,
+ top_k_experts=2,
+ moe_intermediate_size=64,
+ )
+ defaults.update(overrides)
+ return Gemma4TextConfig(**defaults)
+
+
+def test_gemma4_moe_forward():
+ config = _tiny_moe_config()
+ model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16)
+ inject_prime_lm_head(model)
+
+ input_ids = torch.randint(0, 256, (1, 32), device="cuda")
+ position_ids = torch.arange(32, device="cuda").unsqueeze(0)
+
+ with torch.no_grad():
+ output = model(input_ids=input_ids, position_ids=position_ids)
+ assert output["logits"].shape == (1, 32, 256)
+
+
+def test_gemma4_moe_backward():
+ config = _tiny_moe_config()
+ model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16)
+ inject_prime_lm_head(model)
+
+ input_ids = torch.randint(0, 256, (1, 32), device="cuda")
+ position_ids = torch.arange(32, device="cuda").unsqueeze(0)
+
+ output = model(input_ids=input_ids, position_ids=position_ids)
+ loss = output["logits"].sum()
+ loss.backward()
+
+ has_grad = False
+ for name, param in model.named_parameters():
+ if param.requires_grad and param.grad is not None:
+ has_grad = True
+ break
+ assert has_grad, "No gradients found"
diff --git a/tests/unit/train/models/test_qwen3_5_moe_vlm.py b/tests/unit/train/models/test_qwen3_5_moe_vlm.py
index e0901c60e3..c629f6be25 100644
--- a/tests/unit/train/models/test_qwen3_5_moe_vlm.py
+++ b/tests/unit/train/models/test_qwen3_5_moe_vlm.py
@@ -21,6 +21,7 @@ def _tiny_vlm_config():
tc.vocab_size = 256
tc.hidden_size = 256
tc.num_hidden_layers = 2
+ tc.layer_types = ["linear_attention", "full_attention"]
tc.num_attention_heads = 4
tc.num_key_value_heads = 2
tc.head_dim = 64
diff --git a/uv.lock b/uv.lock
index 65744a9cf6..3956d73499 100644
--- a/uv.lock
+++ b/uv.lock
@@ -14,7 +14,7 @@ supported-markers = [
overrides = [
{ name = "nvidia-cudnn-cu12", specifier = ">=9.15" },
{ name = "nvidia-cutlass-dsl", specifier = ">=4.4.1" },
- { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=5c1c72b" },
+ { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" },
]
[[package]]
@@ -475,7 +475,7 @@ wheels = [
[[package]]
name = "compressed-tensors"
-version = "0.13.0"
+version = "0.14.0.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "loguru", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -483,9 +483,9 @@ dependencies = [
{ name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/fc/65/88dd1c58fb9d0ded51b5c86471b937a1525f91fad2211a6f051dc1ea822d/compressed_tensors-0.13.0.tar.gz", hash = "sha256:23893824d3498ea3f1a829f14a8fa85f9a5e76a34c711a038b8d7c619ca9a67c", size = 200995, upload-time = "2025-12-16T16:03:55.397Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/eb/f1/4c9b01ceaf82ad58ad00919223e09b8e74d4073a2ba8e3ab2f97521ef65c/compressed_tensors-0.14.0.1.tar.gz", hash = "sha256:5ad3841184b6f5020e06059b2463191c5c57a144bb97cab9159978d8118839b1", size = 226393, upload-time = "2026-03-11T17:04:35.57Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/0b/b5/61ac2563c62490922b603c09113a083fd74af3630ec3931e769484d6dcb5/compressed_tensors-0.13.0-py3-none-any.whl", hash = "sha256:3518799c9baf034eb642efb551db6b0537b8713d45a64fe4def26f7f8d6cabec", size = 192620, upload-time = "2025-12-16T16:03:53.041Z" },
+ { url = "https://files.pythonhosted.org/packages/0a/26/16a13993ecf4fdc9c39d63b3a6daabafd32a452cf68b81aa9eb3b8170913/compressed_tensors-0.14.0.1-py3-none-any.whl", hash = "sha256:46c4940a3a779d3d97108c294bfcd9acf4bd0491f7c6737c320f0e815ec732e4", size = 196454, upload-time = "2026-03-11T17:04:33.2Z" },
]
[[package]]
@@ -559,19 +559,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/f3/6b032a554019cfb3447e671798c1bd3e79b5f1af20d10253f56cea269ef2/cuda_python-12.9.4-py3-none-any.whl", hash = "sha256:d2cacea882a69863f1e7d27ee71d75f0684f4c76910aff839067e4f89c902279", size = 7594, upload-time = "2025-10-21T14:55:12.846Z" },
]
-[[package]]
-name = "cupy-cuda12x"
-version = "13.6.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "fastrlock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
-]
-wheels = [
- { url = "https://files.pythonhosted.org/packages/12/c5/7e7fc4816d0de0154e5d9053242c3a08a0ca8b43ee656a6f7b3b95055a7b/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a6970ceefe40f9acbede41d7fe17416bd277b1bd2093adcde457b23b578c5a59", size = 127334633, upload-time = "2025-08-18T08:24:43.065Z" },
- { url = "https://files.pythonhosted.org/packages/e0/95/d7e1295141e7d530674a3cc567e13ed0eb6b81524cb122d797ed996b5bea/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:79b0cacb5e8b190ef409f9e03f06ac8de1b021b0c0dda47674d446f5557e0eb1", size = 112886268, upload-time = "2025-08-18T08:24:49.294Z" },
-]
-
[[package]]
name = "datasets"
version = "4.0.0"
@@ -848,18 +835,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/5c/9516828037b1680de14bae4d93b2868b26417059e1e5cb7d8ad7be649197/fastcore-1.12.31-py3-none-any.whl", hash = "sha256:2634f117fe3a2f6c250f8500f70211c96158399936e90fa6d1010d2d516bf03a", size = 98509, upload-time = "2026-03-22T02:21:19.274Z" },
]
-[[package]]
-name = "fastrlock"
-version = "0.8.3"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/73/b1/1c3d635d955f2b4bf34d45abf8f35492e04dbd7804e94ce65d9f928ef3ec/fastrlock-0.8.3.tar.gz", hash = "sha256:4af6734d92eaa3ab4373e6c9a1dd0d5ad1304e172b1521733c6c3b3d73c8fa5d", size = 79327, upload-time = "2024-12-17T11:03:39.638Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/57/21/ea1511b0ef0d5457efca3bf1823effb9c5cad4fc9dca86ce08e4d65330ce/fastrlock-0.8.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:85a49a1f1e020097d087e1963e42cea6f307897d5ebe2cb6daf4af47ffdd3eed", size = 52201, upload-time = "2024-12-17T11:02:19.512Z" },
- { url = "https://files.pythonhosted.org/packages/80/07/cdecb7aa976f34328372f1c4efd6c9dc1b039b3cc8d3f38787d640009a25/fastrlock-0.8.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5f13ec08f1adb1aa916c384b05ecb7dbebb8df9ea81abd045f60941c6283a670", size = 53924, upload-time = "2024-12-17T11:02:20.85Z" },
- { url = "https://files.pythonhosted.org/packages/88/6d/59c497f8db9a125066dd3a7442fab6aecbe90d6fec344c54645eaf311666/fastrlock-0.8.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0ea4e53a04980d646def0f5e4b5e8bd8c7884288464acab0b37ca0c65c482bfe", size = 52140, upload-time = "2024-12-17T11:02:22.263Z" },
- { url = "https://files.pythonhosted.org/packages/62/04/9138943c2ee803d62a48a3c17b69de2f6fa27677a6896c300369e839a550/fastrlock-0.8.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:38340f6635bd4ee2a4fb02a3a725759fe921f2ca846cb9ca44531ba739cc17b4", size = 53261, upload-time = "2024-12-17T11:02:24.418Z" },
-]
-
[[package]]
name = "filelock"
version = "3.19.1"
@@ -932,9 +907,17 @@ dependencies = [
{ name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
+[[package]]
+name = "flashinfer-cubin"
+version = "0.6.6"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/12/e8/826f9452bc5f76b94d7eb025f03dcaf1b51b9ed7790386c0285191e69be4/flashinfer_cubin-0.6.6-py3-none-any.whl", hash = "sha256:36508dfc792eb5ecfb15d2c140a7702812e1fa1ab0fb03929b2ed55e3e8191f3", size = 267661457, upload-time = "2026-03-11T01:36:36.538Z" },
+]
+
[[package]]
name = "flashinfer-python"
-version = "0.6.4"
+version = "0.6.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -951,9 +934,9 @@ dependencies = [
{ name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/77/45/15645d2a4ee81d08206f3e132a77323e48312f510462415d7cd1122eba43/flashinfer_python-0.6.4.tar.gz", hash = "sha256:e6ab798bd1030e5ff7a3bc6952f36386c406928f60b79cf964a6db7aa7ccde75", size = 5337134, upload-time = "2026-02-19T07:33:36.647Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/03/70/c5a235297351021f5d3d3233523a85f5a6468495587489ad2f257e8eafe2/flashinfer_python-0.6.6.tar.gz", hash = "sha256:0730ba7c7aad332961933bcebc5119762797161ede57d955f6fd199818ed1d92", size = 5344156, upload-time = "2026-03-11T01:36:21.434Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/17/9a/d2bab76d2bb15062c6a2329614653e4f8bec9c78eec9069856ef0c7c0a79/flashinfer_python-0.6.4-py3-none-any.whl", hash = "sha256:105596b505892ae330af84e250ee0eb6fc2c3a22e8dc42bd46de1b90d36004c8", size = 7819999, upload-time = "2026-02-19T07:33:34.82Z" },
+ { url = "https://files.pythonhosted.org/packages/e0/61/385d06755f3ab66333018285657adf0daf8a90a129448231fd09e315bd2e/flashinfer_python-0.6.6-py3-none-any.whl", hash = "sha256:078f158636969eec1a0d3dea19c3ca90b426b66df89bbf7b7b8276ce2ec08148", size = 7817047, upload-time = "2026-03-11T01:36:19.198Z" },
]
[[package]]
@@ -1089,19 +1072,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/df/81/dcf11b5915a39024d4a98ef14b16cb0c636a4f2f26ef657982d3144c6544/grpcio-1.78.0rc2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ea66e360e5ea032a1f6dde926915ba683edca811ed6f0e0620c52e264ba364e4", size = 7698266, upload-time = "2026-01-16T07:28:51.342Z" },
]
-[[package]]
-name = "grpcio-reflection"
-version = "1.78.0rc2"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/f7/4f/7c5032a6365cf964ef3e1fea7df76294d2804e6414802e36000d57f0eeb3/grpcio_reflection-1.78.0rc2.tar.gz", hash = "sha256:d925a43ef37f93a2129575111bf4add284d344eecc7e610b65f3b61a59c671b7", size = 19111, upload-time = "2026-01-16T08:01:53.077Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/61/c3/a1082fb098bb23c19c5df32b78c83bd000289592e079fa2daf7dddba7a43/grpcio_reflection-1.78.0rc2-py3-none-any.whl", hash = "sha256:c4a5f8e9681a83c0ddc80b8648d74144bf6c7bad1e7e2598c77754a8e7998d61", size = 22838, upload-time = "2026-01-16T08:01:41.466Z" },
-]
-
[[package]]
name = "grpclib"
version = "0.4.9"
@@ -1139,14 +1109,14 @@ wheels = [
[[package]]
name = "hf-xet"
-version = "1.3.0b0"
+version = "1.4.3"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/41/4b/59c9a123813f1db5441f037d9a0e9171bd480c4ff3a9562976a8bf8e49ad/hf_xet-1.3.0b0.tar.gz", hash = "sha256:ece497f54c80992e1b145a89065443f6acf9a6b51d8e4648e53e3ad650fbec06", size = 615265, upload-time = "2026-01-28T20:37:21.892Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/53/92/ec9ad04d0b5728dca387a45af7bc98fbb0d73b2118759f5f6038b61a57e8/hf_xet-1.4.3.tar.gz", hash = "sha256:8ddedb73c8c08928c793df2f3401ec26f95be7f7e516a7bee2fbb546f6676113", size = 670477, upload-time = "2026-03-31T22:40:07.874Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/7e/34/a16aa436c3e59007678cee07f5cf3929ba053b14ae16dffd3be1270d3927/hf_xet-1.3.0b0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa63330e14196071fafc0e369a8e9d3f847335f10d33ca152537fb47bf263440", size = 58044866, upload-time = "2026-01-28T20:36:31.13Z" },
- { url = "https://files.pythonhosted.org/packages/d0/74/2202cc67e82a6eb64e42314e92ff2ee798e6dd5ee394967880b1370e878e/hf_xet-1.3.0b0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:1f8a48df4e67ab695ae802f0d4d07c3d28fed64ea12decef13f8a8550783a42d", size = 53103717, upload-time = "2026-01-28T20:36:26.633Z" },
- { url = "https://files.pythonhosted.org/packages/8d/eb/9cbf85387377adaef317918318d1921b456625fa2535f39e642ed77076e4/hf_xet-1.3.0b0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ae20bc5405c06538ba820e6a3f818df793fee554f83cf071caa641d0b36f08f8", size = 53485235, upload-time = "2026-01-28T20:37:05.554Z" },
- { url = "https://files.pythonhosted.org/packages/0d/28/302fae85503e423e356042a3332e3b2b714b30ce27db2fe415260973bf0e/hf_xet-1.3.0b0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a566da3478ae73ccd6bca8cb8d1ef85bcd4c36e79912cbfafb5b33890a0f1301", size = 55093706, upload-time = "2026-01-28T20:37:09.561Z" },
+ { url = "https://files.pythonhosted.org/packages/df/9a/a24b26dc8a65f0ecc0fe5be981a19e61e7ca963b85e062c083f3a9100529/hf_xet-1.4.3-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc360b70c815bf340ed56c7b8c63aacf11762a4b099b2fe2c9bd6d6068668c08", size = 4212320, upload-time = "2026-03-31T22:39:42.922Z" },
+ { url = "https://files.pythonhosted.org/packages/53/60/46d493db155d2ee2801b71fb1b0fd67696359047fdd8caee2c914cc50c79/hf_xet-1.4.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:39f2d2e9654cd9b4319885733993807aab6de9dfbd34c42f0b78338d6617421f", size = 3991546, upload-time = "2026-03-31T22:39:41.335Z" },
+ { url = "https://files.pythonhosted.org/packages/bc/f5/067363e1c96c6b17256910830d1b54099d06287e10f4ec6ec4e7e08371fc/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:49ad8a8cead2b56051aa84d7fce3e1335efe68df3cf6c058f22a65513885baac", size = 4193200, upload-time = "2026-03-31T22:40:01.936Z" },
+ { url = "https://files.pythonhosted.org/packages/42/4b/53951592882d9c23080c7644542fda34a3813104e9e11fa1a7d82d419cb8/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7716d62015477a70ea272d2d68cd7cad140f61c52ee452e133e139abfe2c17ba", size = 4429392, upload-time = "2026-03-31T22:40:03.492Z" },
]
[[package]]
@@ -1209,7 +1179,7 @@ wheels = [
[[package]]
name = "huggingface-hub"
-version = "1.4.1"
+version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -1218,14 +1188,13 @@ dependencies = [
{ name = "httpx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "shellingham", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "typer-slim", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "typer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/c4/fc/eb9bc06130e8bbda6a616e1b80a7aa127681c448d6b49806f61db2670b61/huggingface_hub-1.4.1.tar.gz", hash = "sha256:b41131ec35e631e7383ab26d6146b8d8972abc8b6309b963b306fbcca87f5ed5", size = 642156, upload-time = "2026-02-06T09:20:03.013Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/88/bb/62c7aa86f63a05e2f9b96642fdef9b94526a23979820b09f5455deff4983/huggingface_hub-1.9.0.tar.gz", hash = "sha256:0ea5be7a56135c91797cae6ad726e38eaeb6eb4b77cefff5c9d38ba0ecf874f7", size = 750326, upload-time = "2026-04-03T08:35:55.888Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/d5/ae/2f6d96b4e6c5478d87d606a1934b5d436c4a2bce6bb7c6fdece891c128e3/huggingface_hub-1.4.1-py3-none-any.whl", hash = "sha256:9931d075fb7a79af5abc487106414ec5fba2c0ae86104c0c62fd6cae38873d18", size = 553326, upload-time = "2026-02-06T09:20:00.728Z" },
+ { url = "https://files.pythonhosted.org/packages/73/37/0d15d16150e1829f3e90962c99f28257f6de9e526a680b4c6f5acdb54fd2/huggingface_hub-1.9.0-py3-none-any.whl", hash = "sha256:2999328c058d39fd19ab748dd09bd4da2fbaa4f4c1ddea823eab103051e14a1f", size = 637355, upload-time = "2026-04-03T08:35:53.897Z" },
]
[[package]]
@@ -1499,16 +1468,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571, upload-time = "2025-05-05T12:32:29.534Z" },
]
-[[package]]
-name = "kaldi-native-fbank"
-version = "1.22.3"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/3a/2c/84076b352107ce12d56f28c313f1aca1be332d953dd96aec7b84976e6d53/kaldi-native-fbank-1.22.3.tar.gz", hash = "sha256:387bf87225c6b83c93ae652eeaef1b4d531994b6e398e7a77189de340674f9af", size = 71013, upload-time = "2025-10-09T02:31:21.487Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/43/28/6f4fd8953c0b3f30de4526fd024095032abcdc25b6736c77a891687c604e/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f5a44b4a83cf9bf13d3f77858928068b06d3ec2238c27ff2e39393fbf7749c9f", size = 298887, upload-time = "2025-10-09T02:30:53.739Z" },
- { url = "https://files.pythonhosted.org/packages/84/90/01ef7331c52b1eaf9916f3f7a535155aac2e9e2ddad12a141613d92758c7/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f16e74372fe9e20abb4183f98a8e2288d5ee4c48d04d94b6160311170e007661", size = 322002, upload-time = "2025-10-09T02:30:13.04Z" },
-]
-
[[package]]
name = "kubernetes"
version = "35.0.0"
@@ -1813,7 +1772,7 @@ wheels = [
[[package]]
name = "mistral-common"
-version = "1.9.1"
+version = "1.11.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jsonschema", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -1825,9 +1784,9 @@ dependencies = [
{ name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/db/ce/685b8127a326478e05501cb4c9ca23d1cd9f37e16c465a1e832c75aea709/mistral_common-1.9.1.tar.gz", hash = "sha256:550583d70a395c3586cfb748ffab53bd1d7c3409507f0efc0118bff30ffb26e9", size = 6338922, upload-time = "2026-02-12T10:53:41.639Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/61/97/753c85b5c0a19f4331ac99e0300ac8da06d4b29b629c9cb03064b38561bd/mistral_common-1.11.0.tar.gz", hash = "sha256:439b7fa38f9c3f020154af51bdf30eb81def507643017d8ce9f798384ec47ec3", size = 6355512, upload-time = "2026-04-01T13:54:12.36Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ac/72/a38bb1fd9fd4d4ef990341c9dd1a7c8061f1951e10efa6d50c0a3f04eced/mistral_common-1.9.1-py3-none-any.whl", hash = "sha256:9e2b2520b6f67bac2e2bb06fcf985b7a1277b01938da2b7cda8cf0fdbfa92e91", size = 6518623, upload-time = "2026-02-12T10:53:39.457Z" },
+ { url = "https://files.pythonhosted.org/packages/60/e4/73ad3c27e3fb613c3ce0953c928202c46cddebac3989b87be1b6f305a9f6/mistral_common-1.11.0-py3-none-any.whl", hash = "sha256:1d3ecaf7c3aa7338cb37b596fd0fb294485753958ee8e7254a6cc23eb30b249b", size = 6531513, upload-time = "2026-04-01T13:54:16.536Z" },
]
[package.optional-dependencies]
@@ -2813,10 +2772,10 @@ requires-dist = [
{ name = "torch", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu128" },
{ name = "torchdata", specifier = ">=0.11.0" },
{ name = "torchtitan", git = "https://github.com/pytorch/torchtitan?rev=a1fdd7e" },
- { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=5c1c72b" },
+ { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" },
{ name = "uvloop", specifier = ">=0.21.0" },
{ name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=d3c830c" },
- { name = "vllm", specifier = ">=0.17.0" },
+ { name = "vllm", specifier = ">=0.19.0" },
{ name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" },
{ name = "wandb", specifier = ">=0.24.2" },
{ name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" },
@@ -3237,30 +3196,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/28/e688da97806474963b2983a8231572733de11ed48f8899410c0afc8bdf15/quack_kernels-0.3.6-py3-none-any.whl", hash = "sha256:973d974ccca816014006af8f136294ba0f15b1c324197f8b54ffc6500666e0e5", size = 198865, upload-time = "2026-03-24T13:29:59.107Z" },
]
-[[package]]
-name = "ray"
-version = "2.49.1"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "jsonschema", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "msgpack", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
-]
-wheels = [
- { url = "https://files.pythonhosted.org/packages/c6/ba/77eae921fc2595087516df6efc0ca03caacc14d16592341985916f1aed13/ray-2.49.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:cf63b916e399c2a4d484249611a96cee283cef32e544126115a4ad3e872c34eb", size = 69287149, upload-time = "2025-09-03T00:25:47.605Z" },
- { url = "https://files.pythonhosted.org/packages/00/02/c81260c0f94bd34a1442ea488bdd433dfc9e6ed6211c9a59bc4157b8e00e/ray-2.49.1-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:484064fca02732e0b6f4a09dad0d1fb6abd4ca4b6d9bf7c26ab7a17a2887cd09", size = 70114899, upload-time = "2025-09-03T00:25:53.144Z" },
-]
-
-[package.optional-dependencies]
-cgraph = [
- { name = "cupy-cuda12x", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
-]
-
[[package]]
name = "referencing"
version = "0.36.2"
@@ -3277,14 +3212,14 @@ wheels = [
[[package]]
name = "regex"
-version = "2025.9.1"
+version = "2026.3.32"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/b2/5a/4c63457fbcaf19d138d72b2e9b39405954f98c0349b31c601bfcb151582c/regex-2025.9.1.tar.gz", hash = "sha256:88ac07b38d20b54d79e704e38aa3bd2c0f8027432164226bdee201a1c0c9c9ff", size = 400852, upload-time = "2025-09-01T22:10:10.479Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/81/93/5ab3e899c47fa7994e524447135a71cd121685a35c8fe35029005f8b236f/regex-2026.3.32.tar.gz", hash = "sha256:f1574566457161678297a116fa5d1556c5a4159d64c5ff7c760e7c564bf66f16", size = 415605, upload-time = "2026-03-28T21:49:22.012Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/0f/74/f933a607a538f785da5021acf5323961b4620972e2c2f1f39b6af4b71db7/regex-2025.9.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9dc5991592933a4192c166eeb67b29d9234f9c86344481173d1bc52f73a7104", size = 797441, upload-time = "2025-09-01T22:08:39.108Z" },
- { url = "https://files.pythonhosted.org/packages/b2/02/5c891bb5fe0691cc1bad336e3a94b9097fbcf9707ec8ddc1dce9f0397289/regex-2025.9.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47829ffaf652f30d579534da9085fe30c171fa2a6744a93d52ef7195dc38218b", size = 801991, upload-time = "2025-09-01T22:08:44.072Z" },
- { url = "https://files.pythonhosted.org/packages/f1/ae/fd10d6ad179910f7a1b3e0a7fde1ef8bb65e738e8ac4fd6ecff3f52252e4/regex-2025.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e978e5a35b293ea43f140c92a3269b6ab13fe0a2bf8a881f7ac740f5a6ade85", size = 786651, upload-time = "2025-09-01T22:08:46.079Z" },
- { url = "https://files.pythonhosted.org/packages/93/fa/b4c6dbdedc85ef4caec54c817cd5f4418dbfa2453214119f2538082bf666/regex-2025.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:656563e620de6908cd1c9d4f7b9e0777e3341ca7db9d4383bcaa44709c90281e", size = 788138, upload-time = "2025-09-01T22:08:51.933Z" },
+ { url = "https://files.pythonhosted.org/packages/4a/cf/1955bb5567bc491bd63068e17f75ab0c9ff5e9d08466beec7e347f5e768d/regex-2026.3.32-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66a5083c3ffe5a5a95f8281ea47a88072d4f24001d562d1d9d28d4cdc005fec5", size = 796431, upload-time = "2026-03-28T21:46:33.101Z" },
+ { url = "https://files.pythonhosted.org/packages/0a/fe/661043d1c263b0d9d10c6ff4e9c9745f3df9641c62b51f96a3473638e7ce/regex-2026.3.32-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f54840bea73541652f1170dc63402a5b776fc851ad36a842da9e5163c1f504a0", size = 801512, upload-time = "2026-03-28T21:46:38.587Z" },
+ { url = "https://files.pythonhosted.org/packages/b6/c8/d833397b70cd1bacfcdc0a611f0e2c1f5b91fee8eedd88affcee770cbbb6/regex-2026.3.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:66d3126afe7eac41759cd5f0b3b246598086e88e70527c0d68c9e615b81771c4", size = 785837, upload-time = "2026-03-28T21:46:42.926Z" },
+ { url = "https://files.pythonhosted.org/packages/18/f4/04ed04ebf335a44083695c22772be6a42efa31900415555563acf02cb4de/regex-2026.3.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b56993a7aeb4140c4770f4f7965c9e5af4f024457d06e23c01b0d47501cb18ed", size = 788332, upload-time = "2026-03-28T21:46:50.454Z" },
]
[[package]]
@@ -3888,8 +3823,8 @@ wheels = [
[[package]]
name = "transformers"
-version = "5.3.0.dev0"
-source = { git = "https://github.com/huggingface/transformers.git?rev=5c1c72b#5c1c72be5f864d10d0efe8ece0768d9ed6ee4fdd" }
+version = "5.5.0"
+source = { git = "https://github.com/huggingface/transformers.git?rev=c1c3424#c1c34249fa27deefbd4a377dfbf883a39baf5c6d" }
dependencies = [
{ name = "huggingface-hub", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -3938,18 +3873,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7a/ed/d6fca788b51d0d4640c4bc82d0e85bad4b49809bca36bf4af01b4dcb66a7/typer-0.23.0-py3-none-any.whl", hash = "sha256:79f4bc262b6c37872091072a3cb7cb6d7d79ee98c0c658b4364bdcde3c42c913", size = 56668, upload-time = "2026-02-11T15:22:21.075Z" },
]
-[[package]]
-name = "typer-slim"
-version = "0.23.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "typer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/1f/8a/881cfd399a119db89619dc1b93d36e2fb6720ddb112bceff41203f1abd72/typer_slim-0.23.0.tar.gz", hash = "sha256:be8b60243df27cfee444c6db1b10a85f4f3e54d940574f31a996f78aa35a8254", size = 4773, upload-time = "2026-02-11T15:22:19.106Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/07/3e/ba3a222c80ee070d9497ece3e1fe77253c142925dd4c90f04278aac0a9eb/typer_slim-0.23.0-py3-none-any.whl", hash = "sha256:1d693daf22d998a7b1edab8413cdcb8af07254154ce3956c1664dc11b01e2f8b", size = 3399, upload-time = "2026-02-11T15:22:17.792Z" },
-]
-
[[package]]
name = "types-certifi"
version = "2021.10.8.3"
@@ -4131,7 +4054,7 @@ wheels = [
[[package]]
name = "vllm"
-version = "0.17.0"
+version = "0.19.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -4146,12 +4069,10 @@ dependencies = [
{ name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "fastapi", extra = ["standard"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "flashinfer-cubin", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "flashinfer-python", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "gguf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "grpcio-reflection", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "ijson", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "kaldi-native-fbank", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "lark", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "llguidance", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "lm-format-enforcer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -4162,6 +4083,7 @@ dependencies = [
{ name = "ninja", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "numba", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "nvidia-cudnn-frontend", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "openai-harmony", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -4184,7 +4106,6 @@ dependencies = [
{ name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "quack-kernels", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
- { name = "ray", extra = ["cgraph"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "sentencepiece", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -4202,10 +4123,10 @@ dependencies = [
{ name = "watchfiles", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "xgrammar", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/13/d5/af83a4262ca4d5692a93b3c322ae954e3e6c4e23f8f9db3ab87bd79c919e/vllm-0.17.0.tar.gz", hash = "sha256:b0b62e58ef4eb633ef371f2726976372cf6dfcb7ff2ea9ddf7194c1930d5629a", size = 30541311, upload-time = "2026-03-07T03:54:54.333Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/03/14/c330a72309051f762b357a2e41d5015bedbb106ad1e16a231bdfda2e2163/vllm-0.19.0.tar.gz", hash = "sha256:81e59cf87175e7a62eb8d9acf5989484bbd17089d5eface353f89067bda282d9", size = 31071745, upload-time = "2026-04-03T04:04:52.833Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/f2/72/78a48668f2631def18bbaaa331d7878bcfc5c3137455422aafb0748e1261/vllm-0.17.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:310fb82fe061ed75dceeb4aeb803cd8ee0d590337ec720f7abfb03a69314d710", size = 385329399, upload-time = "2026-03-07T03:54:34.261Z" },
- { url = "https://files.pythonhosted.org/packages/25/4f/972726f9a501f01203b5c4796e1932abbe435fae6d7715a4c3f1aad14a58/vllm-0.17.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:0296670a09d392ee43455d9bebf590d05a9bc2ebce5e25e2919222fc815158da", size = 432927988, upload-time = "2026-03-07T03:54:02.312Z" },
+ { url = "https://files.pythonhosted.org/packages/c8/51/467e7a8cb4838022daa731b7f8b34c228691e36f938e1803c3a702c7bd69/vllm-0.19.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:6ab90ccca5d7ca3bd2c8f90133f0fac85e8f4af582a1c67c6cc3f63c615521e3", size = 384650557, upload-time = "2026-04-03T04:05:52.513Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/08/6a431731e4c163bc1fab85b63e269d84104aad0fba98dac1af34fdc5077f/vllm-0.19.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:2d0e5fae45367bdbf111fcad68f4c0f8fdddd2f2fb643e52f0f2daebef7b41cf", size = 432281473, upload-time = "2026-04-03T04:05:22.07Z" },
]
[[package]]
@@ -4360,7 +4281,7 @@ wheels = [
[[package]]
name = "xgrammar"
-version = "0.1.29"
+version = "0.1.33"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -4370,10 +4291,10 @@ dependencies = [
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/02/a3/70dbe3ffd331a1e7e1ad5a95690a4086e6c7cdb8089f5c7eda712219ccec/xgrammar-0.1.29.tar.gz", hash = "sha256:cf195afa81b489eebf35d4c6f37f27136d05420739ab4a6f7f065c938d7e4baa", size = 2321317, upload-time = "2025-12-19T08:23:54.53Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/db/43/e5dfddb1d2a4fccf3e3a88f103e88698cdefc3182f4e169a359ffe1c1794/xgrammar-0.1.33.tar.gz", hash = "sha256:8dbe5fc3d76651ab1fac7a68fc2a118b885fa0ec7189927fb6e0dce0081aea99", size = 2398956, upload-time = "2026-03-27T10:16:36.582Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/57/94/18793c64bf0368075a34c06e196bf002f1e6ab0aee332268f44e8d356d5a/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6eb370a16b27a683e5f2b9e429ab41440c69977d4a504849ed61831b94cc704c", size = 34705239, upload-time = "2025-12-19T08:23:28.369Z" },
- { url = "https://files.pythonhosted.org/packages/3e/da/4c14e3e00be698009b52700f15326a23272b4b00475939b6acc86b151188/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79e6e4f5cd33be77418cf91efc482f2b3d773d309891224383bc8a4948ad7b07", size = 34906135, upload-time = "2025-12-19T08:23:30.838Z" },
+ { url = "https://files.pythonhosted.org/packages/4e/04/43d4baca876f5ae1b45897ec30a59801a2da37f16da1fcd85f9555e4c125/xgrammar-0.1.33-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c803e60d791854c5d1f271ece7e1f34d73c82dd4a8b2a06b7af5331482a78ac", size = 42133168, upload-time = "2026-03-27T10:15:16.994Z" },
+ { url = "https://files.pythonhosted.org/packages/f0/a8/672833a3cff027253793aa999401d8364896ebf396967e475c7a878b895f/xgrammar-0.1.33-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b8eaa533282a0efb0835db6998ae72e7b3c7875d7a52e360ffebff9b78c30a", size = 42205803, upload-time = "2026-03-27T10:15:21.599Z" },
]
[[package]]