diff --git a/configs/gemma4/rl_dense.toml b/configs/gemma4/rl_dense.toml new file mode 100644 index 0000000000..05e48cc74e --- /dev/null +++ b/configs/gemma4/rl_dense.toml @@ -0,0 +1,39 @@ +# Gemma4 31B dense RL test (reverse-text) +output_dir = "outputs/gemma4-dense-rl" +max_steps = 5 +seq_len = 2048 + +[slurm] +job_name = "gemma4-dense-rl" + +[deployment] +type = "single_node" +num_train_gpus = 4 +num_infer_gpus = 4 + +[model] +name = "google/gemma-4-31B-it" + +[wandb] +project = "gemma4-test" +name = "rl-dense" + +[orchestrator] +batch_size = 128 +rollouts_per_example = 16 + +[orchestrator.sampling] +max_tokens = 128 + +[[orchestrator.env]] +id = "reverse-text" + +[trainer.model] +attn = "flash_attention_2" + +[trainer.model.ac] + +[trainer.optim] +lr = 3e-6 + +[inference] diff --git a/configs/gemma4/rl_moe.toml b/configs/gemma4/rl_moe.toml new file mode 100644 index 0000000000..75be4d6d17 --- /dev/null +++ b/configs/gemma4/rl_moe.toml @@ -0,0 +1,43 @@ +# Gemma4 26B-A4B MoE RL test (reverse-text) +output_dir = "outputs/gemma4-moe-rl" +max_steps = 5 +seq_len = 2048 + +[slurm] +job_name = "gemma4-moe-rl" + +[deployment] +type = "single_node" +num_train_gpus = 4 +num_infer_gpus = 4 + +[model] +name = "google/gemma-4-26B-A4B-it" + +[wandb] +project = "gemma4-test" +name = "rl-moe" + +[orchestrator] +batch_size = 128 +rollouts_per_example = 16 + +[orchestrator.sampling] +max_tokens = 128 + +[[orchestrator.env]] +id = "reverse-text" + +[trainer.model] +attn = "flash_attention_2" + +[trainer.model.ac] + +[trainer.optim] +lr = 3e-6 + +[inference] +gpu_memory_utilization = 0.7 + +[inference.model] +max_model_len = 2048 diff --git a/configs/gemma4/sft_dense.toml b/configs/gemma4/sft_dense.toml new file mode 100644 index 0000000000..727271be25 --- /dev/null +++ b/configs/gemma4/sft_dense.toml @@ -0,0 +1,29 @@ +# Gemma4 31B dense SFT test +# Usage: uv run sft @ configs/gemma4/sft_dense.toml +output_dir = "outputs/gemma4-dense-sft" +max_steps = 10 + +[slurm] +job_name = "gemma4-dense-sft" + +[deployment] +type = "single_node" +num_gpus = 8 + +[model] +name = "google/gemma-4-31B-it" +attn = "flash_attention_2" + +[model.ac] + +[wandb] +project = "gemma4-test" +name = "sft-dense" + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 8 +seq_len = 2048 + +[data.chat_template_kwargs] +enable_thinking = true diff --git a/configs/gemma4/sft_moe.toml b/configs/gemma4/sft_moe.toml new file mode 100644 index 0000000000..8510096a23 --- /dev/null +++ b/configs/gemma4/sft_moe.toml @@ -0,0 +1,29 @@ +# Gemma4 26B-A4B MoE SFT test +# Usage: uv run sft @ configs/gemma4/sft_moe.toml +output_dir = "outputs/gemma4-moe-sft" +max_steps = 10 + +[slurm] +job_name = "gemma4-moe-sft" + +[deployment] +type = "single_node" +num_gpus = 8 + +[model] +name = "google/gemma-4-26B-A4B-it" +attn = "flash_attention_2" + +[model.ac] + +[wandb] +project = "gemma4-test" +name = "sft-moe" + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 8 +seq_len = 2048 + +[data.chat_template_kwargs] +enable_thinking = true diff --git a/configs/gemma4/sft_qwen3.toml b/configs/gemma4/sft_qwen3.toml new file mode 100644 index 0000000000..c088f02b31 --- /dev/null +++ b/configs/gemma4/sft_qwen3.toml @@ -0,0 +1,18 @@ +# Qwen3 0.6B SFT sanity check +output_dir = "outputs/qwen3-sft" +max_steps = 5 + +[slurm] +job_name = "qwen3-sft" + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[wandb] +project = "gemma4-test" +name = "sft-qwen3" + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 diff --git a/configs/gemma4/sft_qwen35.toml b/configs/gemma4/sft_qwen35.toml new file mode 100644 index 0000000000..b25d4893b2 --- /dev/null +++ b/configs/gemma4/sft_qwen35.toml @@ -0,0 +1,19 @@ +# Qwen3.5 MoE text-only SFT test (proving VLM bug) +# Usage: uv run sft @ configs/gemma4/sft_qwen35.toml +output_dir = "outputs/qwen35-moe-sft" +max_steps = 5 + +[slurm] +job_name = "qwen35-moe-sft" + +[model] +name = "Qwen/Qwen3.5-35B-A3B" + +[wandb] +project = "gemma4-test" +name = "sft-qwen35" + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 2048 diff --git a/pyproject.toml b/pyproject.toml index 957b85c4d2..6642b882a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "torch>=2.9.0", "torchdata>=0.11.0", "transformers", - "vllm>=0.17.0", + "vllm>=0.19.0", "wandb>=0.24.2", "ring-flash-attn>=0.1.8", "prime>=0.5.37", @@ -123,7 +123,7 @@ torch = { index = "pytorch-cu128" } verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "d3c830c" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } -transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" } +transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "abd9943b" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" } diff --git a/src/prime_rl/__init__.py b/src/prime_rl/__init__.py index e69de29bb2..d76436d3c5 100644 --- a/src/prime_rl/__init__.py +++ b/src/prime_rl/__init__.py @@ -0,0 +1 @@ +import prime_rl._compat # noqa: F401 — must run before ring_flash_attn is imported diff --git a/src/prime_rl/_compat.py b/src/prime_rl/_compat.py new file mode 100644 index 0000000000..5471030c84 --- /dev/null +++ b/src/prime_rl/_compat.py @@ -0,0 +1,19 @@ +"""Compatibility shim: ring_flash_attn + transformers >= 5.4. + +ring_flash_attn 0.1.8 imports `is_flash_attn_greater_or_equal_2_10` from +`transformers.modeling_flash_attention_utils`. This symbol was removed from +that module in transformers 5.4 (still available as a deprecated function +in `transformers.utils.import_utils`, scheduled for removal in 5.8). + +ring_flash_attn's except-branch is a no-op (imports the same symbol again), +so the import crashes on transformers >= 5.4. We patch the symbol back in as +`True` — the check is dead code since no one uses flash_attn < 2.1.0 anymore. + +Upstream fix: https://github.com/zhuzilin/ring-flash-attention/pull/85 +Remove this shim once ring_flash_attn ships a fixed version. +""" + +import transformers.modeling_flash_attention_utils as _mfau + +if not hasattr(_mfau, "is_flash_attn_greater_or_equal_2_10"): + _mfau.is_flash_attn_greater_or_equal_2_10 = True diff --git a/src/prime_rl/configs/sft.py b/src/prime_rl/configs/sft.py index 50d1baacd0..11c778130d 100644 --- a/src/prime_rl/configs/sft.py +++ b/src/prime_rl/configs/sft.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Annotated, Literal, TypeAlias +from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -86,6 +86,14 @@ class SFTDataConfig(BaseDataConfig): # Configuring loss_mask: LossMaskConfig = LossMaskConfig() + chat_template_kwargs: Annotated[ + dict[str, Any] | None, + Field( + description="Extra keyword arguments passed to tokenizer.apply_chat_template(). " + "E.g. {'enable_thinking': true} for models with thinking-aware templates (Gemma4, Qwen3)." + ), + ] = None + @model_validator(mode="after") def validate_subsets_and_splits(self): if self.subsets is not None or self.splits is not None: diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 9b0ffe053e..18c6c3c021 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -64,41 +64,8 @@ def slice_lora_a(self, lora_a): MergedColumnParallelLinearWithShardedLoRA.slice_lora_a = slice_lora_a -# Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode -def monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode(): - from vllm.v1.metrics import loggers as vllm_metrics_loggers - - _original_prometheus_stat_logger_init = vllm_metrics_loggers.PrometheusStatLogger.__init__ - - def _patched_prometheus_stat_logger_init(self, vllm_config, engine_indexes=None): - """Patched init that temporarily disables lora_config to skip the DP mode check.""" - original_lora_config = vllm_config.lora_config - vllm_config.lora_config = None - try: - _original_prometheus_stat_logger_init(self, vllm_config, engine_indexes) - finally: - vllm_config.lora_config = original_lora_config - # Re-initialize LoRA metrics if needed (after the DP check is bypassed) - if original_lora_config is not None: - self.labelname_max_lora = "max_lora" - self.labelname_waiting_lora_adapters = "waiting_lora_adapters" - self.labelname_running_lora_adapters = "running_lora_adapters" - self.max_lora = original_lora_config.max_loras - self.gauge_lora_info = vllm_metrics_loggers.PrometheusStatLogger._gauge_cls( - name="vllm:lora_requests_info", - documentation="Running stats on lora requests.", - multiprocess_mode="sum", - labelnames=[ - self.labelname_max_lora, - self.labelname_waiting_lora_adapters, - self.labelname_running_lora_adapters, - ], - ) - - vllm_metrics_loggers.PrometheusStatLogger.__init__ = _patched_prometheus_stat_logger_init - - # Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times +# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326) def monkey_patch_load_lora_adapter(): from http import HTTPStatus @@ -153,6 +120,7 @@ async def _patched_load_lora_adapter( # Monkeypatch LRUCacheWorkerLoRAManager to allow loading adapter inplace without doing it every request +# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326) def monkey_patch_LRUCacheWorkerLoRAManager(): from vllm.lora.worker_manager import LoRARequest, LRUCacheLoRAModelManager, LRUCacheWorkerLoRAManager @@ -278,109 +246,6 @@ def _patched_get_encode_kwargs(self): TokenizeParams.get_encode_kwargs = _patched_get_encode_kwargs -def monkey_patch_hermes_tool_parser_thread_safety(): - """Patch Hermes2ProToolParser to cache tokenizer encode/decode results. - - The original __init__ calls tokenizer.encode() and tokenizer.decode() on - every instantiation. Under concurrent load, the shared HuggingFace tokenizer's - Rust backend panics with ``RuntimeError: Already borrowed`` because multiple - threads mutably borrow the same internal state simultaneously. - - Fix: run the first __init__ (which calls encode/decode) under a lock, cache - the results, and reuse them for all subsequent instantiations without ever - touching the tokenizer again. - """ - import threading - - import regex as re - from vllm.tool_parsers.abstract_tool_parser import ToolParser - from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser - - _original_init = Hermes2ProToolParser.__init__ - _cache: dict[int, dict] = {} - _lock = threading.Lock() - - def _patched_init(self, tokenizer): - from vllm.tokenizers.mistral import MistralTokenizer - - # Resolve the actual tokenizer that __init__ will use for encode/decode - actual_tokenizer = tokenizer.tokenizer if isinstance(tokenizer, MistralTokenizer) else tokenizer - key = id(actual_tokenizer) - - if key in _cache: - # Fast path: skip encode/decode entirely, set up instance from cache - ToolParser.__init__(self, tokenizer) - if isinstance(tokenizer, MistralTokenizer): - self.model_tokenizer = tokenizer.tokenizer - self.current_tool_name_sent = False - self.prev_tool_call_arr = [] - self.current_tool_id = -1 - self.streamed_args_for_tool = [] - self.tool_call_start_token = "" - self.tool_call_end_token = "" - self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL) - self.scratch_pad_regex = re.compile(r"(.*?)", re.DOTALL) - cached = _cache[key] - self.tool_call_start_token_ids = cached["start_ids"] - self.tool_call_end_token_ids = cached["end_ids"] - self.tool_call_start_token_array = cached["start_array"] - self.tool_call_end_token_array = cached["end_array"] - self.buffered_delta_text = "" - return - - # Slow path: first instantiation for this tokenizer, run under lock - with _lock: - if key in _cache: - # Another thread populated it while we waited - _patched_init(self, tokenizer) - return - _original_init(self, tokenizer) - _cache[key] = { - "start_ids": self.tool_call_start_token_ids, - "end_ids": self.tool_call_end_token_ids, - "start_array": self.tool_call_start_token_array, - "end_array": self.tool_call_end_token_array, - } - - Hermes2ProToolParser.__init__ = _patched_init - - -def monkey_patch_tokenizer_thread_safety(): - """Patch HuggingFace tokenizer to make _encode_plus thread-safe. - - Under concurrent request load, vLLM's API server calls _encode_plus from - multiple async handlers simultaneously. _encode_plus mutates the Rust - tokenizer's internal state via set_truncation_and_padding (enable_truncation/ - enable_padding) and encode_special_tokens. The Rust backend uses RefCell-style - borrow tracking (PyO3), and concurrent mutable borrows cause it to panic - with ``RuntimeError: Already borrowed``. - - Fix: wrap the entire _encode_plus method in a per-tokenizer threading lock - so that state mutation and the subsequent encode call are atomic. - """ - import threading - - from transformers import PreTrainedTokenizerFast - - _original_encode_plus = PreTrainedTokenizerFast._encode_plus - _locks: dict[int, threading.Lock] = {} - _meta_lock = threading.Lock() - - def _get_lock(tokenizer_id: int) -> threading.Lock: - if tokenizer_id not in _locks: - with _meta_lock: - if tokenizer_id not in _locks: - _locks[tokenizer_id] = threading.Lock() - return _locks[tokenizer_id] - - def _patched_encode_plus(self, *args, **kwargs): - lock = _get_lock(id(self._tokenizer)) - with lock: - return _original_encode_plus(self, *args, **kwargs) - - PreTrainedTokenizerFast._encode_plus = _patched_encode_plus - - def monkey_patch_minimax_m2_for_lora(): """Patch vLLM's MiniMaxM2 model for LoRA compatibility. @@ -457,7 +322,7 @@ def _patched_forward(self, hidden_states): def monkey_patch_harmony_stop_token_propagation(): - """Fix: vLLM 0.17.0 doesn't merge harmony stop tokens into per-request SamplingParams. + """Fix: vLLM doesn't merge harmony stop tokens into per-request SamplingParams. The harmony mode sets stop_token_ids (including <|call|> and <|return|>) in default_sampling_params at server init, but ChatCompletionRequest.to_sampling_params() diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index b294aa447b..609b44f84d 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -7,9 +7,6 @@ from fastapi.responses import JSONResponse, StreamingResponse from starlette.datastructures import State from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.cli.serve import run_api_server_worker_proc -from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.api_server import init_app_state from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionResponse from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -123,6 +120,9 @@ # NemotronH "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "qwen3_coder", "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "qwen3_coder", + # Gemma4 + "google/gemma-4-26B-A4B-it": "gemma4", + "google/gemma-4-31B-it": "gemma4", } @@ -136,30 +136,23 @@ def resolve_tool_call_parser(model_name: str, tool_call_parser: str | None) -> s logger = get_logger() from prime_rl.inference.patches import ( monkey_patch_harmony_stop_token_propagation, - monkey_patch_hermes_tool_parser_thread_safety, monkey_patch_load_lora_adapter, - monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode, monkey_patch_tokenize_params_validation, - monkey_patch_tokenizer_thread_safety, ) from prime_rl.inference.vllm.serving_chat_with_tokens import ( ChatCompletionRequestWithTokens, OpenAIServingChatWithTokens, ) -# NOTE: Fix harmony stop token propagation for GPT-OSS models (vLLM 0.17.0 bug) +# NOTE: Fix harmony stop token propagation for GPT-OSS models +# Upstream issue still open: https://github.com/vllm-project/vllm/issues/22519 monkey_patch_harmony_stop_token_propagation() -# NOTE: Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode -monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode() # NOTE: Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times +# May be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326) monkey_patch_load_lora_adapter() # NOTE: Monkeypatch TokenizeParams to fix overly conservative validation +# Still needed in vLLM 0.19 — upstream rejects prompt_len > max_model_len - max_tokens monkey_patch_tokenize_params_validation() -# NOTE: Monkeypatch Hermes tool parser to fix "Already borrowed" RuntimeError under concurrent load -monkey_patch_hermes_tool_parser_thread_safety() -# NOTE: Monkeypatch HF tokenizer to fix "Already borrowed" RuntimeError during concurrent chat template processing -# Can be removed once https://github.com/vllm-project/vllm/pull/36557 is merged and we upgrade vllm -monkey_patch_tokenizer_thread_safety() logger = init_logger("vllm.entrypoints.openai.api_server") @@ -279,82 +272,36 @@ async def custom_init_app_state( ): """ Modifies init_app_state: - 1. Set up the custom OpenAIServingChatWithTokens state. - 2. Monkey-patch to allow updating lora adapters in-place. + 1. Call the original init_app_state to set up standard state. + 2. Replace the serving_chat with our OpenAIServingChatWithTokens wrapper. """ - # Setup the regular app state first (in-place) await init_app_state(engine_client, state, args, supported_tasks) - # NOTE: Initialize the custom OpenAIServingChatWithTokens state here - # TODO: Here, we repeat some calls done in init_app_state to be able to - # correctly set up the OpenAIServingChatWithTokens state, which is a bit - # brittle, and could probably be made nicer - if args.enable_log_requests: - request_logger = RequestLogger(max_log_len=args.max_log_len) + if "generate" in supported_tasks and state.openai_serving_chat is not None: + original_chat = state.openai_serving_chat + serving_chat = object.__new__(OpenAIServingChatWithTokens) + serving_chat.__dict__.update(original_chat.__dict__) + state.openai_serving_chat = serving_chat + state.openai_serving_chat_with_tokens = serving_chat else: - request_logger = None - - resolved_chat_template = load_chat_template(args.chat_template) - - chat_kwargs = dict( - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - trust_request_chat_template=args.trust_request_chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - ) - if hasattr(args, "log_error_stack"): - chat_kwargs["log_error_stack"] = args.log_error_stack - - serving_chat = OpenAIServingChatWithTokens( - engine_client, - state.openai_serving_models, - args.response_role, - **chat_kwargs, - ) - state.openai_serving_chat = serving_chat if "generate" in supported_tasks else None - state.openai_serving_chat_with_tokens = serving_chat if "generate" in supported_tasks else None - - -def custom_run_api_server_worker_proc(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None: - """ - Modifies run_api_server_worker_proc: - 1. Re-import our module to ensure monkey patches are applied in child processes - """ - # NOTE: This hack ensures that monkey patches are applied in child processes - # to make our custom routes work in multi-API-server settings. - import prime_rl.inference.vllm.server # noqa: F401 - - run_api_server_worker_proc(listen_address, sock, args, client_config, **uvicorn_kwargs) + state.openai_serving_chat_with_tokens = None -import vllm.entrypoints.cli.serve import vllm.entrypoints.openai.api_server from vllm.entrypoints.openai.api_server import build_app as _original_build_app -def custom_build_app(args: Namespace, supported_tasks: tuple): +def custom_build_app(args: Namespace, supported_tasks: tuple, model_config=None): """ Wrap build_app to include our custom router. """ - app = _original_build_app(args, supported_tasks) + app = _original_build_app(args, supported_tasks, model_config) app.include_router(router) return app -# Also monkey patch run_api_server_worker_proc for multi-api-server mode -# This is needed because worker processes spawned by run_multi_api_server -# re-import modules and would otherwise use the original run_server_worker vllm.entrypoints.openai.api_server.init_app_state = custom_init_app_state vllm.entrypoints.openai.api_server.build_app = custom_build_app -vllm.entrypoints.cli.serve.run_api_server_worker_proc = custom_run_api_server_worker_proc # Adapted from vllm/entrypoints/cli/serve.py diff --git a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py index 9d37e72d2a..e7f4b911f6 100644 --- a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py +++ b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py @@ -250,7 +250,7 @@ async def create_chat_completion_with_tokens( request_metadata, reasoning_parser, ) - except GenerationError as e: - return self._convert_generation_error_to_response(e) + except GenerationError: + raise # Let FastAPI's global generation_error_handler handle it except ValueError as e: return self.create_error_response(e) diff --git a/src/prime_rl/trainer/models/__init__.py b/src/prime_rl/trainer/models/__init__.py index b2e2068217..2daeb771f0 100644 --- a/src/prime_rl/trainer/models/__init__.py +++ b/src/prime_rl/trainer/models/__init__.py @@ -10,6 +10,7 @@ from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM, Gemma4TextConfig from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput, cast_float_and_contiguous @@ -20,6 +21,7 @@ from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM # Make custom config discoverable by AutoConfig +AutoConfig.register("gemma4_text", Gemma4TextConfig, exist_ok=True) AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True) AutoConfig.register("glm4_moe", Glm4MoeConfig, exist_ok=True) AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True) @@ -29,6 +31,7 @@ AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict()) +_CUSTOM_CAUSAL_LM_MAPPING.register(Gemma4TextConfig, Gemma4ForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(LlamaConfig, LlamaForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(AfmoeConfig, AfmoeForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(Glm4MoeConfig, Glm4MoeForCausalLM, exist_ok=True) @@ -62,6 +65,7 @@ def supports_custom_impl(model_config: PretrainedConfig) -> bool: # Used by get_model() to dispatch VLMs that have a custom text model implementation. # Points to the same unified class — the config drives text-only vs VLM behavior. _CUSTOM_VLM_MAPPING: dict[str, type] = { + "gemma4": Gemma4ForCausalLM, "qwen3_5_moe": Qwen3_5MoeForCausalLM, } diff --git a/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py b/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py index 2090c72d51..24e8013722 100644 --- a/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py +++ b/src/prime_rl/trainer/models/afmoe/configuration_afmoe.py @@ -1,5 +1,4 @@ from transformers.configuration_utils import PretrainedConfig, layer_type_validation -from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging logger = logging.get_logger(__name__) @@ -106,7 +105,7 @@ def __init__( # Validate rope configs if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) + self.standardize_rope_params() super().__init__( tie_word_embeddings=tie_word_embeddings, diff --git a/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py b/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py index 136ad70306..e26f919140 100644 --- a/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py +++ b/src/prime_rl/trainer/models/afmoe/modeling_afmoe.py @@ -508,7 +508,7 @@ def forward( if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, - "input_embeds": inputs_embeds, + "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": None, diff --git a/src/prime_rl/trainer/models/gemma4/__init__.py b/src/prime_rl/trainer/models/gemma4/__init__.py new file mode 100644 index 0000000000..300b15f0e2 --- /dev/null +++ b/src/prime_rl/trainer/models/gemma4/__init__.py @@ -0,0 +1,5 @@ +from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + +from prime_rl.trainer.models.gemma4.modeling_gemma4 import Gemma4ForCausalLM + +__all__ = ["Gemma4TextConfig", "Gemma4ForCausalLM"] diff --git a/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py b/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py new file mode 100644 index 0000000000..6a0ae367e9 --- /dev/null +++ b/src/prime_rl/trainer/models/gemma4/modeling_gemma4.py @@ -0,0 +1,760 @@ +"""Gemma4 custom implementation for PrimeRL training. + +Supports both dense (31B) and MoE (26B-A4B) variants, in text-only and VLM modes: +- Hybrid sliding window + global attention (5:1 pattern) +- Dual RoPE: theta=10K for sliding, theta=1M + partial_rotary_factor=0.25 for global +- K=V sharing on global attention layers (no v_proj) +- QKV norms (q/k with scale, v without scale) +- Attention scaling = 1.0 (QK norms handle magnitude) +- Logit softcapping (tanh at 30.0) +- Per-layer learnable scalar +- Scaled embeddings (× sqrt(hidden_size)) + +VLM mode is auto-detected from the config (presence of vision_config). In VLM mode, +the HF vision tower is used as-is and only the language model uses our custom impl. +""" + +from typing import Optional, Union + +import torch +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple + +from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput +from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig + +# flash-attention-2 +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + flash_attn_varlen_func = None # type: ignore + +# flash-attention-3 +try: + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +except ImportError: + flash_attn_3_varlen_func = None # type: ignore + +try: + from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func +except ImportError: + flash_attn_4_varlen_func = None # type: ignore + + +# --------------------------------------------------------------------------- +# Norms +# --------------------------------------------------------------------------- + + +class Gemma4RMSNormNoScale(nn.Module): + """RMSNorm without a learnable scale parameter, for value normalization.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return hidden_states.to(input_dtype) + + +# --------------------------------------------------------------------------- +# Scaled word embedding +# --------------------------------------------------------------------------- + + +class Gemma4ScaledWordEmbedding(nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + + +# --------------------------------------------------------------------------- +# Dual RoPE (per layer type) +# --------------------------------------------------------------------------- + + +def _compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None, head_dim_key=None): + config.standardize_rope_params() + rope_params = config.rope_parameters[layer_type] if layer_type else config.rope_parameters + base = rope_params["rope_theta"] + partial_rotary_factor = rope_params.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, head_dim_key, None) if head_dim_key else None + if head_dim is None: + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, 1.0 + + +class Gemma4DualRotaryEmbedding(nn.Module): + """Stores separate inv_freq buffers per layer type (sliding vs full).""" + + def __init__(self, config: Gemma4TextConfig, device=None): + super().__init__() + self.config = config + self.layer_types = set(config.layer_types) + self.attention_scaling = {} + + for layer_type in self.layer_types: + rope_params = config.rope_parameters[layer_type] + rope_type = rope_params["rope_type"] + + if rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = _compute_default_rope_parameters + + kwargs = {"device": device, "layer_type": layer_type} + if layer_type == "full_attention" and rope_type == "proportional": + kwargs["head_dim_key"] = "global_head_dim" + + inv_freq, attn_scaling = rope_init_fn(config, **kwargs) + self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) + self.attention_scaling[layer_type] = attn_scaling + + @torch.no_grad() + def forward(self, x: torch.Tensor, position_ids: torch.Tensor, layer_type: str): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attn_scaling = self.attention_scaling[layer_type] + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attn_scaling + sin = emb.sin() * attn_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# RoPE application +# --------------------------------------------------------------------------- + + +def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb_single(t, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + rotary_dim = cos.shape[-1] + t_rot, t_pass = t[..., :rotary_dim], t[..., rotary_dim:] + t_embed = (t_rot * cos) + (_rotate_half(t_rot) * sin) + return torch.cat([t_embed, t_pass], dim=-1) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + +_FLASH_ATTN_FUNCS = { + 2: flash_attn_varlen_func, + 3: flash_attn_3_varlen_func, + 4: flash_attn_4_varlen_func, +} + + +class Gemma4Attention(nn.Module): + """Gemma4 attention with hybrid sliding/global, K=V sharing, QKV norms.""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int, flash_attn_version: int = 2): + super().__init__() + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.is_sliding = self.layer_type == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + # Global attention uses larger head_dim + self.head_dim = config.global_head_dim if (not self.is_sliding and config.global_head_dim) else config.head_dim + self.num_attention_heads = config.num_attention_heads + + # K=V sharing: only on global layers when enabled + self.use_kv_sharing = config.attention_k_eq_v and not self.is_sliding + num_kv_heads = ( + config.num_global_key_value_heads + if self.use_kv_sharing and config.num_global_key_value_heads + else config.num_key_value_heads + ) + self.num_key_value_heads = num_kv_heads + self.num_key_value_groups = config.num_attention_heads // num_kv_heads + + # Projections + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = ( + nn.Linear(config.hidden_size, num_kv_heads * self.head_dim, bias=config.attention_bias) + if not self.use_kv_sharing + else None + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + # QKV norms + self.q_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) + self.k_norm = RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) + self.v_norm = Gemma4RMSNormNoScale(self.head_dim, eps=config.rms_norm_eps) + + # Flash attention + self._flash_attn_version = flash_attn_version + self._flash_attn_func = _FLASH_ATTN_FUNCS[flash_attn_version] + if self._flash_attn_version == 4: + self._flash_attn_call = torch._dynamo.disable(self._flash_attn_func) + else: + self._flash_attn_call = self._flash_attn_func + + def _compute_flash_attention(self, q, k, v, cu_seqlens, max_seqlen): + args = [q, k, v, cu_seqlens, cu_seqlens] + if self._flash_attn_version != 4: + args.extend([max_seqlen, max_seqlen]) + kwargs = {"causal": True, "softmax_scale": 1.0} + if self.sliding_window is not None: + kwargs["window_size"] = (self.sliding_window - 1, 0) + out = self._flash_attn_call(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + return out + + def _compute_sdpa_attention(self, q, k, v, cu_seqlens): + """SDPA fallback for head_dim > 256 (global attention layers). + + Handles packed sequences by building a block-diagonal causal mask from cu_seqlens. + """ + # q/k/v: [total_tokens, heads, dim] -> [1, heads, total_tokens, dim] + q = q.unsqueeze(0).transpose(1, 2) + k = k.unsqueeze(0).transpose(1, 2) + v = v.unsqueeze(0).transpose(1, 2) + + # GQA: repeat k/v heads to match q heads + if k.shape[1] != q.shape[1]: + n_rep = q.shape[1] // k.shape[1] + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + # Build block-diagonal causal mask for packed sequences + total_len = q.shape[2] + if cu_seqlens is not None and len(cu_seqlens) > 2: + # Multiple packed sequences — need block-diagonal mask + mask = torch.full((total_len, total_len), float("-inf"), device=q.device, dtype=q.dtype) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item() + seq_len = end - start + causal = torch.tril(torch.zeros(seq_len, seq_len, device=q.device, dtype=q.dtype)) + causal.masked_fill_( + torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1), + float("-inf"), + ) + mask[start:end, start:end] = causal + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=1.0) + else: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True, scale=1.0) + + # [1, heads, total_tokens, dim] -> [total_tokens, heads, dim] + return out.transpose(1, 2).squeeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cos, sin = position_embeddings + + # Q projection + norm + RoPE + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = _apply_rotary_pos_emb_single(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # K projection + norm + RoPE + key_states = self.k_proj(hidden_states).view(hidden_shape) + # V: either from v_proj or reuse k_proj output (K=V sharing) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states + + key_states = self.k_norm(key_states) + key_states = _apply_rotary_pos_emb_single(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + # Flash attention expects [total_tokens, heads, dim] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # FlashAttention only supports head_dim <= 256; fall back to SDPA for global layers (head_dim=512) + if self.head_dim > 256: + attn_output = self._compute_sdpa_attention(query_states[0], key_states[0], value_states[0], cu_seqlens) + else: + attn_output = self._compute_flash_attention( + query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen + ) + attn_output = attn_output.contiguous().view(1, attn_output.shape[0], -1) + attn_output = self.o_proj(attn_output) + return attn_output, None + + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- + + +class Gemma4MLP(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# Decoder layer +# --------------------------------------------------------------------------- + + +def _get_flash_attn_version(attn_impl: str) -> int: + mapping = { + "flash_attention_2": 2, + "flash_attention_3": 3, + "fa4": 4, + } + return mapping.get(attn_impl, 2) + + +class Gemma4DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.self_attn = Gemma4Attention( + config, layer_idx, flash_attn_version=_get_flash_attn_version(config._attn_implementation) + ) + self.mlp = Gemma4MLP(config) + + # 4 layernorms + self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.pre_feedforward_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.post_feedforward_layernorm = RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + + # Per-layer scalar + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> torch.Tensor: + # Self-attention block + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # FFN block + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# VLM helpers +# --------------------------------------------------------------------------- + + +def _has_vlm_keys(state_dict: dict[str, Tensor]) -> bool: + return any(k.startswith("model.language_model.") for k in state_dict) + + +def _remap_lm_keys(state_dict: dict[str, Tensor], to_flat: bool = True) -> None: + """Remap language model keys between VLM and flat format for weight conversion. + + to_flat=True: model.language_model.* -> model.* + to_flat=False: model.* -> model.language_model.* + + Vision keys (model.vision_tower.*, model.embed_vision.*) are never touched. + """ + VISION_PREFIXES = ("model.vision_tower.", "model.embed_vision.") + src = "model.language_model." if to_flat else "model." + dst = "model." if to_flat else "model.language_model." + for k in [ + k for k in list(state_dict.keys()) if k.startswith(src) and not any(k.startswith(p) for p in VISION_PREFIXES) + ]: + state_dict[dst + k[len(src) :]] = state_dict.pop(k) + + +# --------------------------------------------------------------------------- +# PreTrained base +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4PreTrainedModel(PreTrainedModelPrimeRL): + config_class = Gemma4TextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma4DecoderLayer"] + _supports_flash_attn = True + _supports_sdpa = True + + @classmethod + def _has_moe_keys(cls, state_dict: dict[str, Tensor]) -> bool: + return any("experts." in k for k in state_dict) + + @classmethod + def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + if not cls._has_moe_keys(state_dict): + return True # Dense: HF format = PrimeRL format + return any("experts.gate_up_proj" in k or "experts.0.gate_proj" in k for k in state_dict) + + @classmethod + def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + if not cls._has_moe_keys(state_dict): + return True # Dense: HF format = PrimeRL format + return any("experts.w1" in k for k in state_dict) + + @classmethod + def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_prime_to_hf + + vlm = _has_vlm_keys(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=True) + convert_prime_to_hf(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=False) + return state_dict + + @classmethod + def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_hf_to_prime + + vlm = _has_vlm_keys(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=True) + convert_hf_to_prime(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=False) + return state_dict + + @classmethod + def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_prime_layer_to_hf + + vlm = _has_vlm_keys(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=True) + convert_prime_layer_to_hf(state_dict, layer_idx) + if vlm: + _remap_lm_keys(state_dict, to_flat=False) + return state_dict + + @classmethod + def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import convert_hf_layer_to_prime + + vlm = _has_vlm_keys(state_dict) + if vlm: + _remap_lm_keys(state_dict, to_flat=True) + convert_hf_layer_to_prime(state_dict, layer_idx) + if vlm: + _remap_lm_keys(state_dict, to_flat=False) + return state_dict + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4Model(Gemma4PreTrainedModel): + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = Gemma4ScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + if config.enable_moe_block: + from prime_rl.trainer.models.gemma4_moe.modeling_gemma4_moe import Gemma4MoeDecoderLayer + + layer_cls = Gemma4MoeDecoderLayer + else: + layer_cls = Gemma4DecoderLayer + self.layers = nn.ModuleList([layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.rotary_emb = Gemma4DualRotaryEmbedding(config) + self.gradient_checkpointing = False + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"): + flat_position_ids = position_ids.view(-1) + seqlens = torch.cat( + [ + flat_position_ids[0:1], + flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1, + flat_position_ids[-1:] + 1, + ] + ) + max_seqlen = seqlens.max().item() + cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) + torch._dynamo.mark_dynamic(cu_seqlens, 0) + else: + max_seqlen = None + cu_seqlens = None + + hidden_states = inputs_embeds + + # Compute RoPE embeddings per layer type + unique_layer_types = set(self.config.layer_types) + position_embeddings = {} + for layer_type in unique_layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + for layer_idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[layer_idx] + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings[layer_type], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +# --------------------------------------------------------------------------- +# VLM composite model +# --------------------------------------------------------------------------- + + +class Gemma4VLMModel(nn.Module): + """Composite VLM body: HF vision tower + custom PrimeRL text model.""" + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4VisionModel + + self.vision_tower = Gemma4VisionModel._from_config(config.vision_config) + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config) + self.language_model = Gemma4Model(config.text_config) + + def get_input_embeddings(self): + return self.language_model.embed_tokens + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.language_model.embed_tokens(input_ids) + + if pixel_values is not None: + pixel_values = pixel_values.type(self.vision_tower.dtype) + vision_output = self.vision_tower(pixel_values, return_dict=True) + image_features = self.embed_vision(vision_output.last_hidden_state) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask = input_ids == self.config.image_token_id + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + return self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + ) + + +# --------------------------------------------------------------------------- +# Causal LM (unified text-only + VLM) +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin): + """Unified Gemma4 model for both text-only and VLM configs. + + When config has a vision_config, creates a composite model with HF's frozen + vision tower + custom text model. Otherwise creates a text-only model. + """ + + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self._is_vlm = hasattr(config, "vision_config") + + if self._is_vlm: + self.model = Gemma4VLMModel(config) + text_config = config.text_config + self._tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + else: + self.model = Gemma4Model(config) + text_config = config + + self.vocab_size = text_config.vocab_size + self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False) + self.final_logit_softcapping = text_config.final_logit_softcapping + self.post_init() + + def get_input_embeddings(self): + if self._is_vlm: + return self.model.get_input_embeddings() + return self.model.embed_tokens + + def set_input_embeddings(self, value): + if self._is_vlm: + self.model.set_input_embeddings(value) + else: + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> PrimeLmOutput: + if position_ids is None: + if inputs_embeds is not None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0) + + if self._is_vlm: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + ) + else: + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + return self.lm_head( + hidden_states[:, slice_indices, :], + labels[:, slice_indices] if labels is not None else None, + temperature=temperature, + ) + + def _get_text_config(self): + return self.config.text_config if self._is_vlm else self.config + + def init_buffers_post_meta(self): + text_config = self._get_text_config() + + # Reinitialize embed_scale (non-persistent buffer) + if self._is_vlm: + embed_tokens = self.model.language_model.embed_tokens + rotary_emb = self.model.language_model.rotary_emb + else: + embed_tokens = self.model.embed_tokens + rotary_emb = self.model.rotary_emb + + embed_tokens.embed_scale.fill_(text_config.hidden_size**0.5) + + # Initialize dual RoPE inv_freq buffers + for layer_type in rotary_emb.layer_types: + rope_params = text_config.rope_parameters[layer_type] + rope_type = rope_params["rope_type"] + if rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = _compute_default_rope_parameters + + kwargs = {"device": getattr(rotary_emb, f"{layer_type}_inv_freq").device, "layer_type": layer_type} + if layer_type == "full_attention" and rope_type == "proportional": + kwargs["head_dim_key"] = "global_head_dim" + + inv_freq, attn_scaling = rope_init_fn(text_config, **kwargs) + getattr(rotary_emb, f"{layer_type}_inv_freq").copy_(inv_freq) + rotary_emb.attention_scaling[layer_type] = attn_scaling diff --git a/src/prime_rl/trainer/models/gemma4_moe/__init__.py b/src/prime_rl/trainer/models/gemma4_moe/__init__.py new file mode 100644 index 0000000000..7c79a57790 --- /dev/null +++ b/src/prime_rl/trainer/models/gemma4_moe/__init__.py @@ -0,0 +1,5 @@ +from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + +from prime_rl.trainer.models.gemma4_moe.modeling_gemma4_moe import Gemma4MoeForCausalLM + +__all__ = ["Gemma4TextConfig", "Gemma4MoeForCausalLM"] diff --git a/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py b/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py new file mode 100644 index 0000000000..b4648f0ac7 --- /dev/null +++ b/src/prime_rl/trainer/models/gemma4_moe/converting_gemma4_moe.py @@ -0,0 +1,66 @@ +"""Weight conversion between HF and PrimeRL formats for Gemma4 MoE. + +HF format: experts.gate_up_proj (num_experts, 2*moe_dim, dim) + experts.down_proj (num_experts, dim, moe_dim) + router.norm.weight, router.proj.weight, router.scale, router.per_expert_scale +PrimeRL format: experts.w1 (num_experts, moe_dim, dim), experts.w2 (num_experts, dim, moe_dim), + experts.w3 (num_experts, moe_dim, dim) + router keys stay the same (no rename needed since Gemma4 router has its own structure) +""" + +from torch import Tensor + + +def get_max_layer_num(state_dict: dict[str, Tensor]) -> int: + return max(int(i.split(".")[2]) for i in state_dict.keys() if "model.layers." in i) + 1 + + +def convert_hf_layer_to_prime(state_dict: dict[str, Tensor], layer_idx: int): + """Convert Gemma4 MoE layer from HF fused format to PrimeRL w1/w2/w3 format.""" + i = layer_idx + gate_up_key = f"model.layers.{i}.experts.gate_up_proj" + down_key = f"model.layers.{i}.experts.down_proj" + + if gate_up_key not in state_dict: + return + + gate_up_proj = state_dict.pop(gate_up_key) + down_proj = state_dict.pop(down_key) + + num_experts, fused_dim, dim = gate_up_proj.shape + moe_dim = fused_dim // 2 + + # Split gate_up into w1 (gate) and w3 (up) + state_dict[f"model.layers.{i}.experts.w1"] = gate_up_proj[:, :moe_dim, :] + state_dict[f"model.layers.{i}.experts.w3"] = gate_up_proj[:, moe_dim:, :] + state_dict[f"model.layers.{i}.experts.w2"] = down_proj + + +def convert_prime_layer_to_hf(state_dict: dict[str, Tensor], layer_idx: int): + """Convert Gemma4 MoE layer from PrimeRL w1/w2/w3 to HF per-expert format.""" + i = layer_idx + w1_key = f"model.layers.{i}.experts.w1" + + if w1_key not in state_dict: + return + + w1 = state_dict.pop(w1_key) + w2 = state_dict.pop(f"model.layers.{i}.experts.w2") + w3 = state_dict.pop(f"model.layers.{i}.experts.w3") + + num_experts = w1.shape[0] + for j in range(num_experts): + state_dict[f"model.layers.{i}.experts.{j}.gate_proj.weight"] = w1[j] + state_dict[f"model.layers.{i}.experts.{j}.down_proj.weight"] = w2[j] + state_dict[f"model.layers.{i}.experts.{j}.up_proj.weight"] = w3[j] + + +def convert_hf_to_prime(state_dict: dict[str, Tensor]): + num_layers = get_max_layer_num(state_dict) + for i in range(num_layers): + convert_hf_layer_to_prime(state_dict, i) + + +def convert_prime_to_hf(state_dict: dict[str, Tensor]): + num_layers = get_max_layer_num(state_dict) + for i in range(num_layers): + convert_prime_layer_to_hf(state_dict, i) diff --git a/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py b/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py new file mode 100644 index 0000000000..34f9822ebc --- /dev/null +++ b/src/prime_rl/trainer/models/gemma4_moe/modeling_gemma4_moe.py @@ -0,0 +1,435 @@ +"""Gemma4 MoE custom implementation for PrimeRL training. + +Extends the dense Gemma4 model with: +- 128 sparse experts (top-8) running in parallel with a shared dense MLP +- Custom router: RMSNorm (no scale) + learnable scale + softmax + renormalize + per_expert_scale +- Router input is the pre-MLP residual (not the MLP output) +- Expert activation: gelu_pytorch_tanh (not silu) +- MoE weight conversion: HF fused gate_up_proj → PrimeRL w1/w2/w3 +""" + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.generation import GenerationMixin +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple + +from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.gemma4.modeling_gemma4 import ( + Gemma4Attention, + Gemma4DualRotaryEmbedding, + Gemma4MLP, + Gemma4RMSNormNoScale, + Gemma4ScaledWordEmbedding, + _compute_default_rope_parameters, + _get_flash_attn_version, +) +from prime_rl.trainer.models.gemma4_moe.converting_gemma4_moe import ( + convert_hf_layer_to_prime, + convert_hf_to_prime, + convert_prime_layer_to_hf, + convert_prime_to_hf, +) +from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput +from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig + +# --------------------------------------------------------------------------- +# Gemma4 MoE Router +# --------------------------------------------------------------------------- + + +class Gemma4Router(nn.Module): + """Gemma4 router: RMSNorm(no_scale) → scale → linear → softmax → topk → renorm → per_expert_scale.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.top_k_experts + self.scalar_root_size = self.hidden_size**-0.5 + + self.norm = Gemma4RMSNormNoScale(self.hidden_size, eps=config.rms_norm_eps) + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + + expert_scores = self.proj(hidden_states) + router_probs = F.softmax(expert_scores.float(), dim=-1) + + top_k_weights, top_k_index = torch.topk(router_probs, k=self.top_k, dim=-1) + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + + num_tokens_per_expert = torch.histc( + top_k_index.reshape(-1).float(), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + return top_k_weights.to(hidden_states.dtype), top_k_index, num_tokens_per_expert + + +# --------------------------------------------------------------------------- +# Gemma4 MoE Experts (w1/w2/w3 format with gelu_pytorch_tanh) +# --------------------------------------------------------------------------- + + +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + act_fn, +) -> torch.Tensor: + num_tokens_per_expert_list = num_tokens_per_expert.tolist() + num_padding = x.shape[0] - sum(num_tokens_per_expert_list) + x_splits = torch.split(x[: sum(num_tokens_per_expert_list)], num_tokens_per_expert_list, dim=0) + out_splits = [] + for expert_idx, x_expert in enumerate(x_splits): + h = act_fn(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) + h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) + out_splits.append(h) + out = torch.cat(out_splits, dim=0) + if num_padding > 0: + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + return out + + +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + act_fn, +) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + h = act_fn(torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + return out + + +class Gemma4Experts(nn.Module): + """Expert weights in PrimeRL w1/w2/w3 format with configurable activation.""" + + def __init__(self, config: Gemma4TextConfig, use_grouped_mm: bool = True): + super().__init__() + self.num_experts = config.num_experts + dim = config.hidden_size + hidden_dim = config.moe_intermediate_size + self.w1 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, dim, hidden_dim)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim)) + self.act_fn = ACT2FN[config.hidden_activation] + self.use_grouped_mm = use_grouped_mm + + def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + if self.use_grouped_mm: + return _run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert, self.act_fn) + return _run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert, self.act_fn) + + +# --------------------------------------------------------------------------- +# MoE Decoder Layer +# --------------------------------------------------------------------------- + + +class Gemma4MoeDecoderLayer(GradientCheckpointingLayer): + """Gemma4 decoder layer with shared MLP + sparse MoE in parallel.""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.self_attn = Gemma4Attention( + config, layer_idx, flash_attn_version=_get_flash_attn_version(config._attn_implementation) + ) + self.mlp = Gemma4MLP(config) + + # Standard 4 layernorms + self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.pre_feedforward_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.post_feedforward_layernorm = RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + + # MoE components + self.router = Gemma4Router(config) + self.experts = Gemma4Experts(config, use_grouped_mm=getattr(config, "use_grouped_mm", True)) + + # Extra norms for MoE parallel path + self.post_feedforward_layernorm_1 = RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + self.post_feedforward_layernorm_2 = RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + self.pre_feedforward_layernorm_2 = RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + + # Per-layer scalar + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> torch.Tensor: + # Self-attention block + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # FFN + MoE parallel block + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + # Shared MLP output normalized + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) + + # Sparse experts: router takes pre-MLP residual + hidden_states_flat = residual.reshape(-1, residual.shape[-1]) + top_k_weights, top_k_index, num_tokens_per_expert = self.router(hidden_states_flat) + + # Reorder tokens by expert assignment + selected_flat = top_k_index.reshape(-1) + token_indices_sorted = torch.argsort(selected_flat, stable=True) + scores_sorted = top_k_weights.view(-1)[token_indices_sorted] + token_indices_sorted = token_indices_sorted // self.router.top_k + + dim = hidden_states_flat.shape[-1] + routed_indices = token_indices_sorted.reshape(-1, 1).expand(-1, dim) + + # Pre-norm for expert input + expert_input = self.pre_feedforward_layernorm_2(hidden_states_flat) + routed_input = torch.gather(expert_input, dim=0, index=routed_indices) + routed_input = (routed_input.float() * scores_sorted.reshape(-1, 1)).to(routed_input.dtype) + + routed_output = self.experts(routed_input, num_tokens_per_expert) + + # Scatter back + hidden_states_2 = torch.zeros_like(hidden_states_flat) + hidden_states_2.scatter_add_(dim=0, index=routed_indices, src=routed_output) + hidden_states_2 = hidden_states_2.reshape(residual.shape) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine shared MLP + sparse experts + hidden_states = hidden_states_1 + hidden_states_2 + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# PreTrained base +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4MoePreTrainedModel(PreTrainedModelPrimeRL): + config_class = Gemma4TextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma4MoeDecoderLayer"] + _supports_flash_attn = True + + @classmethod + def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return any("experts.gate_up_proj" in k or "experts.0.gate_proj" in k for k in state_dict) + + @classmethod + def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return any("experts.w1" in k for k in state_dict) + + @classmethod + def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + convert_prime_to_hf(state_dict) + return state_dict + + @classmethod + def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + convert_hf_to_prime(state_dict) + return state_dict + + @classmethod + def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + convert_prime_layer_to_hf(state_dict, layer_idx) + return state_dict + + @classmethod + def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + convert_hf_layer_to_prime(state_dict, layer_idx) + return state_dict + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4MoeModel(Gemma4MoePreTrainedModel): + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = Gemma4ScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=config.hidden_size**0.5, + ) + self.layers = nn.ModuleList( + [Gemma4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.rotary_emb = Gemma4DualRotaryEmbedding(config) + self.gradient_checkpointing = False + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"): + flat_position_ids = position_ids.view(-1) + seqlens = torch.cat( + [ + flat_position_ids[0:1], + flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1, + flat_position_ids[-1:] + 1, + ] + ) + max_seqlen = seqlens.max().item() + cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) + torch._dynamo.mark_dynamic(cu_seqlens, 0) + else: + max_seqlen = None + cu_seqlens = None + + hidden_states = inputs_embeds + + unique_layer_types = set(self.config.layer_types) + position_embeddings = {} + for layer_type in unique_layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + for layer_idx, decoder_layer in enumerate(self.layers): + layer_type = self.config.layer_types[layer_idx] + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings[layer_type], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +# --------------------------------------------------------------------------- +# Causal LM +# --------------------------------------------------------------------------- + + +@auto_docstring +class Gemma4MoeForCausalLM(Gemma4MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.model = Gemma4MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.final_logit_softcapping = config.final_logit_softcapping + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> PrimeLmOutput: + if position_ids is None: + if inputs_embeds is not None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + return self.lm_head( + hidden_states[:, slice_indices, :], + labels[:, slice_indices] if labels is not None else None, + temperature=temperature, + ) + + def init_buffers_post_meta(self): + self.model.embed_tokens.embed_scale.fill_(self.config.hidden_size**0.5) + rotary_emb = self.model.rotary_emb + for layer_type in rotary_emb.layer_types: + rope_params = self.config.rope_parameters[layer_type] + rope_type = rope_params["rope_type"] + if rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = _compute_default_rope_parameters + + kwargs = {"device": getattr(rotary_emb, f"{layer_type}_inv_freq").device, "layer_type": layer_type} + if layer_type == "full_attention" and rope_type == "proportional": + kwargs["head_dim_key"] = "global_head_dim" + + inv_freq, attn_scaling = rope_init_fn(self.config, **kwargs) + getattr(rotary_emb, f"{layer_type}_inv_freq").copy_(inv_freq) + rotary_emb.attention_scaling[layer_type] = attn_scaling diff --git a/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py b/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py index a09d03d648..6fcceec309 100644 --- a/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py +++ b/src/prime_rl/trainer/models/minimax_m2/configuration_minimax_m2.py @@ -1,5 +1,4 @@ from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation class MiniMaxM2Config(PretrainedConfig): @@ -124,7 +123,7 @@ def __init__( if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) + self.standardize_rope_params() # MoE arguments self.num_local_experts = num_local_experts diff --git a/src/prime_rl/trainer/perf.py b/src/prime_rl/trainer/perf.py index a4aeaf9b2c..b22a05fd72 100644 --- a/src/prime_rl/trainer/perf.py +++ b/src/prime_rl/trainer/perf.py @@ -135,16 +135,16 @@ def get_active_mm_params(config: PretrainedConfig) -> float: sparse_mlp_params = 0 # Some MoE models (e.g. DeepSeek) use moe_intermediate_size, others (e.g. Granite) just use intermediate_size - moe_intermediate_size = getattr(config, "moe_intermediate_size", intermediate_size) - if hasattr(config, "num_shared_experts"): # Shared experts + moe_intermediate_size = getattr(config, "moe_intermediate_size", None) or intermediate_size + if hasattr(config, "num_shared_experts") and config.num_shared_experts: # Shared experts sparse_mlp_params += num_sparse_layers * config.num_shared_experts * 3 * moe_intermediate_size * hidden_size - if hasattr(config, "num_experts_per_tok"): # Routed experts + if hasattr(config, "num_experts_per_tok") and config.num_experts_per_tok: # Routed experts sparse_mlp_params += ( num_sparse_layers * config.num_experts_per_tok * 3 * moe_intermediate_size * hidden_size ) if hasattr(config, "n_routed_experts"): # DeepSeek Router sparse_mlp_params += num_sparse_layers * config.n_routed_experts * hidden_size - elif hasattr(config, "num_experts"): # Qwen Router + elif hasattr(config, "num_experts") and config.num_experts is not None: # Qwen Router sparse_mlp_params += num_sparse_layers * config.num_experts * hidden_size else: sparse_mlp_params = 0 diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 626ed151c7..c363f406b5 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -1,3 +1,5 @@ +import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before import + from contextlib import nullcontext import time from datetime import timedelta diff --git a/src/prime_rl/trainer/sft/data.py b/src/prime_rl/trainer/sft/data.py index 34ad868df2..80b7a823b3 100644 --- a/src/prime_rl/trainer/sft/data.py +++ b/src/prime_rl/trainer/sft/data.py @@ -124,6 +124,7 @@ def __init__( seq_len: int = 128, non_dp_size: int = 1, loss_mask_config: LossMaskConfig = LossMaskConfig(), + chat_template_kwargs: dict | None = None, max_examples: int | None = None, max_epochs: int | None = None, ): @@ -136,6 +137,7 @@ def __init__( self.seed = seed self.seq_len = seq_len self.loss_mask_config = loss_mask_config + self.chat_template_kwargs = chat_template_kwargs or {} self.max_examples = max_examples self.max_epochs = max_epochs @@ -205,12 +207,16 @@ def should_mask(message: dict) -> bool: case _: raise ValueError(f"Invalid message role: {message['role']}") + per_example_kwargs = example.get("chat_template_kwargs", {}) + if self.chat_template_kwargs: + per_example_kwargs = {**self.chat_template_kwargs, **per_example_kwargs} + input_ids, loss_mask = build_incremental_token_mask( self.tokenizer, messages, role_to_mask=should_mask, tools=tools, - chat_template_kwargs=example.get("chat_template_kwargs", {}), + chat_template_kwargs=per_example_kwargs, collapse_consecutive_tool_messages=True, ) @@ -547,6 +553,7 @@ def setup_dataset( seed=config.seed, seq_len=config.seq_len, loss_mask_config=config.loss_mask, + chat_template_kwargs=config.chat_template_kwargs, non_dp_size=non_dp_size, max_epochs=max_epochs, ) diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index afca134c9f..ce3a5c02a5 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -1,3 +1,5 @@ +import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before import + import time from contextlib import nullcontext from datetime import timedelta diff --git a/src/prime_rl/utils/vlm.py b/src/prime_rl/utils/vlm.py index 15b5f1bf3b..18040c1c0a 100644 --- a/src/prime_rl/utils/vlm.py +++ b/src/prime_rl/utils/vlm.py @@ -25,6 +25,7 @@ class VLMModelInfo: # Central registry: model_type -> architecture info. VLM_REGISTRY: dict[str, VLMModelInfo] = { + "gemma4": VLMModelInfo(vision_encoder_attr="model.vision_tower", language_model_attr="model.language_model"), "qwen3_vl": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"), "qwen3_5": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"), "qwen3_5_moe": VLMModelInfo(vision_encoder_attr="model.visual", language_model_attr="model.language_model"), diff --git a/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py b/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py index d7892dfda6..e5fd36c96e 100644 --- a/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py +++ b/tests/unit/train/models/afmoe_hf_modeling/modeling_afmoe.py @@ -525,7 +525,7 @@ def forward( if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, - "input_embeds": inputs_embeds, + "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, diff --git a/tests/unit/train/models/test_gemma4.py b/tests/unit/train/models/test_gemma4.py new file mode 100644 index 0000000000..8ace0ae83b --- /dev/null +++ b/tests/unit/train/models/test_gemma4.py @@ -0,0 +1,140 @@ +import pytest +import torch +from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + +from prime_rl.trainer.models.gemma4 import Gemma4ForCausalLM +from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head + +pytestmark = [pytest.mark.gpu] + + +def _tiny_config(**overrides): + defaults = dict( + vocab_size=256, + hidden_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + global_head_dim=128, + intermediate_size=512, + attention_k_eq_v=True, + num_global_key_value_heads=1, + sliding_window=128, + max_position_embeddings=512, + final_logit_softcapping=30.0, + hidden_activation="gelu_pytorch_tanh", + _attn_implementation="flash_attention_2", + ) + defaults.update(overrides) + return Gemma4TextConfig(**defaults) + + +def test_gemma4_forward(): + config = _tiny_config() + model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16) + inject_prime_lm_head(model) + + input_ids = torch.randint(0, 256, (1, 32), device="cuda") + position_ids = torch.arange(32, device="cuda").unsqueeze(0) + + with torch.no_grad(): + output = model(input_ids=input_ids, position_ids=position_ids) + assert "logits" in output + assert output["logits"].shape == (1, 32, 256) + + +def test_gemma4_backward(): + config = _tiny_config() + model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16) + inject_prime_lm_head(model) + + input_ids = torch.randint(0, 256, (1, 32), device="cuda") + position_ids = torch.arange(32, device="cuda").unsqueeze(0) + + output = model(input_ids=input_ids, position_ids=position_ids) + logits = output["logits"] + assert logits is not None + loss = logits.sum() + loss.backward() + + # Check gradients flow + has_grad = False + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + has_grad = True + break + assert has_grad, "No gradients found" + + +def test_gemma4_no_kv_sharing(): + """Test without K=V sharing (like a hypothetical smaller model).""" + config = _tiny_config(attention_k_eq_v=False, num_global_key_value_heads=None) + model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16) + inject_prime_lm_head(model) + + input_ids = torch.randint(0, 256, (1, 32), device="cuda") + position_ids = torch.arange(32, device="cuda").unsqueeze(0) + + with torch.no_grad(): + output = model(input_ids=input_ids, position_ids=position_ids) + assert output["logits"].shape == (1, 32, 256) + + +def _tiny_moe_config(**overrides): + defaults = dict( + vocab_size=256, + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=32, + global_head_dim=64, + intermediate_size=256, + attention_k_eq_v=True, + num_global_key_value_heads=1, + sliding_window=128, + max_position_embeddings=512, + final_logit_softcapping=30.0, + hidden_activation="gelu_pytorch_tanh", + _attn_implementation="flash_attention_2", + enable_moe_block=True, + num_experts=8, + top_k_experts=2, + moe_intermediate_size=64, + ) + defaults.update(overrides) + return Gemma4TextConfig(**defaults) + + +def test_gemma4_moe_forward(): + config = _tiny_moe_config() + model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16) + inject_prime_lm_head(model) + + input_ids = torch.randint(0, 256, (1, 32), device="cuda") + position_ids = torch.arange(32, device="cuda").unsqueeze(0) + + with torch.no_grad(): + output = model(input_ids=input_ids, position_ids=position_ids) + assert output["logits"].shape == (1, 32, 256) + + +def test_gemma4_moe_backward(): + config = _tiny_moe_config() + model = Gemma4ForCausalLM(config).to(device="cuda", dtype=torch.bfloat16) + inject_prime_lm_head(model) + + input_ids = torch.randint(0, 256, (1, 32), device="cuda") + position_ids = torch.arange(32, device="cuda").unsqueeze(0) + + output = model(input_ids=input_ids, position_ids=position_ids) + loss = output["logits"].sum() + loss.backward() + + has_grad = False + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + has_grad = True + break + assert has_grad, "No gradients found" diff --git a/tests/unit/train/models/test_qwen3_5_moe_vlm.py b/tests/unit/train/models/test_qwen3_5_moe_vlm.py index e0901c60e3..c629f6be25 100644 --- a/tests/unit/train/models/test_qwen3_5_moe_vlm.py +++ b/tests/unit/train/models/test_qwen3_5_moe_vlm.py @@ -21,6 +21,7 @@ def _tiny_vlm_config(): tc.vocab_size = 256 tc.hidden_size = 256 tc.num_hidden_layers = 2 + tc.layer_types = ["linear_attention", "full_attention"] tc.num_attention_heads = 4 tc.num_key_value_heads = 2 tc.head_dim = 64 diff --git a/uv.lock b/uv.lock index 65744a9cf6..3956d73499 100644 --- a/uv.lock +++ b/uv.lock @@ -14,7 +14,7 @@ supported-markers = [ overrides = [ { name = "nvidia-cudnn-cu12", specifier = ">=9.15" }, { name = "nvidia-cutlass-dsl", specifier = ">=4.4.1" }, - { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=5c1c72b" }, + { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, ] [[package]] @@ -475,7 +475,7 @@ wheels = [ [[package]] name = "compressed-tensors" -version = "0.13.0" +version = "0.14.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "loguru", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -483,9 +483,9 @@ dependencies = [ { name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/65/88dd1c58fb9d0ded51b5c86471b937a1525f91fad2211a6f051dc1ea822d/compressed_tensors-0.13.0.tar.gz", hash = "sha256:23893824d3498ea3f1a829f14a8fa85f9a5e76a34c711a038b8d7c619ca9a67c", size = 200995, upload-time = "2025-12-16T16:03:55.397Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/f1/4c9b01ceaf82ad58ad00919223e09b8e74d4073a2ba8e3ab2f97521ef65c/compressed_tensors-0.14.0.1.tar.gz", hash = "sha256:5ad3841184b6f5020e06059b2463191c5c57a144bb97cab9159978d8118839b1", size = 226393, upload-time = "2026-03-11T17:04:35.57Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/b5/61ac2563c62490922b603c09113a083fd74af3630ec3931e769484d6dcb5/compressed_tensors-0.13.0-py3-none-any.whl", hash = "sha256:3518799c9baf034eb642efb551db6b0537b8713d45a64fe4def26f7f8d6cabec", size = 192620, upload-time = "2025-12-16T16:03:53.041Z" }, + { url = "https://files.pythonhosted.org/packages/0a/26/16a13993ecf4fdc9c39d63b3a6daabafd32a452cf68b81aa9eb3b8170913/compressed_tensors-0.14.0.1-py3-none-any.whl", hash = "sha256:46c4940a3a779d3d97108c294bfcd9acf4bd0491f7c6737c320f0e815ec732e4", size = 196454, upload-time = "2026-03-11T17:04:33.2Z" }, ] [[package]] @@ -559,19 +559,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/f3/6b032a554019cfb3447e671798c1bd3e79b5f1af20d10253f56cea269ef2/cuda_python-12.9.4-py3-none-any.whl", hash = "sha256:d2cacea882a69863f1e7d27ee71d75f0684f4c76910aff839067e4f89c902279", size = 7594, upload-time = "2025-10-21T14:55:12.846Z" }, ] -[[package]] -name = "cupy-cuda12x" -version = "13.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "fastrlock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/c5/7e7fc4816d0de0154e5d9053242c3a08a0ca8b43ee656a6f7b3b95055a7b/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a6970ceefe40f9acbede41d7fe17416bd277b1bd2093adcde457b23b578c5a59", size = 127334633, upload-time = "2025-08-18T08:24:43.065Z" }, - { url = "https://files.pythonhosted.org/packages/e0/95/d7e1295141e7d530674a3cc567e13ed0eb6b81524cb122d797ed996b5bea/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:79b0cacb5e8b190ef409f9e03f06ac8de1b021b0c0dda47674d446f5557e0eb1", size = 112886268, upload-time = "2025-08-18T08:24:49.294Z" }, -] - [[package]] name = "datasets" version = "4.0.0" @@ -848,18 +835,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/5c/9516828037b1680de14bae4d93b2868b26417059e1e5cb7d8ad7be649197/fastcore-1.12.31-py3-none-any.whl", hash = "sha256:2634f117fe3a2f6c250f8500f70211c96158399936e90fa6d1010d2d516bf03a", size = 98509, upload-time = "2026-03-22T02:21:19.274Z" }, ] -[[package]] -name = "fastrlock" -version = "0.8.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/73/b1/1c3d635d955f2b4bf34d45abf8f35492e04dbd7804e94ce65d9f928ef3ec/fastrlock-0.8.3.tar.gz", hash = "sha256:4af6734d92eaa3ab4373e6c9a1dd0d5ad1304e172b1521733c6c3b3d73c8fa5d", size = 79327, upload-time = "2024-12-17T11:03:39.638Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/21/ea1511b0ef0d5457efca3bf1823effb9c5cad4fc9dca86ce08e4d65330ce/fastrlock-0.8.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:85a49a1f1e020097d087e1963e42cea6f307897d5ebe2cb6daf4af47ffdd3eed", size = 52201, upload-time = "2024-12-17T11:02:19.512Z" }, - { url = "https://files.pythonhosted.org/packages/80/07/cdecb7aa976f34328372f1c4efd6c9dc1b039b3cc8d3f38787d640009a25/fastrlock-0.8.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5f13ec08f1adb1aa916c384b05ecb7dbebb8df9ea81abd045f60941c6283a670", size = 53924, upload-time = "2024-12-17T11:02:20.85Z" }, - { url = "https://files.pythonhosted.org/packages/88/6d/59c497f8db9a125066dd3a7442fab6aecbe90d6fec344c54645eaf311666/fastrlock-0.8.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0ea4e53a04980d646def0f5e4b5e8bd8c7884288464acab0b37ca0c65c482bfe", size = 52140, upload-time = "2024-12-17T11:02:22.263Z" }, - { url = "https://files.pythonhosted.org/packages/62/04/9138943c2ee803d62a48a3c17b69de2f6fa27677a6896c300369e839a550/fastrlock-0.8.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:38340f6635bd4ee2a4fb02a3a725759fe921f2ca846cb9ca44531ba739cc17b4", size = 53261, upload-time = "2024-12-17T11:02:24.418Z" }, -] - [[package]] name = "filelock" version = "3.19.1" @@ -932,9 +907,17 @@ dependencies = [ { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +[[package]] +name = "flashinfer-cubin" +version = "0.6.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/e8/826f9452bc5f76b94d7eb025f03dcaf1b51b9ed7790386c0285191e69be4/flashinfer_cubin-0.6.6-py3-none-any.whl", hash = "sha256:36508dfc792eb5ecfb15d2c140a7702812e1fa1ab0fb03929b2ed55e3e8191f3", size = 267661457, upload-time = "2026-03-11T01:36:36.538Z" }, +] + [[package]] name = "flashinfer-python" -version = "0.6.4" +version = "0.6.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -951,9 +934,9 @@ dependencies = [ { name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/77/45/15645d2a4ee81d08206f3e132a77323e48312f510462415d7cd1122eba43/flashinfer_python-0.6.4.tar.gz", hash = "sha256:e6ab798bd1030e5ff7a3bc6952f36386c406928f60b79cf964a6db7aa7ccde75", size = 5337134, upload-time = "2026-02-19T07:33:36.647Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/70/c5a235297351021f5d3d3233523a85f5a6468495587489ad2f257e8eafe2/flashinfer_python-0.6.6.tar.gz", hash = "sha256:0730ba7c7aad332961933bcebc5119762797161ede57d955f6fd199818ed1d92", size = 5344156, upload-time = "2026-03-11T01:36:21.434Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/9a/d2bab76d2bb15062c6a2329614653e4f8bec9c78eec9069856ef0c7c0a79/flashinfer_python-0.6.4-py3-none-any.whl", hash = "sha256:105596b505892ae330af84e250ee0eb6fc2c3a22e8dc42bd46de1b90d36004c8", size = 7819999, upload-time = "2026-02-19T07:33:34.82Z" }, + { url = "https://files.pythonhosted.org/packages/e0/61/385d06755f3ab66333018285657adf0daf8a90a129448231fd09e315bd2e/flashinfer_python-0.6.6-py3-none-any.whl", hash = "sha256:078f158636969eec1a0d3dea19c3ca90b426b66df89bbf7b7b8276ce2ec08148", size = 7817047, upload-time = "2026-03-11T01:36:19.198Z" }, ] [[package]] @@ -1089,19 +1072,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/81/dcf11b5915a39024d4a98ef14b16cb0c636a4f2f26ef657982d3144c6544/grpcio-1.78.0rc2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ea66e360e5ea032a1f6dde926915ba683edca811ed6f0e0620c52e264ba364e4", size = 7698266, upload-time = "2026-01-16T07:28:51.342Z" }, ] -[[package]] -name = "grpcio-reflection" -version = "1.78.0rc2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f7/4f/7c5032a6365cf964ef3e1fea7df76294d2804e6414802e36000d57f0eeb3/grpcio_reflection-1.78.0rc2.tar.gz", hash = "sha256:d925a43ef37f93a2129575111bf4add284d344eecc7e610b65f3b61a59c671b7", size = 19111, upload-time = "2026-01-16T08:01:53.077Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/c3/a1082fb098bb23c19c5df32b78c83bd000289592e079fa2daf7dddba7a43/grpcio_reflection-1.78.0rc2-py3-none-any.whl", hash = "sha256:c4a5f8e9681a83c0ddc80b8648d74144bf6c7bad1e7e2598c77754a8e7998d61", size = 22838, upload-time = "2026-01-16T08:01:41.466Z" }, -] - [[package]] name = "grpclib" version = "0.4.9" @@ -1139,14 +1109,14 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.3.0b0" +version = "1.4.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/4b/59c9a123813f1db5441f037d9a0e9171bd480c4ff3a9562976a8bf8e49ad/hf_xet-1.3.0b0.tar.gz", hash = "sha256:ece497f54c80992e1b145a89065443f6acf9a6b51d8e4648e53e3ad650fbec06", size = 615265, upload-time = "2026-01-28T20:37:21.892Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/92/ec9ad04d0b5728dca387a45af7bc98fbb0d73b2118759f5f6038b61a57e8/hf_xet-1.4.3.tar.gz", hash = "sha256:8ddedb73c8c08928c793df2f3401ec26f95be7f7e516a7bee2fbb546f6676113", size = 670477, upload-time = "2026-03-31T22:40:07.874Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/34/a16aa436c3e59007678cee07f5cf3929ba053b14ae16dffd3be1270d3927/hf_xet-1.3.0b0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa63330e14196071fafc0e369a8e9d3f847335f10d33ca152537fb47bf263440", size = 58044866, upload-time = "2026-01-28T20:36:31.13Z" }, - { url = "https://files.pythonhosted.org/packages/d0/74/2202cc67e82a6eb64e42314e92ff2ee798e6dd5ee394967880b1370e878e/hf_xet-1.3.0b0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:1f8a48df4e67ab695ae802f0d4d07c3d28fed64ea12decef13f8a8550783a42d", size = 53103717, upload-time = "2026-01-28T20:36:26.633Z" }, - { url = "https://files.pythonhosted.org/packages/8d/eb/9cbf85387377adaef317918318d1921b456625fa2535f39e642ed77076e4/hf_xet-1.3.0b0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ae20bc5405c06538ba820e6a3f818df793fee554f83cf071caa641d0b36f08f8", size = 53485235, upload-time = "2026-01-28T20:37:05.554Z" }, - { url = "https://files.pythonhosted.org/packages/0d/28/302fae85503e423e356042a3332e3b2b714b30ce27db2fe415260973bf0e/hf_xet-1.3.0b0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a566da3478ae73ccd6bca8cb8d1ef85bcd4c36e79912cbfafb5b33890a0f1301", size = 55093706, upload-time = "2026-01-28T20:37:09.561Z" }, + { url = "https://files.pythonhosted.org/packages/df/9a/a24b26dc8a65f0ecc0fe5be981a19e61e7ca963b85e062c083f3a9100529/hf_xet-1.4.3-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc360b70c815bf340ed56c7b8c63aacf11762a4b099b2fe2c9bd6d6068668c08", size = 4212320, upload-time = "2026-03-31T22:39:42.922Z" }, + { url = "https://files.pythonhosted.org/packages/53/60/46d493db155d2ee2801b71fb1b0fd67696359047fdd8caee2c914cc50c79/hf_xet-1.4.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:39f2d2e9654cd9b4319885733993807aab6de9dfbd34c42f0b78338d6617421f", size = 3991546, upload-time = "2026-03-31T22:39:41.335Z" }, + { url = "https://files.pythonhosted.org/packages/bc/f5/067363e1c96c6b17256910830d1b54099d06287e10f4ec6ec4e7e08371fc/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:49ad8a8cead2b56051aa84d7fce3e1335efe68df3cf6c058f22a65513885baac", size = 4193200, upload-time = "2026-03-31T22:40:01.936Z" }, + { url = "https://files.pythonhosted.org/packages/42/4b/53951592882d9c23080c7644542fda34a3813104e9e11fa1a7d82d419cb8/hf_xet-1.4.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7716d62015477a70ea272d2d68cd7cad140f61c52ee452e133e139abfe2c17ba", size = 4429392, upload-time = "2026-03-31T22:40:03.492Z" }, ] [[package]] @@ -1209,7 +1179,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "1.4.1" +version = "1.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1218,14 +1188,13 @@ dependencies = [ { name = "httpx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "shellingham", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "typer-slim", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/fc/eb9bc06130e8bbda6a616e1b80a7aa127681c448d6b49806f61db2670b61/huggingface_hub-1.4.1.tar.gz", hash = "sha256:b41131ec35e631e7383ab26d6146b8d8972abc8b6309b963b306fbcca87f5ed5", size = 642156, upload-time = "2026-02-06T09:20:03.013Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/bb/62c7aa86f63a05e2f9b96642fdef9b94526a23979820b09f5455deff4983/huggingface_hub-1.9.0.tar.gz", hash = "sha256:0ea5be7a56135c91797cae6ad726e38eaeb6eb4b77cefff5c9d38ba0ecf874f7", size = 750326, upload-time = "2026-04-03T08:35:55.888Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/ae/2f6d96b4e6c5478d87d606a1934b5d436c4a2bce6bb7c6fdece891c128e3/huggingface_hub-1.4.1-py3-none-any.whl", hash = "sha256:9931d075fb7a79af5abc487106414ec5fba2c0ae86104c0c62fd6cae38873d18", size = 553326, upload-time = "2026-02-06T09:20:00.728Z" }, + { url = "https://files.pythonhosted.org/packages/73/37/0d15d16150e1829f3e90962c99f28257f6de9e526a680b4c6f5acdb54fd2/huggingface_hub-1.9.0-py3-none-any.whl", hash = "sha256:2999328c058d39fd19ab748dd09bd4da2fbaa4f4c1ddea823eab103051e14a1f", size = 637355, upload-time = "2026-04-03T08:35:53.897Z" }, ] [[package]] @@ -1499,16 +1468,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571, upload-time = "2025-05-05T12:32:29.534Z" }, ] -[[package]] -name = "kaldi-native-fbank" -version = "1.22.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/2c/84076b352107ce12d56f28c313f1aca1be332d953dd96aec7b84976e6d53/kaldi-native-fbank-1.22.3.tar.gz", hash = "sha256:387bf87225c6b83c93ae652eeaef1b4d531994b6e398e7a77189de340674f9af", size = 71013, upload-time = "2025-10-09T02:31:21.487Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/28/6f4fd8953c0b3f30de4526fd024095032abcdc25b6736c77a891687c604e/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f5a44b4a83cf9bf13d3f77858928068b06d3ec2238c27ff2e39393fbf7749c9f", size = 298887, upload-time = "2025-10-09T02:30:53.739Z" }, - { url = "https://files.pythonhosted.org/packages/84/90/01ef7331c52b1eaf9916f3f7a535155aac2e9e2ddad12a141613d92758c7/kaldi_native_fbank-1.22.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f16e74372fe9e20abb4183f98a8e2288d5ee4c48d04d94b6160311170e007661", size = 322002, upload-time = "2025-10-09T02:30:13.04Z" }, -] - [[package]] name = "kubernetes" version = "35.0.0" @@ -1813,7 +1772,7 @@ wheels = [ [[package]] name = "mistral-common" -version = "1.9.1" +version = "1.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonschema", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1825,9 +1784,9 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/db/ce/685b8127a326478e05501cb4c9ca23d1cd9f37e16c465a1e832c75aea709/mistral_common-1.9.1.tar.gz", hash = "sha256:550583d70a395c3586cfb748ffab53bd1d7c3409507f0efc0118bff30ffb26e9", size = 6338922, upload-time = "2026-02-12T10:53:41.639Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/97/753c85b5c0a19f4331ac99e0300ac8da06d4b29b629c9cb03064b38561bd/mistral_common-1.11.0.tar.gz", hash = "sha256:439b7fa38f9c3f020154af51bdf30eb81def507643017d8ce9f798384ec47ec3", size = 6355512, upload-time = "2026-04-01T13:54:12.36Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/72/a38bb1fd9fd4d4ef990341c9dd1a7c8061f1951e10efa6d50c0a3f04eced/mistral_common-1.9.1-py3-none-any.whl", hash = "sha256:9e2b2520b6f67bac2e2bb06fcf985b7a1277b01938da2b7cda8cf0fdbfa92e91", size = 6518623, upload-time = "2026-02-12T10:53:39.457Z" }, + { url = "https://files.pythonhosted.org/packages/60/e4/73ad3c27e3fb613c3ce0953c928202c46cddebac3989b87be1b6f305a9f6/mistral_common-1.11.0-py3-none-any.whl", hash = "sha256:1d3ecaf7c3aa7338cb37b596fd0fb294485753958ee8e7254a6cc23eb30b249b", size = 6531513, upload-time = "2026-04-01T13:54:16.536Z" }, ] [package.optional-dependencies] @@ -2813,10 +2772,10 @@ requires-dist = [ { name = "torch", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "torchdata", specifier = ">=0.11.0" }, { name = "torchtitan", git = "https://github.com/pytorch/torchtitan?rev=a1fdd7e" }, - { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=5c1c72b" }, + { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=d3c830c" }, - { name = "vllm", specifier = ">=0.17.0" }, + { name = "vllm", specifier = ">=0.19.0" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" }, { name = "wandb", specifier = ">=0.24.2" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, @@ -3237,30 +3196,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/28/e688da97806474963b2983a8231572733de11ed48f8899410c0afc8bdf15/quack_kernels-0.3.6-py3-none-any.whl", hash = "sha256:973d974ccca816014006af8f136294ba0f15b1c324197f8b54ffc6500666e0e5", size = 198865, upload-time = "2026-03-24T13:29:59.107Z" }, ] -[[package]] -name = "ray" -version = "2.49.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "jsonschema", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "msgpack", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/ba/77eae921fc2595087516df6efc0ca03caacc14d16592341985916f1aed13/ray-2.49.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:cf63b916e399c2a4d484249611a96cee283cef32e544126115a4ad3e872c34eb", size = 69287149, upload-time = "2025-09-03T00:25:47.605Z" }, - { url = "https://files.pythonhosted.org/packages/00/02/c81260c0f94bd34a1442ea488bdd433dfc9e6ed6211c9a59bc4157b8e00e/ray-2.49.1-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:484064fca02732e0b6f4a09dad0d1fb6abd4ca4b6d9bf7c26ab7a17a2887cd09", size = 70114899, upload-time = "2025-09-03T00:25:53.144Z" }, -] - -[package.optional-dependencies] -cgraph = [ - { name = "cupy-cuda12x", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] - [[package]] name = "referencing" version = "0.36.2" @@ -3277,14 +3212,14 @@ wheels = [ [[package]] name = "regex" -version = "2025.9.1" +version = "2026.3.32" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b2/5a/4c63457fbcaf19d138d72b2e9b39405954f98c0349b31c601bfcb151582c/regex-2025.9.1.tar.gz", hash = "sha256:88ac07b38d20b54d79e704e38aa3bd2c0f8027432164226bdee201a1c0c9c9ff", size = 400852, upload-time = "2025-09-01T22:10:10.479Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/93/5ab3e899c47fa7994e524447135a71cd121685a35c8fe35029005f8b236f/regex-2026.3.32.tar.gz", hash = "sha256:f1574566457161678297a116fa5d1556c5a4159d64c5ff7c760e7c564bf66f16", size = 415605, upload-time = "2026-03-28T21:49:22.012Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0f/74/f933a607a538f785da5021acf5323961b4620972e2c2f1f39b6af4b71db7/regex-2025.9.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9dc5991592933a4192c166eeb67b29d9234f9c86344481173d1bc52f73a7104", size = 797441, upload-time = "2025-09-01T22:08:39.108Z" }, - { url = "https://files.pythonhosted.org/packages/b2/02/5c891bb5fe0691cc1bad336e3a94b9097fbcf9707ec8ddc1dce9f0397289/regex-2025.9.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47829ffaf652f30d579534da9085fe30c171fa2a6744a93d52ef7195dc38218b", size = 801991, upload-time = "2025-09-01T22:08:44.072Z" }, - { url = "https://files.pythonhosted.org/packages/f1/ae/fd10d6ad179910f7a1b3e0a7fde1ef8bb65e738e8ac4fd6ecff3f52252e4/regex-2025.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e978e5a35b293ea43f140c92a3269b6ab13fe0a2bf8a881f7ac740f5a6ade85", size = 786651, upload-time = "2025-09-01T22:08:46.079Z" }, - { url = "https://files.pythonhosted.org/packages/93/fa/b4c6dbdedc85ef4caec54c817cd5f4418dbfa2453214119f2538082bf666/regex-2025.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:656563e620de6908cd1c9d4f7b9e0777e3341ca7db9d4383bcaa44709c90281e", size = 788138, upload-time = "2025-09-01T22:08:51.933Z" }, + { url = "https://files.pythonhosted.org/packages/4a/cf/1955bb5567bc491bd63068e17f75ab0c9ff5e9d08466beec7e347f5e768d/regex-2026.3.32-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66a5083c3ffe5a5a95f8281ea47a88072d4f24001d562d1d9d28d4cdc005fec5", size = 796431, upload-time = "2026-03-28T21:46:33.101Z" }, + { url = "https://files.pythonhosted.org/packages/0a/fe/661043d1c263b0d9d10c6ff4e9c9745f3df9641c62b51f96a3473638e7ce/regex-2026.3.32-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f54840bea73541652f1170dc63402a5b776fc851ad36a842da9e5163c1f504a0", size = 801512, upload-time = "2026-03-28T21:46:38.587Z" }, + { url = "https://files.pythonhosted.org/packages/b6/c8/d833397b70cd1bacfcdc0a611f0e2c1f5b91fee8eedd88affcee770cbbb6/regex-2026.3.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:66d3126afe7eac41759cd5f0b3b246598086e88e70527c0d68c9e615b81771c4", size = 785837, upload-time = "2026-03-28T21:46:42.926Z" }, + { url = "https://files.pythonhosted.org/packages/18/f4/04ed04ebf335a44083695c22772be6a42efa31900415555563acf02cb4de/regex-2026.3.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b56993a7aeb4140c4770f4f7965c9e5af4f024457d06e23c01b0d47501cb18ed", size = 788332, upload-time = "2026-03-28T21:46:50.454Z" }, ] [[package]] @@ -3888,8 +3823,8 @@ wheels = [ [[package]] name = "transformers" -version = "5.3.0.dev0" -source = { git = "https://github.com/huggingface/transformers.git?rev=5c1c72b#5c1c72be5f864d10d0efe8ece0768d9ed6ee4fdd" } +version = "5.5.0" +source = { git = "https://github.com/huggingface/transformers.git?rev=c1c3424#c1c34249fa27deefbd4a377dfbf883a39baf5c6d" } dependencies = [ { name = "huggingface-hub", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3938,18 +3873,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/ed/d6fca788b51d0d4640c4bc82d0e85bad4b49809bca36bf4af01b4dcb66a7/typer-0.23.0-py3-none-any.whl", hash = "sha256:79f4bc262b6c37872091072a3cb7cb6d7d79ee98c0c658b4364bdcde3c42c913", size = 56668, upload-time = "2026-02-11T15:22:21.075Z" }, ] -[[package]] -name = "typer-slim" -version = "0.23.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1f/8a/881cfd399a119db89619dc1b93d36e2fb6720ddb112bceff41203f1abd72/typer_slim-0.23.0.tar.gz", hash = "sha256:be8b60243df27cfee444c6db1b10a85f4f3e54d940574f31a996f78aa35a8254", size = 4773, upload-time = "2026-02-11T15:22:19.106Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/3e/ba3a222c80ee070d9497ece3e1fe77253c142925dd4c90f04278aac0a9eb/typer_slim-0.23.0-py3-none-any.whl", hash = "sha256:1d693daf22d998a7b1edab8413cdcb8af07254154ce3956c1664dc11b01e2f8b", size = 3399, upload-time = "2026-02-11T15:22:17.792Z" }, -] - [[package]] name = "types-certifi" version = "2021.10.8.3" @@ -4131,7 +4054,7 @@ wheels = [ [[package]] name = "vllm" -version = "0.17.0" +version = "0.19.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4146,12 +4069,10 @@ dependencies = [ { name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "fastapi", extra = ["standard"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "flashinfer-cubin", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "flashinfer-python", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "gguf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "grpcio-reflection", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "ijson", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "kaldi-native-fbank", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "lark", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "llguidance", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "lm-format-enforcer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4162,6 +4083,7 @@ dependencies = [ { name = "ninja", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numba", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "nvidia-cudnn-frontend", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "openai-harmony", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4184,7 +4106,6 @@ dependencies = [ { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "quack-kernels", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "ray", extra = ["cgraph"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sentencepiece", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4202,10 +4123,10 @@ dependencies = [ { name = "watchfiles", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "xgrammar", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/13/d5/af83a4262ca4d5692a93b3c322ae954e3e6c4e23f8f9db3ab87bd79c919e/vllm-0.17.0.tar.gz", hash = "sha256:b0b62e58ef4eb633ef371f2726976372cf6dfcb7ff2ea9ddf7194c1930d5629a", size = 30541311, upload-time = "2026-03-07T03:54:54.333Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/14/c330a72309051f762b357a2e41d5015bedbb106ad1e16a231bdfda2e2163/vllm-0.19.0.tar.gz", hash = "sha256:81e59cf87175e7a62eb8d9acf5989484bbd17089d5eface353f89067bda282d9", size = 31071745, upload-time = "2026-04-03T04:04:52.833Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f2/72/78a48668f2631def18bbaaa331d7878bcfc5c3137455422aafb0748e1261/vllm-0.17.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:310fb82fe061ed75dceeb4aeb803cd8ee0d590337ec720f7abfb03a69314d710", size = 385329399, upload-time = "2026-03-07T03:54:34.261Z" }, - { url = "https://files.pythonhosted.org/packages/25/4f/972726f9a501f01203b5c4796e1932abbe435fae6d7715a4c3f1aad14a58/vllm-0.17.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:0296670a09d392ee43455d9bebf590d05a9bc2ebce5e25e2919222fc815158da", size = 432927988, upload-time = "2026-03-07T03:54:02.312Z" }, + { url = "https://files.pythonhosted.org/packages/c8/51/467e7a8cb4838022daa731b7f8b34c228691e36f938e1803c3a702c7bd69/vllm-0.19.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:6ab90ccca5d7ca3bd2c8f90133f0fac85e8f4af582a1c67c6cc3f63c615521e3", size = 384650557, upload-time = "2026-04-03T04:05:52.513Z" }, + { url = "https://files.pythonhosted.org/packages/b7/08/6a431731e4c163bc1fab85b63e269d84104aad0fba98dac1af34fdc5077f/vllm-0.19.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:2d0e5fae45367bdbf111fcad68f4c0f8fdddd2f2fb643e52f0f2daebef7b41cf", size = 432281473, upload-time = "2026-04-03T04:05:22.07Z" }, ] [[package]] @@ -4360,7 +4281,7 @@ wheels = [ [[package]] name = "xgrammar" -version = "0.1.29" +version = "0.1.33" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4370,10 +4291,10 @@ dependencies = [ { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/a3/70dbe3ffd331a1e7e1ad5a95690a4086e6c7cdb8089f5c7eda712219ccec/xgrammar-0.1.29.tar.gz", hash = "sha256:cf195afa81b489eebf35d4c6f37f27136d05420739ab4a6f7f065c938d7e4baa", size = 2321317, upload-time = "2025-12-19T08:23:54.53Z" } +sdist = { url = "https://files.pythonhosted.org/packages/db/43/e5dfddb1d2a4fccf3e3a88f103e88698cdefc3182f4e169a359ffe1c1794/xgrammar-0.1.33.tar.gz", hash = "sha256:8dbe5fc3d76651ab1fac7a68fc2a118b885fa0ec7189927fb6e0dce0081aea99", size = 2398956, upload-time = "2026-03-27T10:16:36.582Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/94/18793c64bf0368075a34c06e196bf002f1e6ab0aee332268f44e8d356d5a/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6eb370a16b27a683e5f2b9e429ab41440c69977d4a504849ed61831b94cc704c", size = 34705239, upload-time = "2025-12-19T08:23:28.369Z" }, - { url = "https://files.pythonhosted.org/packages/3e/da/4c14e3e00be698009b52700f15326a23272b4b00475939b6acc86b151188/xgrammar-0.1.29-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79e6e4f5cd33be77418cf91efc482f2b3d773d309891224383bc8a4948ad7b07", size = 34906135, upload-time = "2025-12-19T08:23:30.838Z" }, + { url = "https://files.pythonhosted.org/packages/4e/04/43d4baca876f5ae1b45897ec30a59801a2da37f16da1fcd85f9555e4c125/xgrammar-0.1.33-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c803e60d791854c5d1f271ece7e1f34d73c82dd4a8b2a06b7af5331482a78ac", size = 42133168, upload-time = "2026-03-27T10:15:16.994Z" }, + { url = "https://files.pythonhosted.org/packages/f0/a8/672833a3cff027253793aa999401d8364896ebf396967e475c7a878b895f/xgrammar-0.1.33-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b8eaa533282a0efb0835db6998ae72e7b3c7875d7a52e360ffebff9b78c30a", size = 42205803, upload-time = "2026-03-27T10:15:21.599Z" }, ] [[package]]