Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def get_model(
config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code
),
)
if not is_vlm_training and getattr(model_config, "model_type", "") == "qwen3_5":
logger.info(f"Using text-only Qwen3.5 config path for {config.name}")
text_config = cast(PretrainedConfig, model_config.text_config)
text_config._attn_implementation = getattr(model_config, "_attn_implementation", config.attn)
text_config._name_or_path = getattr(model_config, "_name_or_path", config.name)
model_config = text_config

model_config.use_cache = False
is_vlm_arch = is_vlm_architecture(model_config)

Expand Down
6 changes: 6 additions & 0 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig

from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
Expand All @@ -16,6 +18,8 @@
from prime_rl.trainer.models.llama import LlamaForCausalLM
from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config, MiniMaxM2ForCausalLM
from prime_rl.trainer.models.nemotron_h import NemotronHConfig, NemotronHForCausalLM
from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM
from prime_rl.trainer.models.qwen3_5 import Qwen3_5ForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM

Expand All @@ -35,6 +39,8 @@
_CUSTOM_CAUSAL_LM_MAPPING.register(GlmMoeDsaConfig, GlmMoeDsaForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(MiniMaxM2Config, MiniMaxM2ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(NemotronHConfig, NemotronHForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3Config, Qwen3ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5TextConfig, Qwen3_5ForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True)

Expand Down
8 changes: 5 additions & 3 deletions src/prime_rl/trainer/models/layers/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,11 @@ def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen):
FlashAttention._compute_attention = _ring_compute_attention

from prime_rl.trainer.models.afmoe.modeling_afmoe import AfmoeFlashAttention

AfmoeFlashAttention._compute_attention = _ring_compute_attention

from prime_rl.trainer.models.qwen3.modeling_qwen3 import Qwen3FlashAttention
from prime_rl.trainer.models.qwen3_5.modeling_qwen3_5 import Qwen3_5FlashAttention
from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedFlashAttention

AfmoeFlashAttention._compute_attention = _ring_compute_attention
Qwen3FlashAttention._compute_attention = _ring_compute_attention
Qwen3_5FlashAttention._compute_attention = _ring_compute_attention
Qwen3_5MoeGatedFlashAttention._compute_attention = _ring_compute_attention
14 changes: 14 additions & 0 deletions src/prime_rl/trainer/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

from prime_rl.trainer.models.qwen3.modeling_qwen3 import (
Qwen3ForCausalLM,
Qwen3Model,
Qwen3PreTrainedModel,
)

__all__ = [
"Qwen3Config",
"Qwen3ForCausalLM",
"Qwen3Model",
"Qwen3PreTrainedModel",
]
Loading
Loading