Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ aime2025 = { index = "primeintellect" }
deepdive = { index = "primeintellect" }
mini_swe_agent_plus = { index = "primeintellect" }
deep-ep = { url = "https://github.com/samsja/flash-attn-builds/releases/download/v0.2/deep_ep-1.2.1+73b6ea4-cp312-cp312-linux_x86_64.whl" }
deep-gemm = { url = "https://github.com/samsja/flash-attn-builds/releases/download/v0.2/deep_gemm-2.3.0+477618c-cp312-cp312-linux_x86_64.whl" }
deep-gemm = { path = "tools/wheels/deep_gemm-2.3.0+d30fc36-cp312-cp312-linux_x86_64.whl" }
nixl-cu12 = { url = "https://github.com/samsja/flash-attn-builds/releases/download/v0.2/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }
flash-linear-attention = { git = "https://github.com/fla-org/flash-linear-attention" }

Expand Down
14 changes: 14 additions & 0 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ class ModelConfig(BaseModelConfig):
),
] = True

fp8: Annotated[
bool,
Field(
description="Whether to use FP8 training via DeepGEMM. Replaces nn.Linear layers with FP8 blockwise linear "
"and uses FP8 grouped GEMM for MoE experts. Requires SM90 (Hopper) GPUs and model.impl='custom'.",
),
] = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHANGELOG not updated for new config field

Low Severity

A new fp8 config field was added to ModelConfig in src/prime_rl/configs/trainer.py, but CHANGELOG.md was not updated. Per the project rule, any PR that modifies configuration structures (added fields) in src/prime_rl/configs/*.py must include a corresponding CHANGELOG.md entry.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


freeze_moe_router: Annotated[
bool,
Field(
Expand Down Expand Up @@ -401,6 +409,12 @@ def flash_attention_4_only_with_custom_impl(self):
raise ValueError("Flash attention 4 is only supported with the custom implementation")
return self

@model_validator(mode="after")
def fp8_only_with_custom_impl(self):
if self.fp8 and self.impl not in ("custom", "auto"):
raise ValueError("FP8 training is only supported with model.impl='custom' or 'auto'.")
return self

@model_validator(mode="after")
def validate_ep_comm_backend(self):
if self.ep_comm_backend == "torch":
Expand Down
5 changes: 5 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
set_selective_activation_checkpointing,
supports_selective_activation_checkpointing,
)
from prime_rl.trainer.models.layers.fp8_linear import replace_linear_with_fp8_blockwise_linear
from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head
from prime_rl.trainer.models.layers.moe import LatentMoE, MoE
from prime_rl.trainer.parallel_dims import ParallelDims
Expand Down Expand Up @@ -244,6 +245,7 @@ def get_model(
if subconfig is not None and hasattr(subconfig, "use_cache"):
subconfig.use_cache = False
model_config.use_grouped_mm = config.moe_use_grouped_mm
model_config.fp8 = config.fp8

# Ensure pad_token_id is set (some models like Qwen3MoE don't have it).
# In transformers v5, token IDs moved from PretrainedConfig to GenerationConfig.
Expand Down Expand Up @@ -783,6 +785,9 @@ def setup_model(

inject_prime_lm_head(model, chunk_size=lm_head_chunk_size, fused_cross_entropy=fused_cross_entropy)

if config.fp8:
replace_linear_with_fp8_blockwise_linear(model)

# Apply LoRA before FSDP setup
if config.lora is not None:
apply_lora_to_model(model, config.lora)
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
## a bit of context here, this basically copy AutoModelForCausalLM from transformers, but use our own model instead

from collections import OrderedDict

from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
from transformers.models.llama.configuration_llama import LlamaConfig

from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM
# from prime_rl.trainer.models.glm_moe_dsa import GlmMoeDsaConfig, GlmMoeDsaForCausalLM
from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput, cast_float_and_contiguous
from prime_rl.trainer.models.llama import LlamaForCausalLM
from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config, MiniMaxM2ForCausalLM
from prime_rl.trainer.models.nemotron_h import NemotronHConfig, NemotronHForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM

Check failure on line 20 in src/prime_rl/trainer/models/__init__.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (I001)

src/prime_rl/trainer/models/__init__.py:3:1: I001 Import block is un-sorted or un-formatted

# Make custom config discoverable by AutoConfig
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
AutoConfig.register("glm4_moe", Glm4MoeConfig, exist_ok=True)
AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True)
# AutoConfig.register("glm_moe_dsa", GlmMoeDsaConfig, exist_ok=True)
AutoConfig.register("minimax_m2", MiniMaxM2Config, exist_ok=True)
AutoConfig.register("nemotron_h", NemotronHConfig, exist_ok=True)
AutoConfig.register("qwen3_moe", Qwen3MoeConfig, exist_ok=True)
Expand All @@ -32,7 +32,7 @@
_CUSTOM_CAUSAL_LM_MAPPING.register(LlamaConfig, LlamaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(AfmoeConfig, AfmoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Glm4MoeConfig, Glm4MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(GlmMoeDsaConfig, GlmMoeDsaForCausalLM, exist_ok=True)
# _CUSTOM_CAUSAL_LM_MAPPING.register(GlmMoeDsaConfig, GlmMoeDsaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(MiniMaxM2Config, MiniMaxM2ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(NemotronHConfig, NemotronHForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True)
Expand Down
1 change: 1 addition & 0 deletions src/prime_rl/trainer/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
top_k=config.num_experts_per_tok,
use_grouped_mm=getattr(config, "use_grouped_mm", True),
load_balance_coeff=getattr(config, "load_balance_coeff", None),
fp8=getattr(config, "fp8", False),
)
if self.moe_enabled:
self.mlp = MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
Expand Down
1 change: 1 addition & 0 deletions src/prime_rl/trainer/models/glm4_moe/modeling_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, config: Glm4MoeConfig, layer_idx: int):
top_k=config.num_experts_per_tok,
load_balance_coeff=1e-3,
use_grouped_mm=config.use_grouped_mm,
fp8=getattr(config, "fp8", False),
)
mlp_config = MLPConfig(
hidden_size=config.hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
top_k=config.num_experts_per_tok,
load_balance_coeff=1e-3,
use_grouped_mm=config.use_grouped_mm,
fp8=getattr(config, "fp8", False),
)
mlp_config = MLPConfig(
hidden_size=config.hidden_size,
Expand Down
Loading
Loading