Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`model.ac`**: Trainer model configs now enable activation checkpointing by default with `mode = "selective"` and `targets = ["norm"]`. Bare `[model.ac]` / `--model.ac` now use selective AC instead of full-layer checkpointing. To force the previous full-layer behavior, set `mode = "full"`. On models without selective hooks, the default norm-only setting falls back to full checkpointing. (2026-04-01)
- **`model.ac.targets`**: Added `moe_act` as a selective activation checkpoint target for custom MoE layers. This checkpoints the routed expert activation function inside expert MLPs without checkpointing the broader routed-expert path; if both `moe_act` and `routed_experts` are selected, the broader `routed_experts` hook wins to avoid double-wrapping nested checkpoints. (2026-04-01)
- **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31)
- **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31)
- **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/scripts/run_single_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def build_command(config: BenchmarkConfig) -> list[str]:

# Add activation checkpointing if enabled
if config.ac == "Recompute":
cmd.append("--model.ac")
cmd.extend(["--model.ac", "--model.ac.mode", "full"])
elif config.ac == "Selective":
cmd.extend(["--model.ac", "--model.ac.mode", "selective"])
if config.selective_targets:
Expand Down
72 changes: 55 additions & 17 deletions docs/memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ While most of our parallelism techniques in prime-rl are designed to scale train

These techniques target the trainer part of prime-rl.


## TLDR: config to use for maximum memory usage reduction with correct throughput

```toml
Expand All @@ -19,6 +18,7 @@ optim_cpu_offload = true
[trainer.model.compile]

[trainer.model.ac]
mode = "full"
freq = 1

[trainer.model.ac_offloading]
Expand All @@ -29,25 +29,66 @@ max_inflight_activations = 1

Activation checkpointing discards intermediate activations during the forward pass and recomputes them during the backward pass, trading compute for memory.

To enable it, use:
Trainer configs now enable activation checkpointing by default with selective AC on the cheapest target:

```toml
[trainer.model.ac]
mode = "selective"
targets = ["norm"]
freq = 1
```

You do not need to add this section unless you want to override the defaults.

`freq` controls how often layers are checkpointed: every `freq` layers. Lower values yield lower memory usage (e.g. `freq = 1` checkpoints every layer).

## Activation offloading
To force full-layer checkpointing instead of the default selective AC, use:

Activation offloading offloads the activations to CPU to reduce the memory usage of the trainer. It can be used in combination with activation checkpointing.
```toml
[trainer.model.ac]
mode = "full"
freq = 1
```

To enable it, use:
If a model or layer does not expose selective AC hooks, prime-rl falls back to full layer checkpointing. The default `targets = ["norm"]` is chosen so this fallback stays safe.

## Selective activation checkpointing tuning

Selective AC lets you add memory savings incrementally before switching all the way to `mode = "full"`.

The orders below are rough tuning heuristics from best memory-saved/recompute tradeoff to worst. Start on the left and add targets as needed. The runtime treats `targets` as a set, so the order in your config file does not matter.

```toml
[trainer.model.ac]
freq = 1
mode = "selective"
targets = ["norm", "attn_proj", "moe_act"]
```

Available targets by model family:
- `llama`: `norm`, `attn_proj`, `mlp`
- `minimax_m2`: `norm`, `moe_act`, `attn_proj`, `routed_experts`
- `qwen3_moe` / `glm4_moe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `routed_experts`
- `afmoe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts`
- `qwen3_5_moe` / `nemotron_h`: `norm`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts`
- `glm_moe_dsa`: `norm`, `mlp`, `mla_up_proj`, `moe_act`, `attn_proj`, `routed_experts`

Notes:
- These lists are unions across the model. On mixed-architecture families, not every target applies to every layer.
- `qwen3_moe`, `glm4_moe`, and `glm_moe_dsa` contain both dense and MoE layers. `mlp` applies to dense layers, while `moe_act` and `routed_experts` apply to MoE layers.
- `afmoe` only exposes `linear_attn` on its sliding-window attention layers.
- `qwen3_5_moe` only exposes `linear_attn` on its GatedDeltaNet layers.
- `nemotron_h` only exposes `linear_attn` on Mamba layers, while `moe_act` and `routed_experts` only apply to its LatentMoE layers.
- `glm_moe_dsa` only exposes `mla_up_proj` on its sparse MLA attention layers.

When selective tuning is not enough, switch to `mode = "full"`.

## Activation offloading

Activation offloading offloads the activations to CPU to reduce the memory usage of the trainer. It can be used in combination with activation checkpointing.

The default selective AC config is already enabled, so to add activation offloading use:

```toml
[trainer.model.ac_offloading]
max_inflight_activations = 5
```
Expand All @@ -63,14 +104,13 @@ To enable it, use:
fused_lm_head_token_chunk_size = auto
```


## Expert parallelism

While expert parallelism splits the weights of the experts across all GPUs like FSDP, using EP still reduces memory usage by reducing the communication size and therefore the FSDP buffer.

EP is only available for models with MoE layers using the custom model implementation.

```
```toml
[trainer.model]
impl = "custom"
ep = 8
Expand All @@ -82,20 +122,19 @@ Context parallelism splits the context into smaller chunks to reduce the memory

CP is only available for certain models and only with the custom model implementation.

```
```toml
[trainer.model]
impl = "custom"
cp = 2
```

We recommend CP 2 or CP 4 for most 128K sequence length training runs. Can be pushed to 8.


## torch compile

Enabling torch.compile can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation.
Enabling `torch.compile` can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation.

```
```toml
[trainer.model.compile]
```

Expand All @@ -105,8 +144,8 @@ Offloading the optimizer states to CPU can reduce the memory usage of the traine

In RL, in contrast with pretraining, we end up with many gradient accumulation steps, so the cost of offloading the optimizer states is not as high as in pretraining, and indeed barely noticeable.

```
[trainer.optim]
```toml
[trainer.model]
optim_cpu_offload = true
```

Expand All @@ -116,7 +155,7 @@ FSDP CPU offloading offloads the parameters, gradients, and optimizer states to

This will make training significantly slower and is not recommended most of the time.

```
```toml
[trainer.model]
fsdp_cpu_offload = true
```
Expand All @@ -125,8 +164,7 @@ fsdp_cpu_offload = true

LoRA training significantly reduces the memory usage of the trainer at the cost of smaller gradient updates.

```
```toml
[trainer.model.lora]
rank = 8
```

3 changes: 2 additions & 1 deletion docs/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ ulimit -n 32000
> I'm getting CUDA out of memory errors.

Assuming this is happening on the RL or SFT trainer, you can try the following:
- Use full activation checkpointing (`--model.ac`)
- Trainer configs already use selective activation checkpointing with `targets = ["norm"]` by default.
- If you still OOM, switch to full activation checkpointing (`--model.ac.mode full`) or add more selective targets as described in `docs/memory_usage.md`.
- Reduce the the micro batch size (`--data.micro-batch-size`) and sequence length (`--data.seq-len`)
- (*Experimental*) Use context parallelism with `--model.cp`

Expand Down
20 changes: 13 additions & 7 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class ActivationCheckpointConfig(BaseConfig):
mode: Annotated[
Literal["full", "selective"],
Field(
description="Whether to checkpoint whole transformer blocks (`full`) or selected subcomponents inside supported custom decoder layers (`selective`).",
description="Whether to checkpoint whole transformer blocks (`full`) or selected subcomponents inside supported decoder layers (`selective`). Defaults to `selective`.",
),
] = "full"
] = "selective"

freq: Annotated[
int,
Expand All @@ -63,7 +63,7 @@ class ActivationCheckpointConfig(BaseConfig):
targets: Annotated[
list[str],
Field(
description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints routed expert compute in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.",
description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `moe_act` checkpoints routed MoE activation functions inside expert MLPs, such as SiLU/GLU or relu2, without checkpointing the surrounding expert projections. `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints the broader routed expert path in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.",
),
] = ["norm"]

Expand Down Expand Up @@ -197,9 +197,9 @@ class ModelConfig(BaseModelConfig):
ac: Annotated[
ActivationCheckpointConfig | None,
Field(
description="Whether to apply activation checkpointing to the model. If None, will not apply activation checkpointing.",
description='Activation checkpointing config. Defaults to selective activation checkpointing with `targets = ["norm"]`. Set to None to disable activation checkpointing.',
),
] = None
] = ActivationCheckpointConfig(mode="selective")

ac_offloading: Annotated[
ActivationOffloadingConfig | None,
Expand Down Expand Up @@ -385,8 +385,14 @@ def ac_offloading_requires_ac(self):

@model_validator(mode="after")
def selective_ac_only_with_custom_impl(self):
if self.ac is not None and self.ac.mode == "selective" and self.impl not in ("custom", "auto"):
raise ValueError("Selective activation checkpointing requires model.impl='custom' or 'auto'")
if self.ac is None or self.ac.mode != "selective" or self.impl in ("custom", "auto"):
return self

if frozenset(self.ac.targets) != frozenset({"norm"}):
raise ValueError(
"Selective activation checkpointing with model.impl='hf' only supports model.ac.targets=['norm']; "
"use model.impl='custom' or 'auto', switch to full AC, or disable AC."
)
return self

@model_validator(mode="after")
Expand Down
11 changes: 9 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,19 @@ def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig):

if ac_config.mode == "selective":
unsupported_targets = frozenset(target_list) - model_supported_targets
if unsupported_targets:
allow_norm_only_full_fallback = selective_layers == 0 and frozenset(target_list) == frozenset({"norm"})

if unsupported_targets and not allow_norm_only_full_fallback:
raise ValueError(
f"Selective activation checkpoint targets {sorted(unsupported_targets)} are not supported "
f"by the selected model layers. Supported targets across the model: {sorted(model_supported_targets)}"
)
if fallback_layer_types:
if allow_norm_only_full_fallback:
logger.warning(
"The selected model does not expose selective activation checkpointing hooks; "
"falling back to full checkpointing for all selected layers."
)
elif fallback_layer_types:
logger.warning(
"Selective activation checkpointing is not supported for layer types "
f"{sorted(fallback_layer_types)}; falling back to full checkpointing for those layers."
Expand Down
27 changes: 26 additions & 1 deletion src/prime_rl/trainer/models/layers/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
`attn_projections(...)`.
- `mlp`: expose a dense `layer.mlp.forward(...)`. A module is treated as dense
when it does not define `_run_routed_experts` or `tokens_per_expert`.
- `moe_act`: expose `layer.mlp.experts.moe_act(...)` or `layer.mlp.moe_act(...)`
for routed MoE activations between expert input and output projections.
- `routed_experts`: expose `layer.mlp._run_routed_experts(...)` for the MoE
expert path, and optionally `layer.mlp._run_local_routed_experts(...)` when
local expert compute is separated from dispatch/combine.
Expand All @@ -33,7 +35,9 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

SELECTIVE_AC_TARGETS = frozenset({"norm", "attn_proj", "mlp", "mla_up_proj", "routed_experts", "linear_attn"})
SELECTIVE_AC_TARGETS = frozenset(
{"norm", "attn_proj", "mlp", "moe_act", "mla_up_proj", "routed_experts", "linear_attn"}
)
_PATCHED_METHODS_ATTR = "_prime_rl_selective_ac_patched_methods"


Expand Down Expand Up @@ -96,19 +100,36 @@ def _get_linear_attn_module(layer: nn.Module) -> nn.Module | None:
return None


def _get_moe_act_module(mlp: nn.Module | None) -> nn.Module | None:
if mlp is None:
return None

if hasattr(mlp, "moe_act"):
return mlp

experts = getattr(mlp, "experts", None)
if experts is not None and hasattr(experts, "moe_act"):
return experts

return None


def get_supported_targets(layer: nn.Module) -> frozenset[str]:
"""Infer which selective activation checkpoint targets a decoder layer supports."""
supported_targets = {"norm"}
self_attn = getattr(layer, "self_attn", None)
mlp = getattr(layer, "mlp", None)
linear_attn = _get_linear_attn_module(layer)
moe_act_module = _get_moe_act_module(mlp)

if _supports_attn_proj(self_attn):
supported_targets.add("attn_proj")
if self_attn is not None and hasattr(self_attn, "mla_up_proj"):
supported_targets.add("mla_up_proj")
if mlp is not None and _is_dense_mlp(mlp):
supported_targets.add("mlp")
if moe_act_module is not None:
supported_targets.add("moe_act")
if mlp is not None and (hasattr(mlp, "_run_routed_experts") or hasattr(mlp, "_run_local_routed_experts")):
supported_targets.add("routed_experts")
if linear_attn is not None:
Expand All @@ -127,7 +148,9 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s
self_attn = getattr(layer, "self_attn", None)
mlp = getattr(layer, "mlp", None)
linear_attn = _get_linear_attn_module(layer)
moe_act_module = _get_moe_act_module(mlp)
attn_proj_is_subsumed = "linear_attn" in enabled_targets and linear_attn is self_attn
moe_act_is_subsumed = "routed_experts" in enabled_targets

if self_attn is not None and "attn_proj" in enabled_targets and not attn_proj_is_subsumed:
checkpoint_method(self_attn, "attn_projections")
Expand All @@ -136,6 +159,8 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s
checkpoint_method(self_attn, "mla_up_proj")
if mlp is not None and "mlp" in enabled_targets:
checkpoint_method(mlp, "forward")
if moe_act_module is not None and "moe_act" in enabled_targets and not moe_act_is_subsumed:
checkpoint_method(moe_act_module, "moe_act")
if mlp is not None and "routed_experts" in enabled_targets:
if hasattr(mlp, "_run_routed_experts"):
checkpoint_method(mlp, "_run_routed_experts")
Expand Down
Loading
Loading