Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ class LoRAConfig(BaseConfig):
),
] = 16

moe_lora_mode: Annotated[
Literal["per_projection", "perft_e"],
Field(
description="MoE LoRA strategy. 'per_projection' applies separate LoRA to each of w1/w2/w3. 'perft_e' applies a single bypass LoRA to the entire MoE block (PErFT-E).",
),
] = "per_projection"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New config fields missing CHANGELOG entry

Low Severity

Two new config fields are added to LoRAConfig in src/prime_rl/configs/trainer.pymoe_lora_mode and save_dtype — without a corresponding CHANGELOG.md update. Per the project rule, any PR modifying configuration structures (including added fields) in src/prime_rl/*/config.py must update the changelog.

Additional Locations (1)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions


alpha: Annotated[
float,
Field(
Expand Down Expand Up @@ -133,6 +140,13 @@ class LoRAConfig(BaseConfig):
"experts",
]

save_dtype: Annotated[
Literal["bfloat16", "float32"],
Field(
description="Dtype to cast adapter weights to when saving.",
),
] = "bfloat16"

modules_to_save: Annotated[
list[str],
Field(
Expand Down
18 changes: 12 additions & 6 deletions src/prime_rl/trainer/ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from prime_rl.utils.logger import get_logger
from prime_rl.utils.utils import get_all_ckpt_steps, get_ckpt_dir, get_step_path, get_weights_dir

DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate DTYPE_MAP definitions across files

Low Severity

This PR introduces two new identical DTYPE_MAP / _DTYPE_MAP dictionaries (in ckpt.py and filesystem.py) that duplicate the existing DTYPE_MAP already defined in src/prime_rl/trainer/model.py. All three map the same two strings to the same torch dtypes. A single shared definition would avoid inconsistency risk if new dtypes are added later.

Additional Locations (1)
Fix in Cursor Fix in Web



def _try_rmtree(path: Path, logger) -> None:
"""Remove a directory tree, logging and skipping on failure."""
Expand Down Expand Up @@ -306,12 +311,13 @@ def mark_stable(self, step: int) -> None:
(step_path / "STABLE").touch()

def get_run_adapter_state_dict(self) -> dict[str, Tensor]:
lora_state_dict = {
f"base_model.model.{key}": (value.full_tensor() if isinstance(value, DTensor) else value).to(
"cpu", non_blocking=False
)
for key, value in get_multi_run_manager().get_state_dict_for_run(0).items()
}
assert self.lora_config is not None, "LoRA config is required to get run adapter state dict"
save_dtype = DTYPE_MAP[self.lora_config.save_dtype]
lora_state_dict = {}
for key, value in get_multi_run_manager().get_state_dict_for_run(0).items():
tensor = value.full_tensor() if isinstance(value, DTensor) else value
tensor = tensor.to(device="cpu", dtype=save_dtype, non_blocking=False)
lora_state_dict[f"base_model.model.{key}"] = tensor

if not lora_state_dict:
raise ValueError("The LoRA state dict is empty. Something went wrong.")
Expand Down
4 changes: 3 additions & 1 deletion src/prime_rl/trainer/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from prime_rl.configs.trainer import LoRAConfig
from prime_rl.trainer.models.layers.lora import MultiLoRALinear, MultiLoRAModule
from prime_rl.trainer.models.layers.lora.fused_moe import MultiLoRAPERFTE
from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts
from prime_rl.trainer.models.layers.moe import GroupedExperts
from prime_rl.trainer.runs import get_multi_run_manager
Expand Down Expand Up @@ -171,7 +172,8 @@ def apply_lora_to_model(model: nn.Module, config: LoRAConfig) -> None:
)
# Handle GroupedExperts (MoE)
elif isinstance(base_module, GroupedExperts):
lora_module = MultiLoRAGroupedExperts(
moe_cls = MultiLoRAPERFTE if config.moe_lora_mode == "perft_e" else MultiLoRAGroupedExperts
lora_module = moe_cls(
base_layer=base_module,
rank=config.rank,
n_adapters=n_loras,
Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/trainer/models/layers/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
set_lora_num_tokens,
set_multilora_scaling,
)
from prime_rl.trainer.models.layers.lora.fused_moe import MultiLoRAPERFTE
from prime_rl.trainer.models.layers.lora.multi_linear import MultiLoRALinear
from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts

__all__ = [
"MultiLoRAModule",
"MultiLoRALinear",
"MultiLoRAGroupedExperts",
"MultiLoRAPERFTE",
"set_lora_num_tokens",
"get_lora_num_tokens",
"set_multilora_scaling",
Expand Down
177 changes: 177 additions & 0 deletions src/prime_rl/trainer/models/layers/lora/fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import math

import torch
from torch import nn
from torch.distributed.tensor import DTensor

from prime_rl.trainer.models.layers.lora.base import MultiLoRAModule, get_lora_num_tokens, get_multilora_scaling
from prime_rl.trainer.models.layers.lora.multi_moe import _run_lora_for_loop, _run_lora_grouped_mm
from prime_rl.trainer.models.layers.moe import GroupedExperts


class MultiLoRAPERFTE(MultiLoRAModule):
"""
GroupedExperts + multi-LoRA as a single bypass path.
Runs the base MoE unmodified and adds a single LoRA: out = moe(x) + B @ A @ x.
"""

def __init__(
self,
base_layer: GroupedExperts,
rank: int,
n_adapters: int,
alpha: float = 32.0,
dropout: float = 0.0,
use_grouped_mm: bool = True,
):
super().__init__(base_layer)
if rank <= 0 or n_adapters <= 0:
raise ValueError("rank and n_adapters must be > 0")

self.num_experts = base_layer.num_experts
self.dim = base_layer.w1.shape[2]

if rank % 8 != 0 or self.dim % 8 != 0:
use_grouped_mm = False

self.rank = rank
self.n_adapters = n_adapters
self.alpha = alpha
self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
self.use_grouped_mm = use_grouped_mm

self._lora_num_tokens = get_lora_num_tokens()
self._scaling_factors = get_multilora_scaling()

# Single LoRA pair per adapter: A maps dim -> rank, B maps rank -> dim
self.lora_A = nn.ParameterList(
[
nn.Parameter(
torch.empty(
self.num_experts,
rank,
self.dim,
device=base_layer.w1.device,
dtype=base_layer.w1.dtype,
)
)
for _ in range(n_adapters)
]
)
self.lora_B = nn.ParameterList(
[
nn.Parameter(
torch.empty(
self.num_experts,
self.dim,
rank,
device=base_layer.w1.device,
dtype=base_layer.w1.dtype,
)
)
for _ in range(n_adapters)
]
)

self.reset_parameters()

def reset_parameters(self, index: int | None = None) -> None:
if index is None:
for i in range(self.n_adapters):
self.reset_parameters(i)
else:
nn.init.kaiming_uniform_(self.lora_A[index], a=math.sqrt(5))
nn.init.zeros_(self.lora_B[index])

def named_parameters_for_adapter(self, idx: int) -> list[tuple[str, nn.Parameter]]:
return [
("lora_A", self.lora_A[idx]), # [num_experts, rank, dim]
("lora_B", self.lora_B[idx]), # [num_experts, dim, rank]
]

def get_lora_param_counts(self) -> tuple[int, int]:
adapter_params = self.lora_A[0].numel() + self.lora_B[0].numel()
adapted_params = self.base_layer.w1.numel() + self.base_layer.w2.numel() + self.base_layer.w3.numel()
return adapter_params, adapted_params

def state_dict_for_adapter(self, idx: int) -> dict[str, torch.Tensor]:
"""Get state dict for a specific adapter as 3D tensors.

Returns:
Dict with keys "lora_A.weight" [num_experts, rank, dim]
and "lora_B.weight" [num_experts, dim, rank].
"""
detached_a = self.lora_A[idx].detach()
detached_b = self.lora_B[idx].detach()

if isinstance(detached_a, DTensor):
detached_a = detached_a.full_tensor()
detached_b = detached_b.full_tensor()

return {
"lora_A.weight": detached_a.clone(),
"lora_B.weight": detached_b.clone(),
}

def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
# Base MoE computation (EP handled by @expert_parallel decorator)
y_moe = self.base_layer(x, num_tokens_per_expert)

# Select active adapter
adapter_idx = self._lora_num_tokens.argmax().item()
lora_a = self.lora_A[adapter_idx]
lora_b = self.lora_B[adapter_idx]
scaling = self._scaling_factors[adapter_idx].item()

# EP handling for LoRA path
permuted_indices = None
if isinstance(lora_a, DTensor):
from torchtitan.distributed.expert_parallel import TOKEN_GROUP_ALIGN_SIZE_M
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

lora_a = lora_a.to_local()
lora_b = lora_b.to_local()

experts_per_ep_rank = lora_a.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
permuted_indices, num_tokens_per_expert, _ = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

# LoRA path
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
lora_x = self.lora_dropout(x)

if self.use_grouped_mm:
lora_out = _run_lora_grouped_mm(lora_x, lora_a, lora_b, offsets)
else:
lora_out = _run_lora_for_loop(lora_x, lora_a, lora_b, num_tokens_per_expert)

# EP unpermute
if permuted_indices is not None:
if lora_out.shape[0] < len(permuted_indices):
num_padding = len(permuted_indices) - lora_out.shape[0]
lora_out = torch.vstack((lora_out, lora_out.new_zeros((num_padding, lora_out.shape[-1]))))
out_unpermuted = lora_out.new_zeros(input_shape)
out_unpermuted[permuted_indices, :] = lora_out
lora_out = out_unpermuted[:-1]

return y_moe + scaling * lora_out

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(base={self.base_layer}, rank={self.rank}, "
f"n_adapters={self.n_adapters}, num_experts={self.num_experts}, "
f"alpha={self.alpha}, dropout={self.lora_dropout}, "
f"use_grouped_mm={self.use_grouped_mm})"
)
9 changes: 8 additions & 1 deletion src/prime_rl/trainer/rl/broadcast/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Literal

import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor

Expand All @@ -19,6 +20,11 @@
from prime_rl.trainer.world import get_world
from prime_rl.utils.utils import get_broadcast_dir, get_step_path

_DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}


class FileSystemWeightBroadcast(WeightBroadcast):
"""Broadcast weights into the inference engine via shared filesystem."""
Expand Down Expand Up @@ -60,11 +66,12 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None:
# For adapter-only, MultiRunManager creates state dict directly for each run
# All ranks must participate in DTensor gathering, but only master saves
state_dict = self.multi_run_manager.get_state_dict_for_run(idx)
save_dtype = _DTYPE_MAP[self.lora_config.save_dtype]
for key, value in state_dict.items():
if isinstance(value, DTensor):
value = value.full_tensor()
if self.world.is_master:
state_dict[key] = value.to("cpu", non_blocking=False)
state_dict[key] = value.to(device="cpu", dtype=save_dtype, non_blocking=False)

# TODO: Broadcast ready to update in sync, then we dont need to gather on not ready
if self.world.is_master:
Expand Down
Loading