Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def get_model(
config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code
),
)
if not is_vlm_training and getattr(model_config, "model_type", "") == "qwen3_5":
logger.info(f"Using text-only Qwen3.5 config path for {config.name}")
text_config = cast(PretrainedConfig, model_config.text_config)
text_config._attn_implementation = getattr(model_config, "_attn_implementation", config.attn)
text_config._name_or_path = getattr(model_config, "_name_or_path", config.name)
model_config = text_config

model_config.use_cache = False
is_vlm_arch = is_vlm_architecture(model_config)

Expand Down
8 changes: 8 additions & 0 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
## a bit of context here, this basically copy AutoModelForCausalLM from transformers, but use our own model instead

from collections import OrderedDict

from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

Check failure on line 12 in src/prime_rl/trainer/models/__init__.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (F811)

src/prime_rl/trainer/models/__init__.py:12:59: F811 Redefinition of unused `Qwen3Config` from line 10
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig

Check failure on line 13 in src/prime_rl/trainer/models/__init__.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (F811)

src/prime_rl/trainer/models/__init__.py:13:63: F811 Redefinition of unused `Qwen3_5TextConfig` from line 11

from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
Expand All @@ -16,6 +20,8 @@
from prime_rl.trainer.models.llama import LlamaForCausalLM
from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config, MiniMaxM2ForCausalLM
from prime_rl.trainer.models.nemotron_h import NemotronHConfig, NemotronHForCausalLM
from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM
from prime_rl.trainer.models.qwen3_5 import Qwen3_5ForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM

Check failure on line 26 in src/prime_rl/trainer/models/__init__.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (I001)

src/prime_rl/trainer/models/__init__.py:3:1: I001 Import block is un-sorted or un-formatted

Expand All @@ -35,6 +41,8 @@
_CUSTOM_CAUSAL_LM_MAPPING.register(GlmMoeDsaConfig, GlmMoeDsaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(MiniMaxM2Config, MiniMaxM2ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(NemotronHConfig, NemotronHForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3Config, Qwen3ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5TextConfig, Qwen3_5ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True)

Expand Down
8 changes: 5 additions & 3 deletions src/prime_rl/trainer/models/layers/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,11 @@ def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen):
FlashAttention._compute_attention = _ring_compute_attention

from prime_rl.trainer.models.afmoe.modeling_afmoe import AfmoeFlashAttention

AfmoeFlashAttention._compute_attention = _ring_compute_attention

from prime_rl.trainer.models.qwen3.modeling_qwen3 import Qwen3FlashAttention
from prime_rl.trainer.models.qwen3_5.modeling_qwen3_5 import Qwen3_5FlashAttention
from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedFlashAttention

AfmoeFlashAttention._compute_attention = _ring_compute_attention
Qwen3FlashAttention._compute_attention = _ring_compute_attention
Qwen3_5FlashAttention._compute_attention = _ring_compute_attention
Qwen3_5MoeGatedFlashAttention._compute_attention = _ring_compute_attention
14 changes: 14 additions & 0 deletions src/prime_rl/trainer/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

from prime_rl.trainer.models.qwen3.modeling_qwen3 import (
Qwen3ForCausalLM,
Qwen3Model,
Qwen3PreTrainedModel,
)

__all__ = [
"Qwen3Config",
"Qwen3ForCausalLM",
"Qwen3Model",
"Qwen3PreTrainedModel",
]
Loading
Loading