diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index d8291ce4a2..023dcd984b 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -100,6 +100,13 @@ class LoRAConfig(BaseConfig): ), ] = 16 + moe_lora_mode: Annotated[ + Literal["per_projection", "perft_e"], + Field( + description="MoE LoRA strategy. 'per_projection' applies separate LoRA to each of w1/w2/w3. 'perft_e' applies a single bypass LoRA to the entire MoE block (PErFT-E).", + ), + ] = "per_projection" + alpha: Annotated[ float, Field( @@ -133,6 +140,13 @@ class LoRAConfig(BaseConfig): "experts", ] + save_dtype: Annotated[ + Literal["bfloat16", "float32"], + Field( + description="Dtype to cast adapter weights to when saving.", + ), + ] = "bfloat16" + modules_to_save: Annotated[ list[str], Field( diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index 4734cfc563..17e3084783 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -32,6 +32,11 @@ from prime_rl.utils.logger import get_logger from prime_rl.utils.utils import get_all_ckpt_steps, get_ckpt_dir, get_step_path, get_weights_dir +DTYPE_MAP = { + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + def _try_rmtree(path: Path, logger) -> None: """Remove a directory tree, logging and skipping on failure.""" @@ -306,12 +311,13 @@ def mark_stable(self, step: int) -> None: (step_path / "STABLE").touch() def get_run_adapter_state_dict(self) -> dict[str, Tensor]: - lora_state_dict = { - f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to( - "cpu", non_blocking=False - ) - for key, value in get_multi_run_manager().get_state_dict_for_run(0).items() - } + assert self.lora_config is not None, "LoRA config is required to get run adapter state dict" + save_dtype = DTYPE_MAP[self.lora_config.save_dtype] + lora_state_dict = {} + for key, value in get_multi_run_manager().get_state_dict_for_run(0).items(): + tensor = value.full_tensor() if isinstance(value, DTensor) else value + tensor = tensor.to(device="cpu", dtype=save_dtype, non_blocking=False) + lora_state_dict[f"base_model.model.{key}"] = tensor if not lora_state_dict: raise ValueError("The LoRA state dict is empty. Something went wrong.") diff --git a/src/prime_rl/trainer/lora.py b/src/prime_rl/trainer/lora.py index 5afac23363..aa2f501fdb 100644 --- a/src/prime_rl/trainer/lora.py +++ b/src/prime_rl/trainer/lora.py @@ -7,6 +7,7 @@ from prime_rl.configs.trainer import LoRAConfig from prime_rl.trainer.models.layers.lora import MultiLoRALinear, MultiLoRAModule from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts +from prime_rl.trainer.models.layers.lora.multi_perfte import MultiLoRAPERFTE from prime_rl.trainer.models.layers.moe import GroupedExperts from prime_rl.trainer.runs import get_multi_run_manager from prime_rl.utils.logger import get_logger @@ -171,7 +172,8 @@ def apply_lora_to_model(model: nn.Module, config: LoRAConfig) -> None: ) # Handle GroupedExperts (MoE) elif isinstance(base_module, GroupedExperts): - lora_module = MultiLoRAGroupedExperts( + moe_cls = MultiLoRAPERFTE if config.moe_lora_mode == "perft_e" else MultiLoRAGroupedExperts + lora_module = moe_cls( base_layer=base_module, rank=config.rank, n_adapters=n_loras, diff --git a/src/prime_rl/trainer/models/layers/lora/__init__.py b/src/prime_rl/trainer/models/layers/lora/__init__.py index ecea8bfe3d..275bae1ace 100644 --- a/src/prime_rl/trainer/models/layers/lora/__init__.py +++ b/src/prime_rl/trainer/models/layers/lora/__init__.py @@ -7,11 +7,13 @@ ) from prime_rl.trainer.models.layers.lora.multi_linear import MultiLoRALinear from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts +from prime_rl.trainer.models.layers.lora.multi_perfte import MultiLoRAPERFTE __all__ = [ "MultiLoRAModule", "MultiLoRALinear", "MultiLoRAGroupedExperts", + "MultiLoRAPERFTE", "set_lora_num_tokens", "get_lora_num_tokens", "set_multilora_scaling", diff --git a/src/prime_rl/trainer/models/layers/lora/multi_perfte.py b/src/prime_rl/trainer/models/layers/lora/multi_perfte.py new file mode 100644 index 0000000000..da96895965 --- /dev/null +++ b/src/prime_rl/trainer/models/layers/lora/multi_perfte.py @@ -0,0 +1,177 @@ +import math + +import torch +from torch import nn +from torch.distributed.tensor import DTensor + +from prime_rl.trainer.models.layers.lora.base import MultiLoRAModule, get_lora_num_tokens, get_multilora_scaling +from prime_rl.trainer.models.layers.lora.multi_moe import _run_lora_for_loop, _run_lora_grouped_mm +from prime_rl.trainer.models.layers.moe import GroupedExperts + + +class MultiLoRAPERFTE(MultiLoRAModule): + """ + GroupedExperts + multi-LoRA as a single bypass path. + Runs the base MoE unmodified and adds a single LoRA: out = moe(x) + B @ A @ x. + """ + + def __init__( + self, + base_layer: GroupedExperts, + rank: int, + n_adapters: int, + alpha: float = 32.0, + dropout: float = 0.0, + use_grouped_mm: bool = True, + ): + super().__init__(base_layer) + if rank <= 0 or n_adapters <= 0: + raise ValueError("rank and n_adapters must be > 0") + + self.num_experts = base_layer.num_experts + self.dim = base_layer.w1.shape[2] + + if rank % 8 != 0 or self.dim % 8 != 0: + use_grouped_mm = False + + self.rank = rank + self.n_adapters = n_adapters + self.alpha = alpha + self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + self.use_grouped_mm = use_grouped_mm + + self._lora_num_tokens = get_lora_num_tokens() + self._scaling_factors = get_multilora_scaling() + + # Single LoRA pair per adapter: A maps dim -> rank, B maps rank -> dim + self.lora_A = nn.ParameterList( + [ + nn.Parameter( + torch.empty( + self.num_experts, + rank, + self.dim, + device=base_layer.w1.device, + dtype=base_layer.w1.dtype, + ) + ) + for _ in range(n_adapters) + ] + ) + self.lora_B = nn.ParameterList( + [ + nn.Parameter( + torch.empty( + self.num_experts, + self.dim, + rank, + device=base_layer.w1.device, + dtype=base_layer.w1.dtype, + ) + ) + for _ in range(n_adapters) + ] + ) + + self.reset_parameters() + + def reset_parameters(self, index: int | None = None) -> None: + if index is None: + for i in range(self.n_adapters): + self.reset_parameters(i) + else: + nn.init.kaiming_uniform_(self.lora_A[index], a=math.sqrt(5)) + nn.init.zeros_(self.lora_B[index]) + + def named_parameters_for_adapter(self, idx: int) -> list[tuple[str, nn.Parameter]]: + return [ + ("lora_A", self.lora_A[idx]), # [num_experts, rank, dim] + ("lora_B", self.lora_B[idx]), # [num_experts, dim, rank] + ] + + def get_lora_param_counts(self) -> tuple[int, int]: + adapter_params = self.lora_A[0].numel() + self.lora_B[0].numel() + adapted_params = self.base_layer.w1.numel() + self.base_layer.w2.numel() + self.base_layer.w3.numel() + return adapter_params, adapted_params + + def state_dict_for_adapter(self, idx: int) -> dict[str, torch.Tensor]: + """Get state dict for a specific adapter as 3D tensors. + + Returns: + Dict with keys "lora_A.weight" [num_experts, rank, dim] + and "lora_B.weight" [num_experts, dim, rank]. + """ + detached_a = self.lora_A[idx].detach() + detached_b = self.lora_B[idx].detach() + + if isinstance(detached_a, DTensor): + detached_a = detached_a.full_tensor() + detached_b = detached_b.full_tensor() + + return { + "lora_A.weight": detached_a.clone(), + "lora_B.weight": detached_b.clone(), + } + + def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + # Base MoE computation (EP handled by @expert_parallel decorator) + y_moe = self.base_layer(x, num_tokens_per_expert) + + # Select active adapter + adapter_idx = self._lora_num_tokens.argmax().item() + lora_a = self.lora_A[adapter_idx] + lora_b = self.lora_B[adapter_idx] + scaling = self._scaling_factors[adapter_idx].item() + + # EP handling for LoRA path + permuted_indices = None + if isinstance(lora_a, DTensor): + from torchtitan.distributed.expert_parallel import TOKEN_GROUP_ALIGN_SIZE_M + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices + + lora_a = lora_a.to_local() + lora_b = lora_b.to_local() + + experts_per_ep_rank = lora_a.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + with torch.no_grad(): + permuted_indices, num_tokens_per_expert, _ = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + # LoRA path + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + lora_x = self.lora_dropout(x) + + if self.use_grouped_mm: + lora_out = _run_lora_grouped_mm(lora_x, lora_a, lora_b, offsets) + else: + lora_out = _run_lora_for_loop(lora_x, lora_a, lora_b, num_tokens_per_expert) + + # EP unpermute + if permuted_indices is not None: + if lora_out.shape[0] < len(permuted_indices): + num_padding = len(permuted_indices) - lora_out.shape[0] + lora_out = torch.vstack((lora_out, lora_out.new_zeros((num_padding, lora_out.shape[-1])))) + out_unpermuted = lora_out.new_zeros(input_shape) + out_unpermuted[permuted_indices, :] = lora_out + lora_out = out_unpermuted[:-1] + + return y_moe + scaling * lora_out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(base={self.base_layer}, rank={self.rank}, " + f"n_adapters={self.n_adapters}, num_experts={self.num_experts}, " + f"alpha={self.alpha}, dropout={self.lora_dropout}, " + f"use_grouped_mm={self.use_grouped_mm})" + ) diff --git a/src/prime_rl/trainer/rl/broadcast/filesystem.py b/src/prime_rl/trainer/rl/broadcast/filesystem.py index 55a92c832d..0674499179 100644 --- a/src/prime_rl/trainer/rl/broadcast/filesystem.py +++ b/src/prime_rl/trainer/rl/broadcast/filesystem.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Literal +import torch import torch.nn as nn from torch.distributed.tensor import DTensor @@ -19,6 +20,11 @@ from prime_rl.trainer.world import get_world from prime_rl.utils.utils import get_broadcast_dir, get_step_path +_DTYPE_MAP = { + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + class FileSystemWeightBroadcast(WeightBroadcast): """Broadcast weights into the inference engine via shared filesystem.""" @@ -60,11 +66,12 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None: # For adapter-only, MultiRunManager creates state dict directly for each run # All ranks must participate in DTensor gathering, but only master saves state_dict = self.multi_run_manager.get_state_dict_for_run(idx) + save_dtype = _DTYPE_MAP[self.lora_config.save_dtype] for key, value in state_dict.items(): if isinstance(value, DTensor): value = value.full_tensor() if self.world.is_master: - state_dict[key] = value.to("cpu", non_blocking=False) + state_dict[key] = value.to(device="cpu", dtype=save_dtype, non_blocking=False) # TODO: Broadcast ready to update in sync, then we dont need to gather on not ready if self.world.is_master: