diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 188b0f61f7..e1a068a8e2 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -225,6 +225,13 @@ def get_model( config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code ), ) + if not is_vlm_training and getattr(model_config, "model_type", "") == "qwen3_5": + logger.info(f"Using text-only Qwen3.5 config path for {config.name}") + text_config = cast(PretrainedConfig, model_config.text_config) + text_config._attn_implementation = getattr(model_config, "_attn_implementation", config.attn) + text_config._name_or_path = getattr(model_config, "_name_or_path", config.name) + model_config = text_config + model_config.use_cache = False is_vlm_arch = is_vlm_architecture(model_config) diff --git a/src/prime_rl/trainer/models/__init__.py b/src/prime_rl/trainer/models/__init__.py index b2e2068217..d5fdc7c521 100644 --- a/src/prime_rl/trainer/models/__init__.py +++ b/src/prime_rl/trainer/models/__init__.py @@ -7,6 +7,8 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM from prime_rl.trainer.models.base import PreTrainedModelPrimeRL @@ -16,6 +18,8 @@ from prime_rl.trainer.models.llama import LlamaForCausalLM from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config, MiniMaxM2ForCausalLM from prime_rl.trainer.models.nemotron_h import NemotronHConfig, NemotronHForCausalLM +from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM +from prime_rl.trainer.models.qwen3_5 import Qwen3_5ForCausalLM from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM @@ -35,6 +39,8 @@ _CUSTOM_CAUSAL_LM_MAPPING.register(GlmMoeDsaConfig, GlmMoeDsaForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(MiniMaxM2Config, MiniMaxM2ForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(NemotronHConfig, NemotronHForCausalLM, exist_ok=True) +_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3Config, Qwen3ForCausalLM, exist_ok=True) +_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5TextConfig, Qwen3_5ForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True) diff --git a/src/prime_rl/trainer/models/layers/attn.py b/src/prime_rl/trainer/models/layers/attn.py index f639245e8b..ffeb649086 100644 --- a/src/prime_rl/trainer/models/layers/attn.py +++ b/src/prime_rl/trainer/models/layers/attn.py @@ -325,9 +325,11 @@ def _ring_compute_attention(self, q, k, v, cu_seqlens, max_seqlen): FlashAttention._compute_attention = _ring_compute_attention from prime_rl.trainer.models.afmoe.modeling_afmoe import AfmoeFlashAttention - - AfmoeFlashAttention._compute_attention = _ring_compute_attention - + from prime_rl.trainer.models.qwen3.modeling_qwen3 import Qwen3FlashAttention + from prime_rl.trainer.models.qwen3_5.modeling_qwen3_5 import Qwen3_5FlashAttention from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeGatedFlashAttention + AfmoeFlashAttention._compute_attention = _ring_compute_attention + Qwen3FlashAttention._compute_attention = _ring_compute_attention + Qwen3_5FlashAttention._compute_attention = _ring_compute_attention Qwen3_5MoeGatedFlashAttention._compute_attention = _ring_compute_attention diff --git a/src/prime_rl/trainer/models/qwen3/__init__.py b/src/prime_rl/trainer/models/qwen3/__init__.py new file mode 100644 index 0000000000..068905d73a --- /dev/null +++ b/src/prime_rl/trainer/models/qwen3/__init__.py @@ -0,0 +1,14 @@ +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + +from prime_rl.trainer.models.qwen3.modeling_qwen3 import ( + Qwen3ForCausalLM, + Qwen3Model, + Qwen3PreTrainedModel, +) + +__all__ = [ + "Qwen3Config", + "Qwen3ForCausalLM", + "Qwen3Model", + "Qwen3PreTrainedModel", +] diff --git a/src/prime_rl/trainer/models/qwen3/modeling_qwen3.py b/src/prime_rl/trainer/models/qwen3/modeling_qwen3.py new file mode 100644 index 0000000000..3ba21da045 --- /dev/null +++ b/src/prime_rl/trainer/models/qwen3/modeling_qwen3.py @@ -0,0 +1,500 @@ +import functools +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput +from prime_rl.trainer.models.layers.mlp import MLP, MLPConfig +from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig +from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig, apply_rotary_pos_emb + +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + flash_attn_varlen_func = None # type: ignore + +try: + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +except ImportError: + flash_attn_3_varlen_func = None # type: ignore + +try: + from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func +except ImportError: + flash_attn_4_varlen_func = None # type: ignore + + +@dataclass +class Qwen3AttentionConfig: + hidden_size: int + head_dim: int + num_attention_heads: int + num_key_value_heads: int + rms_norm_eps: float + attention_bias: bool = False + attention_dropout: float = 0.0 + sliding_window: int | None = None + + +class Qwen3RMSNorm(RMSNorm): + pass + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + if n_rep == 1: + return hidden_states + batch, num_kv_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) + + +class Qwen3AttentionBase(nn.Module): + def __init__(self, config: Qwen3AttentionConfig): + super().__init__() + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.sliding_window = config.sliding_window + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) + self.k_norm = Qwen3RMSNorm(RMSNormConfig(hidden_size=self.head_dim, eps=config.rms_norm_eps)) + + def output_proj(self, attn_output: torch.Tensor, input_shape: tuple[int, ...]) -> torch.Tensor: + if attn_output.dim() == 4: + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.contiguous().view(*input_shape, -1) + return self.o_proj(attn_output) + + +class Qwen3SDPAAttention(Qwen3AttentionBase): + def attn_projections( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + return query_states, key_states, value_states + + def _attention_core( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + key_states = _repeat_kv(key_states, self.num_key_value_groups) + value_states = _repeat_kv(value_states, self.num_key_value_groups) + dropout_p = self.attention_dropout if self.training else 0.0 + return F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=dropout_p, + is_causal=attention_mask is None, + scale=self.scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + del cu_seqlens, max_seqlen + input_shape = hidden_states.shape[:-1] + query_states, key_states, value_states = self.attn_projections(hidden_states, position_embeddings) + attn_output = self._attention_core(query_states, key_states, value_states, attention_mask=attention_mask) + return self.output_proj(attn_output, input_shape), None + + +class Qwen3FlashAttention(Qwen3AttentionBase): + _funcs = { + 2: flash_attn_varlen_func, + 3: flash_attn_3_varlen_func, + 4: flash_attn_4_varlen_func, + } + + def __init__(self, config: Qwen3AttentionConfig, flash_attn_version: int = 4): + super().__init__(config) + self._flash_attn_version = flash_attn_version + self.func = self._funcs[flash_attn_version] + self._flash_attn_call = self.func + if self._flash_attn_version == 4: + self._flash_attn_call = torch._dynamo.disable(self.func) + + def _compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens, max_seqlen): + args = [q, k, v, cu_seqlens, cu_seqlens] + if self._flash_attn_version != 4: + args.extend([max_seqlen, max_seqlen]) + kwargs: dict[str, object] = {"causal": True} + if self.sliding_window is not None: + kwargs["window_size"] = (self.sliding_window - 1, 0) + out = self._flash_attn_call(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + return out + + def _attention_core( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> torch.Tensor: + return self._compute_attention(query_states[0], key_states[0], value_states[0], cu_seqlens, max_seqlen) + + def attn_projections( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + return query_states, key_states, value_states + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + del attention_mask + input_shape = hidden_states.shape[:-1] + query_states, key_states, value_states = self.attn_projections(hidden_states, position_embeddings) + attn_output = self._attention_core( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + return self.output_proj(attn_output, input_shape), None + + +QWEN3_ATTN_IMPL2CLASS = { + "sdpa": Qwen3SDPAAttention, + "flash_attention_2": functools.partial(Qwen3FlashAttention, flash_attn_version=2), + "flash_attention_3": functools.partial(Qwen3FlashAttention, flash_attn_version=3), + "fa4": functools.partial(Qwen3FlashAttention, flash_attn_version=4), +} + + +def _create_rotary_emb(config: Qwen3Config) -> RotaryEmbedding: + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + else: + rope_type = "default" + + rotary_config = RotaryEmbeddingConfig( + max_position_embeddings=config.max_position_embeddings, + rope_type=rope_type, + model_config=config, + ) + return RotaryEmbedding(rotary_config) + + +def _get_qwen3_attention(config: Qwen3Config, layer_idx: int) -> nn.Module: + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + attn_config = Qwen3AttentionConfig( + hidden_size=config.hidden_size, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + attention_dropout=config.attention_dropout, + sliding_window=config.sliding_window if is_sliding else None, + ) + + attn_impl = config._attn_implementation + if attn_impl == "eager": + attn_impl = "sdpa" + if attn_impl not in QWEN3_ATTN_IMPL2CLASS: + supported = sorted(QWEN3_ATTN_IMPL2CLASS) + raise ValueError(f"Qwen3 attention does not support {config._attn_implementation!r}. Supported: {supported}") + return QWEN3_ATTN_IMPL2CLASS[attn_impl](attn_config) + + +class Qwen3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = _get_qwen3_attention(config, layer_idx) + self.mlp = MLP( + MLPConfig( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + gate_act=config.hidden_act, + bias=False, + ) + ) + self.input_layernorm = Qwen3RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.post_attention_layernorm = Qwen3RMSNorm( + RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3PreTrainedModel(PreTrainedModelPrimeRL): + config: Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = True + _can_compile_fullgraph = True + _can_record_outputs = { + "hidden_states": Qwen3DecoderLayer, + } + + @classmethod + def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return True + + @classmethod + def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return True + + @classmethod + def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return state_dict + + @classmethod + def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return state_dict + + @classmethod + def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + del layer_idx + return state_dict + + @classmethod + def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + del layer_idx + return state_dict + + +class Qwen3Model(Qwen3PreTrainedModel): + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.rms_norm_eps)) + self.rotary_emb = _create_rotary_emb(config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in config.layer_types + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if not isinstance(attention_mask, dict): + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device), + "past_key_values": None, + "position_ids": position_ids, + } + causal_mask_mapping: dict[str, torch.Tensor] = {"full_attention": create_causal_mask(**mask_kwargs)} + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + else: + causal_mask_mapping = attention_mask + + if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"): + flat_position_ids = position_ids.view(-1) + seqlens = torch.cat( + [ + flat_position_ids[0:1], + flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1, + flat_position_ids[-1:] + 1, + ] + ) + max_seqlen = seqlens.max().item() + cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) + torch._dynamo.mark_dynamic(cu_seqlens, 0) + else: + max_seqlen = None + cu_seqlens = None + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> PrimeLmOutput: + del cache_position, kwargs + assert use_cache is None, "use_cache is not supported for custom qwen3 for now" + assert past_key_values is None, "past_key_values is not supported for custom qwen3 for now" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + return self.lm_head( + hidden_states[:, slice_indices, :], + labels[:, slice_indices] if labels is not None else None, + temperature=temperature, + ) + + def init_buffers_post_meta(self) -> None: + buffer_names = [name for name, _ in self.named_buffers()] + if "model.rotary_emb.inv_freq" in buffer_names: + rotary_emb = self.model.rotary_emb + inv_freq, rotary_emb.attention_scaling = rotary_emb.rope_init_fn( + rotary_emb.config, rotary_emb.inv_freq.device + ) + rotary_emb.inv_freq.copy_(inv_freq) + + +__all__ = ["Qwen3ForCausalLM", "Qwen3Model", "Qwen3PreTrainedModel"] diff --git a/src/prime_rl/trainer/models/qwen3_5/__init__.py b/src/prime_rl/trainer/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..0d143a5305 --- /dev/null +++ b/src/prime_rl/trainer/models/qwen3_5/__init__.py @@ -0,0 +1,14 @@ +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig + +from prime_rl.trainer.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5ForCausalLM, + Qwen3_5PreTrainedModel, + Qwen3_5TextModel, +) + +__all__ = [ + "Qwen3_5TextConfig", + "Qwen3_5ForCausalLM", + "Qwen3_5TextModel", + "Qwen3_5PreTrainedModel", +] diff --git a/src/prime_rl/trainer/models/qwen3_5/modeling_qwen3_5.py b/src/prime_rl/trainer/models/qwen3_5/modeling_qwen3_5.py new file mode 100644 index 0000000000..8748e2662e --- /dev/null +++ b/src/prime_rl/trainer/models/qwen3_5/modeling_qwen3_5.py @@ -0,0 +1,444 @@ +import functools +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs +from transformers.utils.generic import maybe_autocast + +from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput +from prime_rl.trainer.models.layers.mlp import MLP, MLPConfig +from prime_rl.trainer.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeGatedAttentionConfig, + Qwen3_5MoeGatedDeltaNet, + Qwen3_5MoeGatedFlashAttention, + Qwen3_5MoeGatedSDPAAttention, + Qwen3_5MoeRMSNorm, + _repeat_kv, +) + + +class Qwen3_5TextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: Qwen3_5TextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + self.rope_init_fn = self.compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) + + @staticmethod + def compute_default_rope_parameters( + config: Qwen3_5TextConfig | None = None, + device: Optional[torch.device] = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + del seq_len + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x: torch.Tensor, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + freqs_t = freqs[0] + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + +Qwen3_5RMSNorm = Qwen3_5MoeRMSNorm + + +def apply_mask_to_padding_states(hidden_states: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + hidden_states = (hidden_states * attention_mask[:, :, None]).to(hidden_states.dtype) + return hidden_states + + +class Qwen3_5GatedDeltaNet(Qwen3_5MoeGatedDeltaNet): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + return super().forward(hidden_states) + + +class Qwen3_5SDPAAttention(Qwen3_5MoeGatedSDPAAttention): + def _attention_core( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + key_states = _repeat_kv(key_states, self.num_key_value_groups) + value_states = _repeat_kv(value_states, self.num_key_value_groups) + dropout_p = self.attention_dropout if self.training else 0.0 + return F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=dropout_p, + is_causal=attention_mask is None, + scale=self.scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + del cu_seqlens, max_seqlen + query_states, key_states, value_states, gate = self.attn_projections(hidden_states, position_embeddings) + attn_output = self._attention_core( + query_states, + key_states, + value_states, + attention_mask=attention_mask, + ) + return self.output_proj(attn_output, gate), None + + +class Qwen3_5FlashAttention(Qwen3_5MoeGatedFlashAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + del attention_mask + query_states, key_states, value_states, gate = self.attn_projections(hidden_states, position_embeddings) + attn_output = self._attention_core( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + return self.output_proj(attn_output, gate), None + + +QWEN35_ATTN_IMPL2CLASS = { + "sdpa": Qwen3_5SDPAAttention, + "flash_attention_2": functools.partial(Qwen3_5FlashAttention, flash_attn_version=2), + "flash_attention_3": functools.partial(Qwen3_5FlashAttention, flash_attn_version=3), + "fa4": functools.partial(Qwen3_5FlashAttention, flash_attn_version=4), +} + + +def _get_qwen3_5_attention(config: Qwen3_5TextConfig) -> nn.Module: + attn_config = Qwen3_5MoeGatedAttentionConfig( + hidden_size=config.hidden_size, + head_dim=config.head_dim, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + attention_dropout=config.attention_dropout, + ) + + attn_impl = config._attn_implementation + if attn_impl == "eager": + attn_impl = "sdpa" + if attn_impl not in QWEN35_ATTN_IMPL2CLASS: + supported = sorted(QWEN35_ATTN_IMPL2CLASS) + raise ValueError(f"Qwen3.5 attention does not support {config._attn_implementation!r}. Supported: {supported}") + return QWEN35_ATTN_IMPL2CLASS[attn_impl](attn_config) + + +class Qwen3_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3_5TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5GatedDeltaNet(config) + elif self.layer_type == "full_attention": + self.self_attn = _get_qwen3_5_attention(config) + self.mlp = MLP( + MLPConfig( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + gate_act=config.hidden_act, + bias=False, + ) + ) + self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + max_seqlen: int | None = None, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn(hidden_states, attention_mask=attention_mask) + else: + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3_5PreTrainedModel(PreTrainedModelPrimeRL): + config: Qwen3_5TextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3_5DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = False + _supports_attention_backend = True + _can_compile_fullgraph = False + _can_record_outputs = { + "hidden_states": Qwen3_5DecoderLayer, + } + + @classmethod + def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return True + + @classmethod + def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return True + + @classmethod + def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return state_dict + + @classmethod + def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return state_dict + + @classmethod + def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + del layer_idx + return state_dict + + @classmethod + def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + del layer_idx + return state_dict + + +class Qwen3_5TextModel(Qwen3_5PreTrainedModel): + def __init__(self, config: Qwen3_5TextConfig): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [Qwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(4, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + rotary_position_ids = position_ids[1:] + else: + text_position_ids = position_ids if position_ids.ndim == 2 else None + rotary_position_ids = position_ids + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + if self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4"): + base_position_ids = text_position_ids + if base_position_ids is None: + base_position_ids = rotary_position_ids[0] if rotary_position_ids.ndim == 3 else rotary_position_ids + flat_position_ids = base_position_ids.view(-1) + seqlens = torch.cat( + [ + flat_position_ids[0:1], + flat_position_ids[:-1][(flat_position_ids == 0)[1:]] + 1, + flat_position_ids[-1:] + 1, + ] + ) + max_seqlen = seqlens.max().item() + cu_seqlens = seqlens.cumsum(dim=0, dtype=torch.int32) + torch._dynamo.mark_dynamic(cu_seqlens, 0) + else: + max_seqlen = None + cu_seqlens = None + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, rotary_position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + def _update_linear_attn_mask( + self, + attention_mask: torch.Tensor | None, + cache_position: torch.LongTensor, + ) -> torch.Tensor | None: + linear_attn_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + linear_attn_mask = None + return linear_attn_mask + + +class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config: Qwen3_5TextConfig): + super().__init__(config) + self.model = Qwen3_5TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> PrimeLmOutput: + del cache_position, kwargs + assert use_cache is None, "use_cache is not supported for custom qwen3.5 for now" + assert past_key_values is None, "past_key_values is not supported for custom qwen3.5 for now" + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + return self.lm_head( + hidden_states[:, slice_indices, :], + labels[:, slice_indices] if labels is not None else None, + temperature=temperature, + ) + + def init_buffers_post_meta(self) -> None: + buffer_names = [name for name, _ in self.named_buffers()] + if "model.rotary_emb.inv_freq" in buffer_names: + rotary_emb = self.model.rotary_emb + inv_freq, rotary_emb.attention_scaling = rotary_emb.rope_init_fn( + rotary_emb.config, rotary_emb.inv_freq.device + ) + rotary_emb.inv_freq.copy_(inv_freq) + rotary_emb.original_inv_freq.copy_(inv_freq) + + +__all__ = ["Qwen3_5ForCausalLM", "Qwen3_5TextModel", "Qwen3_5PreTrainedModel"]