diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7bf13850..ee8071bad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes). +- **`trainer.model.lora.init_adapter_path`**: Added optional RL LoRA warm-start path. Loads a saved LoRA adapter into the created adapter slot before training updates. (2026-04-01) - **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31) - **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31) - **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27) diff --git a/configs/ci/integration/rl_lora_init/resume.toml b/configs/ci/integration/rl_lora_init/resume.toml new file mode 100644 index 0000000000..e4e324be75 --- /dev/null +++ b/configs/ci/integration/rl_lora_init/resume.toml @@ -0,0 +1,34 @@ +max_steps = 25 +seq_len = 2048 + +[ckpt] +resume_step = 20 + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[trainer.optim] +lr = 2e-5 + +[trainer.model.lora] +rank = 8 + +[trainer.ckpt.weights] +save_adapter_separately = true + +[orchestrator] +batch_size = 64 +rollouts_per_example = 16 + +[orchestrator.model.lora] +name = "r8-init-cont" + +[orchestrator.sampling] +max_tokens = 128 + +[[orchestrator.env]] +id = "reverse-text" + +[inference] +enable_lora = true +gpu_memory_utilization = 0.7 diff --git a/configs/ci/integration/rl_lora_init/sft.toml b/configs/ci/integration/rl_lora_init/sft.toml new file mode 100644 index 0000000000..cf77873551 --- /dev/null +++ b/configs/ci/integration/rl_lora_init/sft.toml @@ -0,0 +1,29 @@ +max_steps = 100 + +[ckpt] + +[ckpt.weights] +save_adapter_separately = true + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 16 +seq_len = 1024 + +[optim] +lr = 2e-4 diff --git a/configs/ci/integration/rl_lora_init/start.toml b/configs/ci/integration/rl_lora_init/start.toml new file mode 100644 index 0000000000..082e71739f --- /dev/null +++ b/configs/ci/integration/rl_lora_init/start.toml @@ -0,0 +1,33 @@ +max_steps = 20 +seq_len = 2048 + +[ckpt] + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[trainer.optim] +lr = 2e-5 + +[trainer.model.lora] +rank = 8 + +[trainer.ckpt.weights] +save_adapter_separately = true + +[orchestrator] +batch_size = 64 +rollouts_per_example = 16 + +[orchestrator.model.lora] +name = "r8-init-cont" + +[orchestrator.sampling] +max_tokens = 128 + +[[orchestrator.env]] +id = "reverse-text" + +[inference] +enable_lora = true +gpu_memory_utilization = 0.7 diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index 1508a0a683..723815c8d2 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -158,6 +158,13 @@ class LoRAConfig(BaseConfig): ), ] = [] + init_adapter_path: Annotated[ + Path | None, + Field( + description="Optional adapter path to warm-start training. Loaded into the created adapter slot before RL updates.", + ), + ] = None + class DebugModelConfig(BaseConfig): """Debugging feature around model and distributed training.""" diff --git a/src/prime_rl/trainer/lora.py b/src/prime_rl/trainer/lora.py index 5afac23363..c82ae6ac5b 100644 --- a/src/prime_rl/trainer/lora.py +++ b/src/prime_rl/trainer/lora.py @@ -1,8 +1,13 @@ +import json import re +from dataclasses import dataclass +from pathlib import Path from typing import Dict, List import torch import torch.nn as nn +from safetensors.torch import load_file +from torch.distributed.tensor import DTensor, distribute_tensor from prime_rl.configs.trainer import LoRAConfig from prime_rl.trainer.models.layers.lora import MultiLoRALinear, MultiLoRAModule @@ -11,6 +16,12 @@ from prime_rl.trainer.runs import get_multi_run_manager from prime_rl.utils.logger import get_logger +_MOE_LORA_KEY_RE = re.compile( + r"(?P.*\.experts)\.(?P\d+)\.(?Pgate_proj|down_proj|up_proj)\.(?Plora_[AB])(?:\.(?:default|\d+))?(?:\.weight)?" +) +ADAPTER_CONFIG_NAME = "adapter_config.json" +ADAPTER_WEIGHT_FILENAMES = ("adapter_model.safetensors", "adapter_model.bin") + def strip_lora_from_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Strip LoRA from the state dict.""" @@ -22,6 +33,13 @@ def strip_lora_from_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, return new_state_dict +def _strip_adapter_export_prefix(key: str) -> str: + prefix = "base_model.model." + if key.startswith(prefix): + return key[len(prefix) :] + return key + + def _get_module_by_name(model: nn.Module, module_name: str) -> nn.Module: """Get a module by its fully qualified name.""" parts = module_name.split(".") @@ -160,7 +178,6 @@ def apply_lora_to_model(model: nn.Module, config: LoRAConfig) -> None: for module_name in target_modules: base_module = _get_module_by_name(model, module_name) - # Handle Linear layers if isinstance(base_module, nn.Linear): lora_module = MultiLoRALinear( base_layer=base_module, @@ -276,6 +293,230 @@ def save_lora_config(model: nn.Module, save_path, rank: int, alpha: float, dropo "modules_to_save": sorted(list(modules_to_save)) if modules_to_save else None, } - config_path = save_path / "adapter_config.json" + config_path = save_path / ADAPTER_CONFIG_NAME with open(config_path, "w") as f: json.dump(adapter_config, f, indent=2) + + +def resolve_adapter_dir(init_adapter_path: Path) -> Path: + """Resolve init_adapter_path to the adapter directory containing adapter_config.json.""" + if init_adapter_path.is_file(): + if init_adapter_path.name not in ADAPTER_WEIGHT_FILENAMES: + raise ValueError( + f"init_adapter_path file must be one of {ADAPTER_WEIGHT_FILENAMES}, got {init_adapter_path}" + ) + adapter_dir = init_adapter_path.parent + elif init_adapter_path.is_dir(): + adapter_dir = init_adapter_path + else: + raise ValueError(f"init_adapter_path does not exist: {init_adapter_path}") + if not (adapter_dir / ADAPTER_CONFIG_NAME).exists(): + raise ValueError(f"Adapter directory is missing {ADAPTER_CONFIG_NAME}: {adapter_dir}") + return adapter_dir + + +def _resolve_adapter_weights(adapter_dir: Path) -> Path: + for name in ADAPTER_WEIGHT_FILENAMES: + candidate = adapter_dir / name + if candidate.exists(): + return candidate + raise ValueError(f"No {ADAPTER_WEIGHT_FILENAMES} found in {adapter_dir}") + + +def _load_adapter_config(adapter_dir: Path, lora_config: LoRAConfig) -> dict: + with open(adapter_dir / ADAPTER_CONFIG_NAME) as f: + adapter_config = json.load(f) + if adapter_config.get("peft_type") != "LORA": + raise ValueError( + f"init_adapter_path only supports LoRA adapters, got peft_type={adapter_config.get('peft_type')}" + ) + if adapter_config.get("r") != lora_config.rank: + raise ValueError(f"init_adapter_path rank mismatch: expected {lora_config.rank}, got {adapter_config.get('r')}") + if adapter_config.get("lora_alpha") != lora_config.alpha: + raise ValueError( + f"init_adapter_path alpha mismatch: expected {lora_config.alpha}, got {adapter_config.get('lora_alpha')}" + ) + modules_to_save = adapter_config.get("modules_to_save") + if modules_to_save not in (None, []): + raise ValueError("init_adapter_path does not support adapters with modules_to_save") + return adapter_config + + +def _normalize_lora_key(key: str) -> str: + key = _strip_adapter_export_prefix(key) + if key.endswith(".weight"): + key = key[: -len(".weight")] + key = re.sub(r"\.(lora_[AB])\.(default|\d+)", r".\1.0", key) + key = re.sub(r"\.(lora_[AB])$", r".\1.0", key) + return key + + +def _parse_moe_lora_key(key: str) -> tuple[str, int] | None: + key = _strip_adapter_export_prefix(key) + m = _MOE_LORA_KEY_RE.fullmatch(key) + if m is None: + return None + proj_map = {"gate_proj": "w1", "down_proj": "w2", "up_proj": "w3"} + return f"{m.group('prefix')}.{proj_map[m.group('proj')]}_{m.group('ab')}.0", int(m.group("eid")) + + +def _set_adapter_idx_suffix(key: str, adapter_idx: int) -> str: + if not key.endswith(".0"): + raise ValueError(f"Expected adapter-slot suffix '.0' in key: {key}") + return f"{key[:-2]}.{adapter_idx}" + + +def _is_model_lora_key_for_adapter(key: str, adapter_idx: int) -> bool: + # Support the current standard LoRA and grouped-expert adapter suffixes. + suffixes = ( + f"lora_A.{adapter_idx}", + f"lora_B.{adapter_idx}", + f"w1_lora_A.{adapter_idx}", + f"w1_lora_B.{adapter_idx}", + f"w2_lora_A.{adapter_idx}", + f"w2_lora_B.{adapter_idx}", + f"w3_lora_A.{adapter_idx}", + f"w3_lora_B.{adapter_idx}", + ) + return key.endswith(suffixes) + + +def _get_model_lora_state_keys(model_state: dict[str, torch.Tensor], adapter_idx: int) -> set[str]: + return {key for key in model_state if _is_model_lora_key_for_adapter(key, adapter_idx)} + + +def _get_load_path_label(mapped_keys: dict[str, torch.Tensor]) -> str: + has_moe = any(".experts." in key for key in mapped_keys) + has_dense = any(".lora_" in key and ".experts." not in key for key in mapped_keys) + if has_dense and has_moe: + return "mixed" + if has_moe: + return "moe" + return "normal" + + +def _raise_lora_key_mismatch( + *, + init_adapter_path: Path, + adapter_idx: int, + expected_keys: set[str], + actual_keys: set[str], + load_path: str, +) -> None: + missing = sorted(expected_keys - actual_keys) + unexpected = sorted(actual_keys - expected_keys) + matched = len(expected_keys & actual_keys) + raise ValueError( + "LoRA key mismatch. " + f"adapter_path={init_adapter_path}, adapter_idx={adapter_idx}, load_path={load_path}, " + f"loaded_keys={len(actual_keys)}, matched_model_keys={matched}/{len(expected_keys)}, " + f"missing_sample={missing[:5]}, unexpected_sample={unexpected[:5]}" + ) + + +@dataclass(frozen=True) +class PreparedInitAdapter: + init_adapter_path: Path + slot0_tensors: dict[str, torch.Tensor] + load_path: str + + def apply_to_model(self, model: nn.Module, adapter_idx: int = 0) -> None: + model_state = model.state_dict() + expected_keys = _get_model_lora_state_keys(model_state, adapter_idx) + mapped = { + _set_adapter_idx_suffix(key, adapter_idx): value + for key, value in self.slot0_tensors.items() + } + + if set(mapped) != expected_keys: + _raise_lora_key_mismatch( + init_adapter_path=self.init_adapter_path, + adapter_idx=adapter_idx, + expected_keys=expected_keys, + actual_keys=set(mapped), + load_path=self.load_path, + ) + + aligned = {} + for key, value in mapped.items(): + target = model_state[key] + if value.shape != target.shape: + raise ValueError( + f"LoRA tensor shape mismatch for {key}: expected {target.shape}, got {value.shape}" + ) + value = value.to(dtype=target.dtype) + if isinstance(target, DTensor): + aligned[key] = distribute_tensor( + value.to(device=target.device), + target.device_mesh, + target.placements, + ) + else: + aligned[key] = value.to(device=target.device) + + model.load_state_dict(aligned, strict=False) + + def register_creation_hook(self, model: nn.Module) -> None: + if getattr(model, "_prime_init_adapter_creation_hook_registered", False): + return + + def _apply_prepared_init_adapter(idx: int, _run_id: str) -> None: + self.apply_to_model(model, adapter_idx=idx) + + get_multi_run_manager().register_creation_hook(_apply_prepared_init_adapter) + setattr(model, "_prime_init_adapter_creation_hook_registered", True) + + +def prepare_init_adapter(model: nn.Module, init_adapter_path: Path, lora_config: LoRAConfig) -> PreparedInitAdapter: + adapter_dir = resolve_adapter_dir(init_adapter_path) + _load_adapter_config(adapter_dir, lora_config) + weights_path = _resolve_adapter_weights(adapter_dir) + raw = ( + load_file(str(weights_path), device="cpu") + if weights_path.suffix == ".safetensors" + else torch.load(weights_path, map_location="cpu", weights_only=True) + ) + + mapped: dict[str, torch.Tensor] = {} + moe_parts: dict[str, dict[int, torch.Tensor]] = {} + for key, value in raw.items(): + if "lora_A" not in key and "lora_B" not in key: + continue + moe = _parse_moe_lora_key(key) + if moe is not None: + target_key, expert_id = moe + moe_parts.setdefault(target_key, {})[expert_id] = value + else: + mapped[_normalize_lora_key(key)] = value + + for target_key, parts in moe_parts.items(): + count = len(parts) + if set(parts) != set(range(count)): + raise ValueError(f"Missing MoE expert slices for {target_key}") + mapped[target_key] = torch.stack([parts[i] for i in range(count)], dim=0) + + if not mapped: + raise ValueError("No LoRA tensors found in init adapter") + + model_state = model.state_dict() + expected_keys = _get_model_lora_state_keys(model_state, 0) + load_path = _get_load_path_label(mapped) + if set(mapped) != expected_keys: + _raise_lora_key_mismatch( + init_adapter_path=init_adapter_path, + adapter_idx=0, + expected_keys=expected_keys, + actual_keys=set(mapped), + load_path=load_path, + ) + + for key, value in mapped.items(): + target = model_state[key] + if value.shape != target.shape: + raise ValueError(f"LoRA tensor shape mismatch for {key}: expected {target.shape}, got {value.shape}") + + return PreparedInitAdapter( + init_adapter_path=init_adapter_path, + slot0_tensors={key: value.detach().to("cpu", non_blocking=False) for key, value in mapped.items()}, + load_path=load_path, + ) diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 188b0f61f7..1917e81e2a 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -23,7 +23,12 @@ from prime_rl.configs.trainer import ActivationCheckpointConfig, CompileConfig, ModelConfig, TokenizerConfig from prime_rl.trainer.distributed import DeepEPExpertParallel -from prime_rl.trainer.lora import apply_lora_to_model, freeze_all_except_lora_and_specified, strip_lora_from_state_dict +from prime_rl.trainer.lora import ( + apply_lora_to_model, + freeze_all_except_lora_and_specified, + prepare_init_adapter, + strip_lora_from_state_dict, +) from prime_rl.trainer.models import ( AutoModelForCausalLMPrimeRL, PreTrainedModelPrimeRL, @@ -831,6 +836,15 @@ def setup_model( else: load_dcp_from_hf(model, config, parallel_dims) + prepared_init_adapter = None + if config.lora is not None and config.lora.init_adapter_path is not None: + prepared_init_adapter = prepare_init_adapter(model, config.lora.init_adapter_path, config.lora) + + if prepared_init_adapter is not None: + if not loading_from_checkpoint_later: + prepared_init_adapter.apply_to_model(model) + prepared_init_adapter.register_creation_hook(model) + _reset_runtime_moe_buffers(model) return model diff --git a/tests/integration/test_rl_lora_init_continuation.py b/tests/integration/test_rl_lora_init_continuation.py new file mode 100644 index 0000000000..77f1deebcb --- /dev/null +++ b/tests/integration/test_rl_lora_init_continuation.py @@ -0,0 +1,212 @@ +from pathlib import Path +from typing import Callable + +import pytest + +from prime_rl.trainer.weights import load_state_dict +from tests.conftest import ProcessResult +from tests.utils import ( + check_loss_goes_down, + check_metric_in_range, + check_no_error, + check_reward_goes_up, + check_reward_in_range, + strip_escape_codes, +) + +pytestmark = [pytest.mark.gpu, pytest.mark.slow] + + +TIMEOUT = 600 +SFT_STEP = 100 +RL_STEP = 20 + + +def read_log_lines(*paths: Path) -> list[str]: + for path in paths: + if path.exists(): + with open(path, "r") as f: + return strip_escape_codes(f.read()).splitlines() + raise FileNotFoundError(f"None of the expected log paths exist: {[p.as_posix() for p in paths]}") + + +def assert_adapter_checkpoint(adapter_dir: Path) -> None: + assert (adapter_dir / "adapter_config.json").exists() + state_dict = load_state_dict(adapter_dir) + assert state_dict + assert all(key.startswith("base_model.model.") for key in state_dict) + assert any(key.endswith("lora_A.weight") for key in state_dict) + assert any(key.endswith("lora_B.weight") for key in state_dict) + + +@pytest.fixture(scope="module") +def wandb_name(branch_name: str) -> str: + return f"test-rl-lora-init-continuation-{branch_name}" + + +@pytest.fixture(scope="module") +def sft_output_dir(output_dir: Path) -> Path: + path = output_dir / "sft_lora" + path.mkdir(parents=True, exist_ok=True) + return path + + +@pytest.fixture(scope="module") +def rl_output_dir(output_dir: Path) -> Path: + path = output_dir / "rl_lora_init" + path.mkdir(parents=True, exist_ok=True) + return path + + +@pytest.fixture(scope="module") +def sft_lora_process( + run_process: Callable[..., ProcessResult], + wandb_project: str, + branch_name: str, + sft_output_dir: Path, +) -> ProcessResult: + cmd = [ + "uv", + "run", + "sft", + "@", + "configs/ci/integration/rl_lora_init/sft.toml", + "--deployment.num-gpus", + "2", + "--clean-output-dir", + "--wandb.project", + wandb_project, + "--wandb.name", + f"test-sft-lora-init-continuation-{branch_name}", + "--output-dir", + sft_output_dir.as_posix(), + ] + return run_process(cmd, timeout=TIMEOUT) + + +@pytest.fixture(scope="module") +def init_adapter_dir(sft_lora_process: ProcessResult, sft_output_dir: Path) -> Path: + assert sft_lora_process.returncode == 0, f"SFT process has non-zero return code ({sft_lora_process})" + adapter_dir = sft_output_dir / "weights" / f"step_{SFT_STEP}" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir) + return adapter_dir + + +@pytest.fixture(scope="module") +def rl_process( + init_adapter_dir: Path, + run_process: Callable[..., ProcessResult], + rl_output_dir: Path, + wandb_project: str, + wandb_name: str, +) -> ProcessResult: + cmd = [ + "uv", + "run", + "rl", + "@", + "configs/ci/integration/rl_lora_init/start.toml", + "--trainer.model.lora.init_adapter_path", + init_adapter_dir.as_posix(), + "--wandb.project", + wandb_project, + "--wandb.name", + wandb_name, + "--output-dir", + rl_output_dir.as_posix(), + ] + return run_process(cmd, env={"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}, timeout=TIMEOUT) + + +@pytest.fixture(scope="module") +def rl_resume_process( + rl_process, + init_adapter_dir: Path, + run_process: Callable[..., ProcessResult], + rl_output_dir: Path, + wandb_project: str, + wandb_name: str, +) -> ProcessResult: + assert rl_process.returncode == 0, f"RL init process has non-zero return code ({rl_process})" + cmd = [ + "uv", + "run", + "rl", + "@", + "configs/ci/integration/rl_lora_init/resume.toml", + "--trainer.model.lora.init_adapter_path", + init_adapter_dir.as_posix(), + "--wandb.project", + wandb_project, + "--wandb.name", + f"{wandb_name}-resume", + "--output-dir", + rl_output_dir.as_posix(), + ] + return run_process(cmd, env={"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}, timeout=TIMEOUT) + + +def test_sft_lora_no_error(sft_lora_process: ProcessResult): + assert sft_lora_process.returncode == 0, f"Process has non-zero return code ({sft_lora_process})" + + +def test_sft_lora_loss_goes_down(sft_lora_process: ProcessResult, sft_output_dir: Path): + trainer_stdout = read_log_lines( + sft_output_dir / "logs" / "trainer.log", + sft_output_dir / "logs" / "trainer" / "rank_0.log", + ) + check_loss_goes_down(trainer_stdout) + check_metric_in_range(trainer_stdout, metric_name="Loss", pattern=r"Loss:\s*(\d+\.\d{4})", min_threshold=None, max_threshold=1.5) + + +def test_init_adapter_checkpoint_written(init_adapter_dir: Path): + assert_adapter_checkpoint(init_adapter_dir) + + +@pytest.fixture(scope="module") +def rl_no_error(rl_process: ProcessResult, rl_output_dir: Path): + check_no_error(rl_process, rl_output_dir) + + +def test_reward_goes_up(rl_process: ProcessResult, rl_no_error, rl_output_dir: Path): + orchestrator_stdout = read_log_lines( + rl_output_dir / "logs" / "orchestrator.log", + rl_output_dir / "logs" / "orchestrator.stdout", + ) + check_reward_goes_up(orchestrator_stdout) + + +def test_reward_in_range(rl_process: ProcessResult, rl_no_error, rl_output_dir: Path): + orchestrator_stdout = read_log_lines( + rl_output_dir / "logs" / "orchestrator.log", + rl_output_dir / "logs" / "orchestrator.stdout", + ) + check_reward_in_range(orchestrator_stdout, min_threshold=0.65) + + +def test_rl_adapter_checkpoint_written(rl_process: ProcessResult, rl_no_error, rl_output_dir: Path): + adapter_dir = rl_output_dir / "weights" / f"step_{RL_STEP}" / "lora_adapters" + assert_adapter_checkpoint(adapter_dir) + + +@pytest.fixture(scope="module") +def rl_resume_no_error(rl_resume_process: ProcessResult, rl_output_dir: Path): + check_no_error(rl_resume_process, rl_output_dir) + + +def test_reward_in_range_resume(rl_resume_process: ProcessResult, rl_resume_no_error, rl_output_dir: Path): + orchestrator_stdout = read_log_lines( + rl_output_dir / "logs" / "orchestrator.log", + rl_output_dir / "logs" / "orchestrator.stdout", + ) + check_reward_in_range(orchestrator_stdout, min_threshold=0.65) + + +def test_resume_restores_training_progress(rl_resume_process: ProcessResult, rl_resume_no_error, rl_output_dir: Path): + trainer_stdout = read_log_lines( + rl_output_dir / "logs" / "trainer.log", + rl_output_dir / "logs" / "trainer" / "rank_0.log", + ) + + assert any(f"Resuming training from checkpoint step {RL_STEP}" in line for line in trainer_stdout) + assert any(f"Starting from step {RL_STEP}" in line for line in trainer_stdout) diff --git a/tests/unit/test_lora_broadcast_export.py b/tests/unit/test_lora_broadcast_export.py new file mode 100644 index 0000000000..f04e9d7859 --- /dev/null +++ b/tests/unit/test_lora_broadcast_export.py @@ -0,0 +1,90 @@ +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn + +from prime_rl.configs.trainer import LoRAConfig +from prime_rl.trainer.ckpt import WeightCheckpointManager +from prime_rl.trainer.models.layers.lora.multi_linear import MultiLoRALinear +from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts +from prime_rl.trainer.models.layers.moe import GroupedExperts +from prime_rl.trainer.runs import get_multi_run_manager, setup_multi_run_manager + + +def test_get_state_dict_for_run_stays_in_internal_key_space() -> None: + with tempfile.TemporaryDirectory() as td: + init_file = Path(td) / "pg_init" + dist.init_process_group(backend="gloo", init_method=f"file://{init_file}", rank=0, world_size=1) + try: + setup_multi_run_manager(Path(td), 1, torch.device("cpu"), LoRAConfig(rank=8, target_modules=["q_proj"])) + mgr = get_multi_run_manager() + mod = MultiLoRALinear(nn.Linear(4, 6, bias=False), rank=8, n_adapters=1) + mgr.register_module("model.layers.0.self_attn.q_proj", mod) + + state = mgr.get_state_dict_for_run(0) + + assert "model.layers.0.self_attn.q_proj.lora_A.weight" in state + assert "model.layers.0.self_attn.q_proj.lora_B.weight" in state + assert not any(key.startswith("base_model.model.") for key in state) + finally: + dist.destroy_process_group() + + +def test_get_state_dict_for_run_keeps_moe_keys_internal_and_unprefixed() -> None: + class _FakeExperts(GroupedExperts): + def __init__(self): + nn.Module.__init__(self) + self.w1 = torch.zeros((2, 4, 4)) + self.w2 = torch.zeros((2, 4, 4)) + self.w3 = torch.zeros((2, 4, 4)) + self.num_experts = 2 + self.hidden_size = 4 + self.intermediate_size = 4 + + def forward(self, hidden_states, num_tokens_per_expert): + raise NotImplementedError + + with tempfile.TemporaryDirectory() as td: + init_file = Path(td) / "pg_init" + dist.init_process_group(backend="gloo", init_method=f"file://{init_file}", rank=0, world_size=1) + try: + setup_multi_run_manager(Path(td), 1, torch.device("cpu"), LoRAConfig(rank=8, target_modules=["experts"])) + mgr = get_multi_run_manager() + mod = MultiLoRAGroupedExperts(_FakeExperts(), rank=8, n_adapters=1) + mgr.register_module("model.layers.0.mlp.experts", mod) + + state = mgr.get_state_dict_for_run(0) + + assert "model.layers.0.mlp.experts.0.gate_proj.lora_A.weight" in state + assert "model.layers.0.mlp.experts.0.gate_proj.lora_B.weight" in state + assert "model.layers.0.mlp.experts.1.up_proj.lora_A.weight" in state + assert not any(key.startswith("base_model.model.") for key in state) + finally: + dist.destroy_process_group() + + +def test_weight_checkpoint_adapter_export_adds_peft_prefix_exactly_once() -> None: + with tempfile.TemporaryDirectory() as td: + init_file = Path(td) / "pg_init" + dist.init_process_group(backend="gloo", init_method=f"file://{init_file}", rank=0, world_size=1) + try: + lora_config = LoRAConfig(rank=8, target_modules=["q_proj"]) + setup_multi_run_manager(Path(td), 1, torch.device("cpu"), lora_config) + mgr = get_multi_run_manager() + mod = MultiLoRALinear(nn.Linear(4, 6, bias=False), rank=8, n_adapters=1) + mgr.register_module("model.layers.0.self_attn.q_proj", mod) + + ckpt_config = type( + "Cfg", + (), + {"save_format": "safetensors", "save_sharded": True, "save_adapter_separately": True}, + )() + state = WeightCheckpointManager(Path(td), ckpt_config, lora_config).get_run_adapter_state_dict() + + assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" in state + assert "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight" in state + assert not any(key.startswith("base_model.model.base_model.model.") for key in state) + finally: + dist.destroy_process_group() diff --git a/tests/unit/train/test_lora_init_continuation.py b/tests/unit/train/test_lora_init_continuation.py new file mode 100644 index 0000000000..caef799a4c --- /dev/null +++ b/tests/unit/train/test_lora_init_continuation.py @@ -0,0 +1,332 @@ +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from safetensors.torch import save_file + +from prime_rl.configs.trainer import LoRAConfig +from prime_rl.trainer.lora import ( + prepare_init_adapter, +) + + +class _FakeModel: + pass + + +class _PreparedAdapterRecorder: + def __init__(self) -> None: + self.calls: list[tuple[object, int]] = [] + + def apply_to_model(self, model, adapter_idx: int = 0) -> None: + self.calls.append((model, adapter_idx)) + + def register_creation_hook(self, model) -> None: + if getattr(model, "_prime_init_adapter_creation_hook_registered", False): + return + + def _apply_prepared_init_adapter(idx: int, _run_id: str) -> None: + self.apply_to_model(model, adapter_idx=idx) + + from prime_rl.trainer.lora import get_multi_run_manager + + get_multi_run_manager().register_creation_hook(_apply_prepared_init_adapter) + setattr(model, "_prime_init_adapter_creation_hook_registered", True) + + +def _write_adapter_config(adapter_dir: Path, rank: int = 1, alpha: int = 1) -> None: + (adapter_dir / "adapter_config.json").write_text( + f'{{"peft_type":"LORA","r":{rank},"lora_alpha":{alpha},"modules_to_save":null}}' + ) + + +def test_prepared_init_adapter_registers_creation_hook_for_created_run_slots() -> None: + model = _FakeModel() + prepared_adapter = _PreparedAdapterRecorder() + creation_hooks = [] + + class _Manager: + def register_creation_hook(self, hook): + creation_hooks.append(hook) + + from unittest.mock import patch + + with patch("prime_rl.trainer.lora.get_multi_run_manager", return_value=_Manager()): + prepared_adapter.register_creation_hook(model) + assert len(creation_hooks) == 1 + creation_hooks[0](12, "run_default") + + assert prepared_adapter.calls == [(model, 12)] + + +def test_prepared_init_adapter_registers_creation_hook_only_once() -> None: + model = _FakeModel() + prepared_adapter = _PreparedAdapterRecorder() + creation_hooks = [] + + class _Manager: + def register_creation_hook(self, hook): + creation_hooks.append(hook) + + from unittest.mock import patch + + with patch("prime_rl.trainer.lora.get_multi_run_manager", return_value=_Manager()): + prepared_adapter.register_creation_hook(model) + prepared_adapter.register_creation_hook(model) + + assert len(creation_hooks) == 1 + + +def test_prepared_init_adapter_creation_hook_does_not_apply_current_slot() -> None: + model = _FakeModel() + prepared_adapter = _PreparedAdapterRecorder() + creation_hooks = [] + + class _Manager: + def register_creation_hook(self, hook): + creation_hooks.append(hook) + + from unittest.mock import patch + + with patch("prime_rl.trainer.lora.get_multi_run_manager", return_value=_Manager()): + prepared_adapter.register_creation_hook(model) + + assert prepared_adapter.calls == [] + assert len(creation_hooks) == 1 + creation_hooks[0](7, "run_after_resume") + assert prepared_adapter.calls == [(model, 7)] + + +def test_prepare_init_adapter_supports_dtensor_targets_and_preserves_values(tmp_path: Path) -> None: + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + _write_adapter_config(adapter_dir) + save_file( + { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.tensor([[1.0, 2.0]], dtype=torch.float64), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.tensor([[3.0], [4.0]], dtype=torch.float64), + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + + with tempfile.TemporaryDirectory() as tmpdir: + init_file = Path(tmpdir) / "pg_init" + dist.init_process_group(backend="gloo", init_method=f"file://{init_file}", rank=0, world_size=1) + try: + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor import Replicate, distribute_tensor + + mesh = init_device_mesh("cpu", (1,)) + + class _Model: + def __init__(self): + self._state = { + "model.layers.0.self_attn.q_proj.lora_A.0": distribute_tensor( + torch.zeros((1, 2), dtype=torch.float32), mesh, [Replicate()] + ), + "model.layers.0.self_attn.q_proj.lora_B.0": distribute_tensor( + torch.zeros((2, 1), dtype=torch.float32), mesh, [Replicate()] + ), + "model.layers.0.self_attn.q_proj.lora_A.3": distribute_tensor( + torch.zeros((1, 2), dtype=torch.float32), mesh, [Replicate()] + ), + "model.layers.0.self_attn.q_proj.lora_B.3": distribute_tensor( + torch.zeros((2, 1), dtype=torch.float32), mesh, [Replicate()] + ), + } + self.loaded = None + + def state_dict(self): + return self._state + + def load_state_dict(self, aligned, strict=False): + self.loaded = aligned + + model = _Model() + prepared = prepare_init_adapter(model, adapter_dir, LoRAConfig(rank=1, alpha=1)) + prepared.apply_to_model(model, adapter_idx=3) + + assert model.loaded is not None + assert all(hasattr(value, "device_mesh") for value in model.loaded.values()) + assert torch.equal( + model.loaded["model.layers.0.self_attn.q_proj.lora_A.3"].to_local(), + torch.tensor([[1.0, 2.0]], dtype=torch.float32), + ) + assert torch.equal( + model.loaded["model.layers.0.self_attn.q_proj.lora_B.3"].to_local(), + torch.tensor([[3.0], [4.0]], dtype=torch.float32), + ) + finally: + dist.destroy_process_group() + + +def test_prepare_init_adapter_nonzero_adapter_idx_preserves_layer_zero_path(tmp_path: Path) -> None: + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + _write_adapter_config(adapter_dir) + save_file( + { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.tensor([[1.0, 2.0]]), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.tensor([[3.0], [4.0]]), + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + + class _Model: + def __init__(self): + self._state = { + "model.layers.0.self_attn.q_proj.lora_A.0": torch.zeros((1, 2)), + "model.layers.0.self_attn.q_proj.lora_B.0": torch.zeros((2, 1)), + "model.layers.0.self_attn.q_proj.lora_A.12": torch.zeros((1, 2)), + "model.layers.0.self_attn.q_proj.lora_B.12": torch.zeros((2, 1)), + } + self.loaded = None + + def state_dict(self): + return self._state + + def load_state_dict(self, aligned, strict=False): + self.loaded = aligned + + model = _Model() + prepared = prepare_init_adapter(model, adapter_dir, LoRAConfig(rank=1, alpha=1)) + prepared.apply_to_model(model, adapter_idx=12) + + assert model.loaded is not None + assert set(model.loaded) == { + "model.layers.0.self_attn.q_proj.lora_A.12", + "model.layers.0.self_attn.q_proj.lora_B.12", + } + + +def test_prepare_init_adapter_nonzero_adapter_idx_preserves_moe_layer_zero_path_and_values(tmp_path: Path) -> None: + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + _write_adapter_config(adapter_dir) + save_file( + { + "base_model.model.model.layers.0.mlp.experts.0.gate_proj.lora_A.weight": torch.tensor([[1.0, 2.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.0.gate_proj.lora_B.weight": torch.tensor([[3.0], [4.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.gate_proj.lora_A.weight": torch.tensor([[5.0, 6.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.gate_proj.lora_B.weight": torch.tensor([[7.0], [8.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.0.down_proj.lora_A.weight": torch.tensor([[11.0, 12.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.0.down_proj.lora_B.weight": torch.tensor([[13.0], [14.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.down_proj.lora_A.weight": torch.tensor([[15.0, 16.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.down_proj.lora_B.weight": torch.tensor([[17.0], [18.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.0.up_proj.lora_A.weight": torch.tensor([[21.0, 22.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.0.up_proj.lora_B.weight": torch.tensor([[23.0], [24.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.up_proj.lora_A.weight": torch.tensor([[25.0, 26.0]], dtype=torch.float64), + "base_model.model.model.layers.0.mlp.experts.1.up_proj.lora_B.weight": torch.tensor([[27.0], [28.0]], dtype=torch.float64), + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + + class _Model: + def __init__(self): + stacked = torch.zeros((2, 1, 2), dtype=torch.float32) + stacked_b = torch.zeros((2, 2, 1), dtype=torch.float32) + self._state = { + "model.layers.0.mlp.experts.w1_lora_A.0": stacked.clone(), + "model.layers.0.mlp.experts.w1_lora_B.0": stacked_b.clone(), + "model.layers.0.mlp.experts.w2_lora_A.0": stacked.clone(), + "model.layers.0.mlp.experts.w2_lora_B.0": stacked_b.clone(), + "model.layers.0.mlp.experts.w3_lora_A.0": stacked.clone(), + "model.layers.0.mlp.experts.w3_lora_B.0": stacked_b.clone(), + "model.layers.0.mlp.experts.w1_lora_A.4": stacked.clone(), + "model.layers.0.mlp.experts.w1_lora_B.4": stacked_b.clone(), + "model.layers.0.mlp.experts.w2_lora_A.4": stacked.clone(), + "model.layers.0.mlp.experts.w2_lora_B.4": stacked_b.clone(), + "model.layers.0.mlp.experts.w3_lora_A.4": stacked.clone(), + "model.layers.0.mlp.experts.w3_lora_B.4": stacked_b.clone(), + } + self.loaded = None + + def state_dict(self): + return self._state + + def load_state_dict(self, aligned, strict=False): + self.loaded = aligned + + model = _Model() + prepared = prepare_init_adapter(model, adapter_dir, LoRAConfig(rank=1, alpha=1)) + prepared.apply_to_model(model, adapter_idx=4) + + assert model.loaded is not None + assert set(model.loaded) == { + "model.layers.0.mlp.experts.w1_lora_A.4", + "model.layers.0.mlp.experts.w1_lora_B.4", + "model.layers.0.mlp.experts.w2_lora_A.4", + "model.layers.0.mlp.experts.w2_lora_B.4", + "model.layers.0.mlp.experts.w3_lora_A.4", + "model.layers.0.mlp.experts.w3_lora_B.4", + } + assert torch.equal( + model.loaded["model.layers.0.mlp.experts.w1_lora_A.4"], + torch.tensor([[[1.0, 2.0]], [[5.0, 6.0]]], dtype=torch.float32), + ) + assert torch.equal( + model.loaded["model.layers.0.mlp.experts.w2_lora_A.4"], + torch.tensor([[[11.0, 12.0]], [[15.0, 16.0]]], dtype=torch.float32), + ) + assert torch.equal( + model.loaded["model.layers.0.mlp.experts.w3_lora_B.4"], + torch.tensor([[[23.0], [24.0]], [[27.0], [28.0]]], dtype=torch.float32), + ) + + +def test_prepared_init_adapter_is_reused_without_reloading_files(tmp_path: Path) -> None: + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + _write_adapter_config(adapter_dir) + save_file( + { + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.tensor([[1.0, 2.0]]), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.tensor([[3.0], [4.0]]), + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + + class _Model: + def __init__(self): + self._state = { + "model.layers.0.self_attn.q_proj.lora_A.0": torch.zeros((1, 2)), + "model.layers.0.self_attn.q_proj.lora_B.0": torch.zeros((2, 1)), + "model.layers.0.self_attn.q_proj.lora_A.1": torch.zeros((1, 2)), + "model.layers.0.self_attn.q_proj.lora_B.1": torch.zeros((2, 1)), + } + self.loaded = None + + def state_dict(self): + return self._state + + def load_state_dict(self, aligned, strict=False): + self.loaded = aligned + + model = _Model() + prepared = prepare_init_adapter(model, adapter_dir, LoRAConfig(rank=1, alpha=1)) + + from unittest.mock import patch + + with patch("prime_rl.trainer.lora.load_file", side_effect=AssertionError("adapter files should not be re-read")): + prepared.apply_to_model(model, adapter_idx=0) + prepared.apply_to_model(model, adapter_idx=1) + + assert model.loaded is not None + + +def test_prepare_init_adapter_rejects_modules_to_save(tmp_path: Path) -> None: + adapter_dir = tmp_path / "adapter" + adapter_dir.mkdir() + (adapter_dir / "adapter_config.json").write_text( + '{"peft_type":"LORA","r":1,"lora_alpha":1,"modules_to_save":["lm_head"]}' + ) + save_file({}, str(adapter_dir / "adapter_model.safetensors")) + + class _Model: + def state_dict(self): + return {} + + with pytest.raises(ValueError, match="modules_to_save"): + prepare_init_adapter(_Model(), adapter_dir, LoRAConfig(rank=1, alpha=1))