Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions configs/gemma4/rl_dense.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Gemma4 31B dense RL test (reverse-text)
output_dir = "outputs/gemma4-dense-rl"
max_steps = 5
seq_len = 2048

[slurm]
job_name = "gemma4-dense-rl"

[deployment]
type = "single_node"
num_train_gpus = 4
num_infer_gpus = 4

[model]
name = "google/gemma-4-31B-it"

[wandb]
project = "gemma4-test"
name = "rl-dense"

[orchestrator]
batch_size = 128
rollouts_per_example = 16

[orchestrator.sampling]
max_tokens = 128

[[orchestrator.env]]
id = "reverse-text"

[trainer.model]
attn = "flash_attention_2"

[trainer.model.ac]

[trainer.optim]
lr = 3e-6

[inference]
43 changes: 43 additions & 0 deletions configs/gemma4/rl_moe.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Gemma4 26B-A4B MoE RL test (reverse-text)
output_dir = "outputs/gemma4-moe-rl"
max_steps = 5
seq_len = 2048

[slurm]
job_name = "gemma4-moe-rl"

[deployment]
type = "single_node"
num_train_gpus = 4
num_infer_gpus = 4

[model]
name = "google/gemma-4-26B-A4B-it"

[wandb]
project = "gemma4-test"
name = "rl-moe"

[orchestrator]
batch_size = 128
rollouts_per_example = 16

[orchestrator.sampling]
max_tokens = 128

[[orchestrator.env]]
id = "reverse-text"

[trainer.model]
attn = "flash_attention_2"

[trainer.model.ac]

[trainer.optim]
lr = 3e-6

[inference]
gpu_memory_utilization = 0.7

[inference.model]
max_model_len = 2048
29 changes: 29 additions & 0 deletions configs/gemma4/sft_dense.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Gemma4 31B dense SFT test
# Usage: uv run sft @ configs/gemma4/sft_dense.toml
output_dir = "outputs/gemma4-dense-sft"
max_steps = 10

[slurm]
job_name = "gemma4-dense-sft"

[deployment]
type = "single_node"
num_gpus = 8

[model]
name = "google/gemma-4-31B-it"
attn = "flash_attention_2"

[model.ac]

[wandb]
project = "gemma4-test"
name = "sft-dense"

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 8
seq_len = 2048

[data.chat_template_kwargs]
enable_thinking = true
29 changes: 29 additions & 0 deletions configs/gemma4/sft_moe.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Gemma4 26B-A4B MoE SFT test
# Usage: uv run sft @ configs/gemma4/sft_moe.toml
output_dir = "outputs/gemma4-moe-sft"
max_steps = 10

[slurm]
job_name = "gemma4-moe-sft"

[deployment]
type = "single_node"
num_gpus = 8

[model]
name = "google/gemma-4-26B-A4B-it"
attn = "flash_attention_2"

[model.ac]

[wandb]
project = "gemma4-test"
name = "sft-moe"

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 8
seq_len = 2048

[data.chat_template_kwargs]
enable_thinking = true
18 changes: 18 additions & 0 deletions configs/gemma4/sft_qwen3.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Qwen3 0.6B SFT sanity check
output_dir = "outputs/qwen3-sft"
max_steps = 5

[slurm]
job_name = "qwen3-sft"

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[wandb]
project = "gemma4-test"
name = "sft-qwen3"

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 4
seq_len = 1024
19 changes: 19 additions & 0 deletions configs/gemma4/sft_qwen35.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Qwen3.5 MoE text-only SFT test (proving VLM bug)
# Usage: uv run sft @ configs/gemma4/sft_qwen35.toml
output_dir = "outputs/qwen35-moe-sft"
max_steps = 5

[slurm]
job_name = "qwen35-moe-sft"

[model]
name = "Qwen/Qwen3.5-35B-A3B"

[wandb]
project = "gemma4-test"
name = "sft-qwen35"

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 4
seq_len = 2048
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"torch>=2.9.0",
"torchdata>=0.11.0",
"transformers",
"vllm>=0.17.0",
"vllm>=0.19.0",
"wandb>=0.24.2",
"ring-flash-attn>=0.1.8",
"prime>=0.5.37",
Expand Down Expand Up @@ -123,7 +123,7 @@ torch = { index = "pytorch-cu128" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "d3c830c" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "abd9943b" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.14/vllm_router-0.1.14-cp38-abi3-linux_x86_64.whl" }
Expand Down
1 change: 1 addition & 0 deletions src/prime_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import prime_rl._compat # noqa: F401 — must run before ring_flash_attn is imported
19 changes: 19 additions & 0 deletions src/prime_rl/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Compatibility shim: ring_flash_attn + transformers >= 5.4.

ring_flash_attn 0.1.8 imports `is_flash_attn_greater_or_equal_2_10` from
`transformers.modeling_flash_attention_utils`. This symbol was removed from
that module in transformers 5.4 (still available as a deprecated function
in `transformers.utils.import_utils`, scheduled for removal in 5.8).

ring_flash_attn's except-branch is a no-op (imports the same symbol again),
so the import crashes on transformers >= 5.4. We patch the symbol back in as
`True` — the check is dead code since no one uses flash_attn < 2.1.0 anymore.

Upstream fix: https://github.com/zhuzilin/ring-flash-attention/pull/85
Remove this shim once ring_flash_attn ships a fixed version.
"""

import transformers.modeling_flash_attention_utils as _mfau

if not hasattr(_mfau, "is_flash_attn_greater_or_equal_2_10"):
_mfau.is_flash_attn_greater_or_equal_2_10 = True
10 changes: 9 additions & 1 deletion src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias

from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand Down Expand Up @@ -86,6 +86,14 @@ class SFTDataConfig(BaseDataConfig):
# Configuring
loss_mask: LossMaskConfig = LossMaskConfig()

chat_template_kwargs: Annotated[
dict[str, Any] | None,
Field(
description="Extra keyword arguments passed to tokenizer.apply_chat_template(). "
"E.g. {'enable_thinking': true} for models with thinking-aware templates (Gemma4, Qwen3)."
),
] = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config field

Low Severity

A new chat_template_kwargs field is added to SFTDataConfig in src/prime_rl/configs/sft.py, but CHANGELOG.md has no corresponding entry. Per the project rule, any PR modifying configuration structures or usage patterns in config files must update the changelog.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


@model_validator(mode="after")
def validate_subsets_and_splits(self):
if self.subsets is not None or self.splits is not None:
Expand Down
Loading
Loading