Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`ckpt.weights.save_adapter_separately` / LoRA weight exports**: LoRA weight exports now always write `weights/step_<N>` as a standalone merged Hugging Face-compatible checkpoint. When `save_adapter_separately = true`, PRIME additionally writes adapter-only weights under `weights/step_<N>/lora_adapters`; this flag no longer changes the semantics of the main exported checkpoint. (2026-04-04)
- **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31)
- **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31)
- **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27)
Expand Down
29 changes: 29 additions & 0 deletions configs/ci/integration/sft_lora/start_merged_only.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
max_steps = 10

[ckpt]

[ckpt.weights]
save_adapter_separately = false

[model]
name = "PrimeIntellect/Qwen3-0.6B"

[model.lora]
rank = 8
target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]

[data]
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 4
seq_len = 1024

[optim]
lr = 1.5e-5
4 changes: 3 additions & 1 deletion docs/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The default checkpoint directory is `checkpoints` and each checkpoint step will

Checkpointing is configured with the config key `--ckpt`. One can specify the interval (`--ckpt.interval`), whether to save checkpoints asynchronoously (`--ckpt.save-async`), how many recent step checkpoints to keep on disk (`--ckpt.keep-last`), and keep checkpoints at every N steps permanently (`--ckpt.keep-interval`). By default, we do not checkpoint to save disk space.

When weight export is enabled under `ckpt.weights`, `weights/step_<N>` is intended to be a standalone Hugging Face-compatible checkpoint. If LoRA is enabled, the full checkpoint remains merged and `save_adapter_separately=true` additionally writes an adapter-only export under `weights/step_<N>/lora_adapters`.

## SFT

Let's split the reverse text training SFT example, which does 40 steps by default, into two runs of 20 steps each.
Expand Down Expand Up @@ -54,4 +56,4 @@ uv run rl \
--orchestrator @ path/to/orch.toml \
--max-steps 20 \
--ckpt.resume-step 10
```
```
2 changes: 1 addition & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ class WeightCheckpointConfig(BaseConfig):
save_adapter_separately: Annotated[
bool,
Field(
description="Whether to save LoRA adapters separately before merging into full model weights.",
description="Whether to also save LoRA adapters separately in addition to merged full model weights.",
),
] = False

Expand Down
11 changes: 9 additions & 2 deletions src/prime_rl/trainer/ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.tokenization_utils import PreTrainedTokenizer

from prime_rl.configs.trainer import CheckpointConfig, LoRAConfig, WeightCheckpointConfig
from prime_rl.trainer.lora import has_lora_layers, save_lora_config
from prime_rl.trainer.lora import has_lora_layers, merge_lora_state_dict, save_lora_config
from prime_rl.trainer.models import PreTrainedModelPrimeRL
from prime_rl.trainer.optim import CPUOffloadOptimizer
from prime_rl.trainer.runs import Progress, get_multi_run_manager
Expand Down Expand Up @@ -385,12 +385,19 @@ def save(
state_dict = gather_weights_on_master(model, self.world.is_master, dtype=torch.bfloat16)
self.logger.debug(f"Gathered weights on master rank in {time.perf_counter() - start_time:.2f} seconds")

lora_enabled = has_lora_layers(model)
if lora_enabled:
self.logger.debug("Merging LoRA weights into full weight checkpoint")
start_time = time.perf_counter()
state_dict = merge_lora_state_dict(model, state_dict)
self.logger.debug(f"Merged LoRA weights in {time.perf_counter() - start_time:.2f} seconds")

# Remove tied weight keys to match original model format
if getattr(model.config, "tie_word_embeddings", False):
for key in getattr(model, "_tied_weights_keys", []):
state_dict.pop(key, None)

if has_lora_layers(model) and self.config.save_adapter_separately:
if lora_enabled and self.config.save_adapter_separately:
self.logger.debug("Getting run adapter state dict for weight checkpoint")
start_time = time.perf_counter()
lora_state_dict = self.get_run_adapter_state_dict()
Expand Down
85 changes: 85 additions & 0 deletions src/prime_rl/trainer/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,91 @@ def has_lora_layers(model: nn.Module) -> bool:
return False


def _module_key(prefix: str, suffix: str) -> str:
return f"{prefix}.{suffix}" if prefix else suffix


def _get_lora_scaling(module: MultiLoRAModule, run_idx: int) -> float:
return float(module._scaling_factors[run_idx].item())


def _compute_lora_delta(
lora_a: torch.Tensor,
lora_b: torch.Tensor,
scaling: float,
output_dtype: torch.dtype,
) -> torch.Tensor:
matmul_dtype = torch.float32 if output_dtype in {torch.float16, torch.bfloat16} else output_dtype
delta = torch.matmul(lora_b.to(dtype=matmul_dtype), lora_a.to(dtype=matmul_dtype))
return (delta * scaling).to(dtype=output_dtype)


def _merge_multilora_linear(
state_dict: dict[str, torch.Tensor],
prefix: str,
module: MultiLoRALinear,
run_idx: int,
) -> None:
base_key = _module_key(prefix, "base_layer.weight")
lora_a_key = _module_key(prefix, f"lora_A.{run_idx}")
lora_b_key = _module_key(prefix, f"lora_B.{run_idx}")

base_weight = state_dict[base_key]
delta = _compute_lora_delta(
state_dict[lora_a_key],
state_dict[lora_b_key],
scaling=_get_lora_scaling(module, run_idx),
output_dtype=base_weight.dtype,
)
state_dict[base_key] = base_weight + delta


def _merge_multilora_grouped_experts(
state_dict: dict[str, torch.Tensor],
prefix: str,
module: MultiLoRAGroupedExperts,
run_idx: int,
) -> None:
scaling = _get_lora_scaling(module, run_idx)
weight_specs = (
("w1", "w1_lora_A", "w1_lora_B"),
("w2", "w2_lora_A", "w2_lora_B"),
("w3", "w3_lora_A", "w3_lora_B"),
)
for base_name, lora_a_name, lora_b_name in weight_specs:
base_key = _module_key(prefix, f"base_layer.{base_name}")
lora_a_key = _module_key(prefix, f"{lora_a_name}.{run_idx}")
lora_b_key = _module_key(prefix, f"{lora_b_name}.{run_idx}")

base_weight = state_dict[base_key]
delta = _compute_lora_delta(
state_dict[lora_a_key],
state_dict[lora_b_key],
scaling=scaling,
output_dtype=base_weight.dtype,
)
state_dict[base_key] = base_weight + delta


def merge_lora_state_dict(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
run_idx: int = 0,
) -> dict[str, torch.Tensor]:
"""Merge a run's LoRA weights into gathered full-model weights."""
if not state_dict:
return state_dict

merged_state_dict = dict(state_dict)
for name, module in model.named_modules():
if isinstance(module, MultiLoRALinear):
_merge_multilora_linear(merged_state_dict, name, module, run_idx)
elif isinstance(module, MultiLoRAGroupedExperts):
_merge_multilora_grouped_experts(merged_state_dict, name, module, run_idx)

return clean_lora_state_dict(merged_state_dict)


def clean_lora_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Remove LoRA parameters and fix LoRA base layer key names for HF compatibility."""
clean_state_dict = {}
Expand Down
7 changes: 0 additions & 7 deletions src/prime_rl/trainer/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
WEIGHTS_NAME,
)

from prime_rl.trainer.lora import (
clean_lora_state_dict,
)
from prime_rl.utils.logger import get_logger

PYTORCH_WRAPPER_PREFIXES = ["_fsdp_wrapped_module.", "_orig_module.", "_checkpoint_wrapped_module."]
Expand Down Expand Up @@ -147,10 +144,6 @@ def gather_weights_on_master(
cpu_state[key] = value.to("cpu", non_blocking=False)
torch.distributed.barrier()

# Always clean up the state dict for HF compatibility
if any(".base_layer." in key or "lora_A" in key or "lora_B" in key for key in cpu_state.keys()):
cpu_state = clean_lora_state_dict(cpu_state)

return cpu_state


Expand Down
Loading