From 1de9bb7d19463dd8cb919564ab9bc089e7f78cfe Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Wed, 1 Apr 2026 20:13:15 -0700 Subject: [PATCH 1/2] Fix LoRA merged weight exports --- docs/checkpointing.md | 4 +- src/prime_rl/configs/trainer.py | 2 +- src/prime_rl/trainer/ckpt.py | 11 +- src/prime_rl/trainer/lora.py | 85 ++++++++++++++++ src/prime_rl/trainer/weights.py | 7 -- tests/integration/test_sft_lora.py | 20 ++++ tests/unit/train/test_lora.py | 156 +++++++++++++++++++++++++++++ 7 files changed, 274 insertions(+), 11 deletions(-) create mode 100644 tests/unit/train/test_lora.py diff --git a/docs/checkpointing.md b/docs/checkpointing.md index 49f33be2b5..d2b5b2765e 100644 --- a/docs/checkpointing.md +++ b/docs/checkpointing.md @@ -10,6 +10,8 @@ The default checkpoint directory is `checkpoints` and each checkpoint step will Checkpointing is configured with the config key `--ckpt`. One can specify the interval (`--ckpt.interval`), whether to save checkpoints asynchronoously (`--ckpt.save-async`), how many recent step checkpoints to keep on disk (`--ckpt.keep-last`), and keep checkpoints at every N steps permanently (`--ckpt.keep-interval`). By default, we do not checkpoint to save disk space. +When weight export is enabled under `ckpt.weights`, `weights/step_` is intended to be a standalone Hugging Face-compatible checkpoint. If LoRA is enabled, the full checkpoint remains merged and `save_adapter_separately=true` additionally writes an adapter-only export under `weights/step_/lora_adapters`. + ## SFT Let's split the reverse text training SFT example, which does 40 steps by default, into two runs of 20 steps each. @@ -54,4 +56,4 @@ uv run rl \ --orchestrator @ path/to/orch.toml \ --max-steps 20 \ --ckpt.resume-step 10 -``` \ No newline at end of file +``` diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index 1508a0a683..1858c91234 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -538,7 +538,7 @@ class WeightCheckpointConfig(BaseConfig): save_adapter_separately: Annotated[ bool, Field( - description="Whether to save LoRA adapters separately before merging into full model weights.", + description="Whether to also save LoRA adapters separately in addition to merged full model weights.", ), ] = False diff --git a/src/prime_rl/trainer/ckpt.py b/src/prime_rl/trainer/ckpt.py index 4734cfc563..5248a72106 100644 --- a/src/prime_rl/trainer/ckpt.py +++ b/src/prime_rl/trainer/ckpt.py @@ -20,7 +20,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer from prime_rl.configs.trainer import CheckpointConfig, LoRAConfig, WeightCheckpointConfig -from prime_rl.trainer.lora import has_lora_layers, save_lora_config +from prime_rl.trainer.lora import has_lora_layers, merge_lora_state_dict, save_lora_config from prime_rl.trainer.models import PreTrainedModelPrimeRL from prime_rl.trainer.optim import CPUOffloadOptimizer from prime_rl.trainer.runs import Progress, get_multi_run_manager @@ -385,12 +385,19 @@ def save( state_dict = gather_weights_on_master(model, self.world.is_master, dtype=torch.bfloat16) self.logger.debug(f"Gathered weights on master rank in {time.perf_counter() - start_time:.2f} seconds") + lora_enabled = has_lora_layers(model) + if lora_enabled: + self.logger.debug("Merging LoRA weights into full weight checkpoint") + start_time = time.perf_counter() + state_dict = merge_lora_state_dict(model, state_dict) + self.logger.debug(f"Merged LoRA weights in {time.perf_counter() - start_time:.2f} seconds") + # Remove tied weight keys to match original model format if getattr(model.config, "tie_word_embeddings", False): for key in getattr(model, "_tied_weights_keys", []): state_dict.pop(key, None) - if has_lora_layers(model) and self.config.save_adapter_separately: + if lora_enabled and self.config.save_adapter_separately: self.logger.debug("Getting run adapter state dict for weight checkpoint") start_time = time.perf_counter() lora_state_dict = self.get_run_adapter_state_dict() diff --git a/src/prime_rl/trainer/lora.py b/src/prime_rl/trainer/lora.py index 5afac23363..8a4bfb7c37 100644 --- a/src/prime_rl/trainer/lora.py +++ b/src/prime_rl/trainer/lora.py @@ -217,6 +217,91 @@ def has_lora_layers(model: nn.Module) -> bool: return False +def _module_key(prefix: str, suffix: str) -> str: + return f"{prefix}.{suffix}" if prefix else suffix + + +def _get_lora_scaling(module: MultiLoRAModule, run_idx: int) -> float: + return float(module._scaling_factors[run_idx].item()) + + +def _compute_lora_delta( + lora_a: torch.Tensor, + lora_b: torch.Tensor, + scaling: float, + output_dtype: torch.dtype, +) -> torch.Tensor: + matmul_dtype = torch.float32 if output_dtype in {torch.float16, torch.bfloat16} else output_dtype + delta = torch.matmul(lora_b.to(dtype=matmul_dtype), lora_a.to(dtype=matmul_dtype)) + return (delta * scaling).to(dtype=output_dtype) + + +def _merge_multilora_linear( + state_dict: dict[str, torch.Tensor], + prefix: str, + module: MultiLoRALinear, + run_idx: int, +) -> None: + base_key = _module_key(prefix, "base_layer.weight") + lora_a_key = _module_key(prefix, f"lora_A.{run_idx}") + lora_b_key = _module_key(prefix, f"lora_B.{run_idx}") + + base_weight = state_dict[base_key] + delta = _compute_lora_delta( + state_dict[lora_a_key], + state_dict[lora_b_key], + scaling=_get_lora_scaling(module, run_idx), + output_dtype=base_weight.dtype, + ) + state_dict[base_key] = base_weight + delta + + +def _merge_multilora_grouped_experts( + state_dict: dict[str, torch.Tensor], + prefix: str, + module: MultiLoRAGroupedExperts, + run_idx: int, +) -> None: + scaling = _get_lora_scaling(module, run_idx) + weight_specs = ( + ("w1", "w1_lora_A", "w1_lora_B"), + ("w2", "w2_lora_A", "w2_lora_B"), + ("w3", "w3_lora_A", "w3_lora_B"), + ) + for base_name, lora_a_name, lora_b_name in weight_specs: + base_key = _module_key(prefix, f"base_layer.{base_name}") + lora_a_key = _module_key(prefix, f"{lora_a_name}.{run_idx}") + lora_b_key = _module_key(prefix, f"{lora_b_name}.{run_idx}") + + base_weight = state_dict[base_key] + delta = _compute_lora_delta( + state_dict[lora_a_key], + state_dict[lora_b_key], + scaling=scaling, + output_dtype=base_weight.dtype, + ) + state_dict[base_key] = base_weight + delta + + +def merge_lora_state_dict( + model: nn.Module, + state_dict: dict[str, torch.Tensor], + run_idx: int = 0, +) -> dict[str, torch.Tensor]: + """Merge a run's LoRA weights into gathered full-model weights.""" + if not state_dict: + return state_dict + + merged_state_dict = dict(state_dict) + for name, module in model.named_modules(): + if isinstance(module, MultiLoRALinear): + _merge_multilora_linear(merged_state_dict, name, module, run_idx) + elif isinstance(module, MultiLoRAGroupedExperts): + _merge_multilora_grouped_experts(merged_state_dict, name, module, run_idx) + + return clean_lora_state_dict(merged_state_dict) + + def clean_lora_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Remove LoRA parameters and fix LoRA base layer key names for HF compatibility.""" clean_state_dict = {} diff --git a/src/prime_rl/trainer/weights.py b/src/prime_rl/trainer/weights.py index 840197aec8..a6b0fb6edf 100644 --- a/src/prime_rl/trainer/weights.py +++ b/src/prime_rl/trainer/weights.py @@ -19,9 +19,6 @@ WEIGHTS_NAME, ) -from prime_rl.trainer.lora import ( - clean_lora_state_dict, -) from prime_rl.utils.logger import get_logger PYTORCH_WRAPPER_PREFIXES = ["_fsdp_wrapped_module.", "_orig_module.", "_checkpoint_wrapped_module."] @@ -147,10 +144,6 @@ def gather_weights_on_master( cpu_state[key] = value.to("cpu", non_blocking=False) torch.distributed.barrier() - # Always clean up the state dict for HF compatibility - if any(".base_layer." in key or "lora_A" in key or "lora_B" in key for key in cpu_state.keys()): - cpu_state = clean_lora_state_dict(cpu_state) - return cpu_state diff --git a/tests/integration/test_sft_lora.py b/tests/integration/test_sft_lora.py index ffa9edc976..80ac3d15a1 100644 --- a/tests/integration/test_sft_lora.py +++ b/tests/integration/test_sft_lora.py @@ -21,6 +21,14 @@ def assert_adapter_checkpoint(adapter_dir: Path) -> None: assert all(key.startswith("base_model.model.") for key in state_dict) +def assert_full_checkpoint(weights_dir: Path) -> None: + state_dict = load_state_dict(weights_dir) + assert state_dict + assert all("lora_A" not in key and "lora_B" not in key for key in state_dict) + assert all(".base_layer." not in key for key in state_dict) + assert all(not key.startswith("base_model.model.") for key in state_dict) + + @pytest.fixture(scope="module") def wandb_name(branch_name: str) -> str: """Fixture for W&B name for SFT LoRA CI integration tests.""" @@ -104,6 +112,12 @@ def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, output_dir: assert_adapter_checkpoint(adapter_dir) +def test_full_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path): + """Tests that the full checkpoint stays HF-compatible when adapters are also exported.""" + weights_dir = output_dir / "weights" / "step_10" + assert_full_checkpoint(weights_dir) + + def test_no_error_resume(sft_lora_resume_process: ProcessResult): """Tests that the SFT LoRA resume process does not fail.""" assert sft_lora_resume_process.returncode == 0, f"Process has non-zero return code ({sft_lora_resume_process})" @@ -122,3 +136,9 @@ def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResul """Tests that the adapter checkpoint is written after resuming with valid PEFT-compatible keys.""" adapter_dir = output_dir / "weights" / "step_20" / "lora_adapters" assert_adapter_checkpoint(adapter_dir) + + +def test_full_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): + """Tests that the resumed full checkpoint stays HF-compatible when adapters are also exported.""" + weights_dir = output_dir / "weights" / "step_20" + assert_full_checkpoint(weights_dir) diff --git a/tests/unit/train/test_lora.py b/tests/unit/train/test_lora.py new file mode 100644 index 0000000000..40a11b98da --- /dev/null +++ b/tests/unit/train/test_lora.py @@ -0,0 +1,156 @@ +import torch +from torch import nn + +from prime_rl.trainer.lora import merge_lora_state_dict +from prime_rl.trainer.models.layers.lora import set_lora_num_tokens, set_multilora_scaling +from prime_rl.trainer.models.layers.lora.multi_linear import MultiLoRALinear +from prime_rl.trainer.models.layers.lora.multi_moe import MultiLoRAGroupedExperts +from prime_rl.trainer.models.layers.moe import GroupedExperts + + +def _set_lora_globals(num_adapters: int, scaling_factors: list[float]) -> None: + set_lora_num_tokens(torch.ones(num_adapters, dtype=torch.int32), reset_reference=True) + set_multilora_scaling(torch.tensor(scaling_factors, dtype=torch.bfloat16), reset_reference=True) + + +def test_merge_lora_state_dict_merges_linear_weights() -> None: + _set_lora_globals(num_adapters=2, scaling_factors=[0.5, 1.5]) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = MultiLoRALinear(nn.Linear(3, 2, bias=False), rank=2, n_adapters=2) + self.modules_to_save = nn.Linear(2, 2, bias=False) + + model = Model() + base_weight = torch.tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + lora_a_run_0 = torch.tensor( + [ + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ] + ) + lora_b_run_0 = torch.tensor( + [ + [1.0, 2.0], + [3.0, 4.0], + ] + ) + lora_a_run_1 = torch.tensor( + [ + [2.0, 1.0, 0.0], + [1.0, 0.0, 2.0], + ] + ) + lora_b_run_1 = torch.tensor( + [ + [1.0, 3.0], + [2.0, 4.0], + ] + ) + modules_to_save_weight = torch.tensor( + [ + [7.0, 8.0], + [9.0, 10.0], + ] + ) + + merged_state_dict = merge_lora_state_dict( + model, + { + "linear.base_layer.weight": base_weight, + "linear.lora_A.0": lora_a_run_0, + "linear.lora_B.0": lora_b_run_0, + "linear.lora_A.1": lora_a_run_1, + "linear.lora_B.1": lora_b_run_1, + "modules_to_save.weight": modules_to_save_weight, + }, + run_idx=1, + ) + + expected_weight = base_weight + 1.5 * torch.matmul(lora_b_run_1, lora_a_run_1) + torch.testing.assert_close(merged_state_dict["linear.weight"], expected_weight) + torch.testing.assert_close(merged_state_dict["modules_to_save.weight"], modules_to_save_weight) + assert all("lora_A" not in key and "lora_B" not in key for key in merged_state_dict) + assert "linear.base_layer.weight" not in merged_state_dict + + +def test_merge_lora_state_dict_merges_grouped_expert_weights() -> None: + _set_lora_globals(num_adapters=1, scaling_factors=[0.5]) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.moe = MultiLoRAGroupedExperts( + GroupedExperts(dim=3, hidden_dim=4, num_experts=2, use_grouped_mm=False), + rank=2, + n_adapters=1, + ) + + model = Model() + base_w1 = torch.arange(24, dtype=torch.float32).view(2, 4, 3) + base_w2 = torch.arange(24, 48, dtype=torch.float32).view(2, 3, 4) + base_w3 = torch.arange(48, 72, dtype=torch.float32).view(2, 4, 3) + + w1_lora_a = torch.tensor( + [ + [[1.0, 0.0, 2.0], [0.0, 1.0, 1.0]], + [[2.0, 1.0, 0.0], [1.0, 0.0, 1.0]], + ] + ) + w1_lora_b = torch.tensor( + [ + [[1.0, 2.0], [0.0, 1.0], [2.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 2.0], [2.0, 0.0], [1.0, 1.0]], + ] + ) + w2_lora_a = torch.tensor( + [ + [[1.0, 0.0, 2.0, 1.0], [0.0, 1.0, 1.0, 0.0]], + [[2.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 2.0]], + ] + ) + w2_lora_b = torch.tensor( + [ + [[1.0, 0.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 2.0], [0.0, 1.0], [2.0, 0.0]], + ] + ) + w3_lora_a = torch.tensor( + [ + [[0.0, 1.0, 2.0], [1.0, 0.0, 1.0]], + [[1.0, 2.0, 0.0], [0.0, 1.0, 1.0]], + ] + ) + w3_lora_b = torch.tensor( + [ + [[2.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], + [[1.0, 0.0], [2.0, 1.0], [1.0, 2.0], [0.0, 1.0]], + ] + ) + + merged_state_dict = merge_lora_state_dict( + model, + { + "moe.base_layer.w1": base_w1, + "moe.base_layer.w2": base_w2, + "moe.base_layer.w3": base_w3, + "moe.w1_lora_A.0": w1_lora_a, + "moe.w1_lora_B.0": w1_lora_b, + "moe.w2_lora_A.0": w2_lora_a, + "moe.w2_lora_B.0": w2_lora_b, + "moe.w3_lora_A.0": w3_lora_a, + "moe.w3_lora_B.0": w3_lora_b, + }, + ) + + torch.testing.assert_close(merged_state_dict["moe.w1"], base_w1 + 0.5 * torch.matmul(w1_lora_b, w1_lora_a)) + torch.testing.assert_close(merged_state_dict["moe.w2"], base_w2 + 0.5 * torch.matmul(w2_lora_b, w2_lora_a)) + torch.testing.assert_close(merged_state_dict["moe.w3"], base_w3 + 0.5 * torch.matmul(w3_lora_b, w3_lora_a)) + assert all("lora_A" not in key and "lora_B" not in key for key in merged_state_dict) + assert all(".base_layer." not in key for key in merged_state_dict) From 6a6ae0817dac5e6ebce4ce617e47ed2ea4f6b8ed Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sat, 4 Apr 2026 11:29:32 -0700 Subject: [PATCH 2/2] Add changelog and merged-only LoRA export coverage --- CHANGELOG.md | 1 + .../sft_lora/start_merged_only.toml | 29 +++ tests/integration/test_sft_lora.py | 206 ++++++++++++------ 3 files changed, 171 insertions(+), 65 deletions(-) create mode 100644 configs/ci/integration/sft_lora/start_merged_only.toml diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7bf13850..eae1460e88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes). +- **`ckpt.weights.save_adapter_separately` / LoRA weight exports**: LoRA weight exports now always write `weights/step_` as a standalone merged Hugging Face-compatible checkpoint. When `save_adapter_separately = true`, PRIME additionally writes adapter-only weights under `weights/step_/lora_adapters`; this flag no longer changes the semantics of the main exported checkpoint. (2026-04-04) - **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31) - **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31) - **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27) diff --git a/configs/ci/integration/sft_lora/start_merged_only.toml b/configs/ci/integration/sft_lora/start_merged_only.toml new file mode 100644 index 0000000000..4ac8714cde --- /dev/null +++ b/configs/ci/integration/sft_lora/start_merged_only.toml @@ -0,0 +1,29 @@ +max_steps = 10 + +[ckpt] + +[ckpt.weights] +save_adapter_separately = false + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[model.lora] +rank = 8 +target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +[data] +name = "PrimeIntellect/Reverse-Text-SFT" +batch_size = 4 +seq_len = 1024 + +[optim] +lr = 1.5e-5 diff --git a/tests/integration/test_sft_lora.py b/tests/integration/test_sft_lora.py index 80ac3d15a1..d32707ff91 100644 --- a/tests/integration/test_sft_lora.py +++ b/tests/integration/test_sft_lora.py @@ -12,6 +12,10 @@ TIMEOUT = 300 # 5 minutes +def assert_process_succeeded(process: ProcessResult) -> None: + assert process.returncode == 0, f"Process has non-zero return code ({process})" + + def assert_adapter_checkpoint(adapter_dir: Path) -> None: assert (adapter_dir / "adapter_config.json").exists() state_dict = load_state_dict(adapter_dir) @@ -21,6 +25,10 @@ def assert_adapter_checkpoint(adapter_dir: Path) -> None: assert all(key.startswith("base_model.model.") for key in state_dict) +def assert_no_adapter_checkpoint(adapter_dir: Path) -> None: + assert not adapter_dir.exists() + + def assert_full_checkpoint(weights_dir: Path) -> None: state_dict = load_state_dict(weights_dir) assert state_dict @@ -29,6 +37,13 @@ def assert_full_checkpoint(weights_dir: Path) -> None: assert all(not key.startswith("base_model.model.") for key in state_dict) +def get_trainer_log_lines(run_output_dir: Path) -> list[str]: + trainer_log_path = run_output_dir / "logs" / "trainer.log" + print(f"Checking trainer path in {trainer_log_path}") + with open(trainer_log_path, "r") as f: + return strip_escape_codes(f.read()).splitlines() + + @pytest.fixture(scope="module") def wandb_name(branch_name: str) -> str: """Fixture for W&B name for SFT LoRA CI integration tests.""" @@ -36,109 +51,170 @@ def wandb_name(branch_name: str) -> str: @pytest.fixture(scope="module") -def sft_lora_process( +def separate_adapter_output_dir(output_dir: Path) -> Path: + return output_dir / "separate_adapter" + + +@pytest.fixture(scope="module") +def merged_only_output_dir(output_dir: Path) -> Path: + return output_dir / "merged_only" + + +@pytest.fixture(scope="module") +def run_sft_lora( run_process: Callable[..., ProcessResult], wandb_project: str, wandb_name: str, - output_dir: Path, +): + def _run_sft_lora(config_path: str, run_output_dir: Path, run_name_suffix: str, clean_output_dir: bool) -> ProcessResult: + cmd = [ + "uv", + "run", + "sft", + "@", + config_path, + ] + if clean_output_dir: + cmd.append("--clean-output-dir") + cmd.extend( + [ + "--deployment.num-gpus", + "2", + "--wandb.project", + wandb_project, + "--wandb.name", + f"{wandb_name}-{run_name_suffix}", + "--output-dir", + run_output_dir.as_posix(), + ] + ) + + return run_process(cmd, timeout=TIMEOUT) + + return _run_sft_lora + + +@pytest.fixture(scope="module") +def sft_lora_process( + run_sft_lora, + separate_adapter_output_dir: Path, ) -> ProcessResult: - """Fixture for running SFT LoRA CI integration test""" - cmd = [ - "uv", - "run", - "sft", - "@", + """Fixture for running SFT LoRA CI integration test with separate adapter exports.""" + return run_sft_lora( "configs/ci/integration/sft_lora/start.toml", - "--deployment.num-gpus", - "2", - "--clean-output-dir", - "--wandb.project", - wandb_project, - "--wandb.name", - wandb_name, - "--output-dir", - output_dir.as_posix(), - ] - - return run_process(cmd, timeout=TIMEOUT) + separate_adapter_output_dir, + "separate-adapter", + clean_output_dir=True, + ) @pytest.fixture(scope="module") def sft_lora_resume_process( sft_lora_process, # Resume training can only start when regular SFT LoRA process is finished - run_process: Callable[..., ProcessResult], - wandb_project: str, - wandb_name: str, - output_dir: Path, + run_sft_lora, + separate_adapter_output_dir: Path, ) -> ProcessResult: - """Fixture for resuming SFT LoRA CI integration test""" - wandb_name += "-resume" - cmd = [ - "uv", - "run", - "sft", - "@", + """Fixture for resuming the SFT LoRA CI integration test with separate adapter exports.""" + if sft_lora_process.returncode != 0: + pytest.skip("Initial SFT LoRA process failed") + return run_sft_lora( "configs/ci/integration/sft_lora/resume.toml", - "--deployment.num-gpus", - "2", - "--wandb.project", - wandb_project, - "--wandb.name", - wandb_name, - "--output-dir", - output_dir.as_posix(), - ] + separate_adapter_output_dir, + "separate-adapter-resume", + clean_output_dir=False, + ) - return run_process(cmd, timeout=TIMEOUT) + +@pytest.fixture(scope="module") +def sft_lora_merged_only_process( + run_sft_lora, + merged_only_output_dir: Path, +) -> ProcessResult: + """Fixture for running SFT LoRA CI integration test without separate adapter exports.""" + return run_sft_lora( + "configs/ci/integration/sft_lora/start_merged_only.toml", + merged_only_output_dir, + "merged-only", + clean_output_dir=True, + ) def test_no_error(sft_lora_process: ProcessResult): """Tests that the SFT LoRA process does not fail.""" - assert sft_lora_process.returncode == 0, f"Process has non-zero return code ({sft_lora_process})" + assert_process_succeeded(sft_lora_process) -def test_loss_goes_down(sft_lora_process: ProcessResult, output_dir: Path): - """Tests that the loss goes down in the SFT LoRA process""" - trainer_log_path = output_dir / "logs" / "trainer.log" - print(f"Checking trainer path in {trainer_log_path}") - with open(trainer_log_path, "r") as f: - trainer_stdout = strip_escape_codes(f.read()).splitlines() - check_loss_goes_down(trainer_stdout) +def test_loss_goes_down(sft_lora_process: ProcessResult, separate_adapter_output_dir: Path): + """Tests that the loss goes down in the SFT LoRA process.""" + assert_process_succeeded(sft_lora_process) + check_loss_goes_down(get_trainer_log_lines(separate_adapter_output_dir)) -def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path): +def test_adapter_checkpoint_written(sft_lora_process: ProcessResult, separate_adapter_output_dir: Path): """Tests that the adapter checkpoint is written with valid PEFT-compatible keys.""" - adapter_dir = output_dir / "weights" / "step_10" / "lora_adapters" + assert_process_succeeded(sft_lora_process) + adapter_dir = separate_adapter_output_dir / "weights" / "step_10" / "lora_adapters" assert_adapter_checkpoint(adapter_dir) -def test_full_checkpoint_written(sft_lora_process: ProcessResult, output_dir: Path): +def test_full_checkpoint_written(sft_lora_process: ProcessResult, separate_adapter_output_dir: Path): """Tests that the full checkpoint stays HF-compatible when adapters are also exported.""" - weights_dir = output_dir / "weights" / "step_10" + assert_process_succeeded(sft_lora_process) + weights_dir = separate_adapter_output_dir / "weights" / "step_10" assert_full_checkpoint(weights_dir) +def test_no_error_merged_only(sft_lora_merged_only_process: ProcessResult): + """Tests that the merged-only SFT LoRA process does not fail.""" + assert_process_succeeded(sft_lora_merged_only_process) + + +def test_loss_goes_down_merged_only(sft_lora_merged_only_process: ProcessResult, merged_only_output_dir: Path): + """Tests that the loss goes down when only merged full checkpoints are exported.""" + assert_process_succeeded(sft_lora_merged_only_process) + check_loss_goes_down(get_trainer_log_lines(merged_only_output_dir)) + + +def test_full_checkpoint_written_merged_only( + sft_lora_merged_only_process: ProcessResult, + merged_only_output_dir: Path, +): + """Tests that merged-only LoRA exports remain HF-compatible.""" + assert_process_succeeded(sft_lora_merged_only_process) + weights_dir = merged_only_output_dir / "weights" / "step_10" + assert_full_checkpoint(weights_dir) + + +def test_adapter_checkpoint_not_written_merged_only( + sft_lora_merged_only_process: ProcessResult, + merged_only_output_dir: Path, +): + """Tests that merged-only LoRA exports do not create an adapter sidecar directory.""" + assert_process_succeeded(sft_lora_merged_only_process) + adapter_dir = merged_only_output_dir / "weights" / "step_10" / "lora_adapters" + assert_no_adapter_checkpoint(adapter_dir) + + def test_no_error_resume(sft_lora_resume_process: ProcessResult): """Tests that the SFT LoRA resume process does not fail.""" - assert sft_lora_resume_process.returncode == 0, f"Process has non-zero return code ({sft_lora_resume_process})" + assert_process_succeeded(sft_lora_resume_process) -def test_loss_goes_down_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): - """Tests that the loss goes down in the SFT LoRA resume process""" - trainer_log_path = output_dir / "logs" / "trainer.log" - print(f"Checking trainer path in {trainer_log_path}") - with open(trainer_log_path, "r") as f: - trainer_stdout = strip_escape_codes(f.read()).splitlines() - check_loss_goes_down(trainer_stdout) +def test_loss_goes_down_resume(sft_lora_resume_process: ProcessResult, separate_adapter_output_dir: Path): + """Tests that the loss goes down in the SFT LoRA resume process.""" + assert_process_succeeded(sft_lora_resume_process) + check_loss_goes_down(get_trainer_log_lines(separate_adapter_output_dir)) -def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): +def test_adapter_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, separate_adapter_output_dir: Path): """Tests that the adapter checkpoint is written after resuming with valid PEFT-compatible keys.""" - adapter_dir = output_dir / "weights" / "step_20" / "lora_adapters" + assert_process_succeeded(sft_lora_resume_process) + adapter_dir = separate_adapter_output_dir / "weights" / "step_20" / "lora_adapters" assert_adapter_checkpoint(adapter_dir) -def test_full_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, output_dir: Path): +def test_full_checkpoint_written_resume(sft_lora_resume_process: ProcessResult, separate_adapter_output_dir: Path): """Tests that the resumed full checkpoint stays HF-compatible when adapters are also exported.""" - weights_dir = output_dir / "weights" / "step_20" + assert_process_succeeded(sft_lora_resume_process) + weights_dir = separate_adapter_output_dir / "weights" / "step_20" assert_full_checkpoint(weights_dir)