Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`model.ac`**: `trainer.model.ac` no longer bakes in a config-level default. When it is unset, supported custom implementations default to `mode = "selective"` with `targets = ["norm"]`, while HF implementations leave activation checkpointing disabled. Bare `[model.ac]` / `--model.ac` again mean explicit full-layer checkpointing, and selective AC remains custom-only. (2026-04-01)
- **`model.ac.targets`**: Added `moe_act` as a selective activation checkpoint target for custom MoE layers. This checkpoints the routed expert activation function inside expert MLPs without checkpointing the broader routed-expert path; if both `moe_act` and `routed_experts` are selected, the broader `routed_experts` hook wins to avoid double-wrapping nested checkpoints. (2026-04-01)
- **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31)
- **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31)
- **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27)
Expand Down
93 changes: 78 additions & 15 deletions docs/memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ While most of our parallelism techniques in prime-rl are designed to scale train

These techniques target the trainer part of prime-rl.


## TLDR: config to use for maximum memory usage reduction with correct throughput

```toml
Expand All @@ -19,6 +18,7 @@ optim_cpu_offload = true
[trainer.model.compile]

[trainer.model.ac]
mode = "full"
freq = 1

[trainer.model.ac_offloading]
Expand All @@ -29,20 +29,86 @@ max_inflight_activations = 1

Activation checkpointing discards intermediate activations during the forward pass and recomputes them during the backward pass, trading compute for memory.

To enable it, use:
If `trainer.model.ac` is unset, supported custom implementations default to selective AC on the cheapest target:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh ? where is it define, I don't like this tbh

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would agree here. a config unset should not default do doing it anyway

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can ofc default to selective ac if this is reasonable, but it should be explicit in the configs imo (e.g. an agent should see ac.mode = selective instead of ac.mode = None)


```toml
[trainer.model.ac]
mode = "selective"
targets = ["norm"]
freq = 1
```

HF or other non-custom implementations do not auto-enable activation checkpointing when `trainer.model.ac` is unset.

To explicitly enable full-layer checkpointing, use:

```toml
[trainer.model.ac]
freq = 1
```

This is equivalent to:

```toml
[trainer.model.ac]
mode = "full"
freq = 1
```

`freq` controls how often layers are checkpointed: every `freq` layers. Lower values yield lower memory usage (e.g. `freq = 1` checkpoints every layer).
## Selective activation checkpointing tuning

Selective AC is only supported with the custom model implementation. It lets you add memory savings incrementally before switching all the way to `mode = "full"`.

The orders below are rough tuning heuristics from best memory-saved/recompute tradeoff to worst. Start on the left and add targets as needed. The runtime treats `targets` as a set, so the order in your config file does not matter.

```toml
[trainer.model]
impl = "custom"

[trainer.model.ac]
mode = "selective"
targets = ["norm", "attn_proj", "moe_act"]
```

Available targets by model family:
- `llama`: `norm`, `attn_proj`, `mlp`
- `minimax_m2`: `norm`, `moe_act`, `attn_proj`, `routed_experts`
- `qwen3_moe` / `glm4_moe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `routed_experts`
- `afmoe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts`
- `qwen3_5_moe` / `nemotron_h`: `norm`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts`
- `glm_moe_dsa`: `norm`, `mlp`, `mla_up_proj`, `moe_act`, `attn_proj`, `routed_experts`

Notes:
- These lists are unions across the model. On mixed-architecture families, not every target applies to every layer.
- `qwen3_moe`, `glm4_moe`, and `glm_moe_dsa` contain both dense and MoE layers. `mlp` applies to dense layers, while `moe_act` and `routed_experts` apply to MoE layers.
- `afmoe` only exposes `linear_attn` on its sliding-window attention layers.
- `qwen3_5_moe` only exposes `linear_attn` on its GatedDeltaNet layers.
- `nemotron_h` only exposes `linear_attn` on Mamba layers, while `moe_act` and `routed_experts` only apply to its LatentMoE layers.
- `glm_moe_dsa` only exposes `mla_up_proj` on its sparse MLA attention layers.

When selective tuning is not enough, switch to `mode = "full"`.

## Activation offloading

Activation offloading offloads the activations to CPU to reduce the memory usage of the trainer. It can be used in combination with activation checkpointing.

To enable it, use:
If you set `trainer.model.ac_offloading` without `trainer.model.ac`, prime-rl also enables explicit full-layer activation checkpointing.

To use activation offloading with the custom-model selective setup, configure both explicitly:

```toml
[trainer.model]
impl = "custom"

[trainer.model.ac]
mode = "selective"
targets = ["norm"]

[trainer.model.ac_offloading]
max_inflight_activations = 5
```

To use activation offloading with full AC, use:

```toml
[trainer.model.ac]
Expand All @@ -63,14 +129,13 @@ To enable it, use:
fused_lm_head_token_chunk_size = auto
```


## Expert parallelism

While expert parallelism splits the weights of the experts across all GPUs like FSDP, using EP still reduces memory usage by reducing the communication size and therefore the FSDP buffer.

EP is only available for models with MoE layers using the custom model implementation.

```
```toml
[trainer.model]
impl = "custom"
ep = 8
Expand All @@ -82,20 +147,19 @@ Context parallelism splits the context into smaller chunks to reduce the memory

CP is only available for certain models and only with the custom model implementation.

```
```toml
[trainer.model]
impl = "custom"
cp = 2
```

We recommend CP 2 or CP 4 for most 128K sequence length training runs. Can be pushed to 8.


## torch compile

Enabling torch.compile can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation.
Enabling `torch.compile` can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation.

```
```toml
[trainer.model.compile]
```

Expand All @@ -105,8 +169,8 @@ Offloading the optimizer states to CPU can reduce the memory usage of the traine

In RL, in contrast with pretraining, we end up with many gradient accumulation steps, so the cost of offloading the optimizer states is not as high as in pretraining, and indeed barely noticeable.

```
[trainer.optim]
```toml
[trainer.model]
optim_cpu_offload = true
```

Expand All @@ -116,7 +180,7 @@ FSDP CPU offloading offloads the parameters, gradients, and optimizer states to

This will make training significantly slower and is not recommended most of the time.

```
```toml
[trainer.model]
fsdp_cpu_offload = true
```
Expand All @@ -125,8 +189,7 @@ fsdp_cpu_offload = true

LoRA training significantly reduces the memory usage of the trainer at the cost of smaller gradient updates.

```
```toml
[trainer.model.lora]
rank = 8
```

3 changes: 2 additions & 1 deletion docs/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ ulimit -n 32000
> I'm getting CUDA out of memory errors.

Assuming this is happening on the RL or SFT trainer, you can try the following:
- Use full activation checkpointing (`--model.ac`)
- Supported custom implementations already default to selective activation checkpointing with `targets = ["norm"]`.
- If you still OOM, enable full activation checkpointing with `--model.ac` or, on custom implementations, add more selective targets as described in `docs/memory_usage.md`.
- Reduce the the micro batch size (`--data.micro-batch-size`) and sequence length (`--data.seq-len`)
- (*Experimental*) Use context parallelism with `--model.cp`

Expand Down
8 changes: 6 additions & 2 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ActivationCheckpointConfig(BaseConfig):
targets: Annotated[
list[str],
Field(
description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints routed expert compute in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.",
description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `moe_act` checkpoints routed MoE activation functions inside expert MLPs, such as SiLU/GLU or relu2, without checkpointing the surrounding expert projections. `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints the broader routed expert path in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.",
),
] = ["norm"]

Expand Down Expand Up @@ -197,7 +197,11 @@ class ModelConfig(BaseModelConfig):
ac: Annotated[
ActivationCheckpointConfig | None,
Field(
description="Whether to apply activation checkpointing to the model. If None, will not apply activation checkpointing.",
description=(
"Activation checkpointing config. If None, supported custom implementations default to "
'selective activation checkpointing with `targets = ["norm"]`; other '
"implementations leave activation checkpointing disabled."
),
),
] = None

Expand Down
26 changes: 24 additions & 2 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,22 @@ def reshard_module(model: nn.Module):
module.reshard()


def get_default_activation_checkpoint_config(resolved_impl: str) -> ActivationCheckpointConfig | None:
if resolved_impl == "custom":
return ActivationCheckpointConfig(mode="selective")
return None


def get_effective_activation_checkpoint_config(
config: ModelConfig, model: nn.Module
) -> ActivationCheckpointConfig | None:
if config.ac is not None:
return config.ac

resolved_impl = "custom" if isinstance(model, PreTrainedModelPrimeRL) else "hf"
return get_default_activation_checkpoint_config(resolved_impl)


def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig):
logger = get_logger()
language_model = get_language_model(model)
Expand Down Expand Up @@ -797,8 +813,14 @@ def setup_model(
freeze_all_except_lora_and_specified(model, config.lora)

# the right order is AC -> Compile -> FSDP
if config.ac is not None:
apply_ac(model, config.ac)
effective_ac_config = get_effective_activation_checkpoint_config(config, model)
if effective_ac_config is not None:
if config.ac is None:
logger.info(
"Defaulting activation checkpointing to selective mode "
"(freq=1, targets=['norm']) for the custom implementation."
)
apply_ac(model, effective_ac_config)
if config.compile is not None:
apply_compile(model, config.compile)

Expand Down
27 changes: 26 additions & 1 deletion src/prime_rl/trainer/models/layers/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
`attn_projections(...)`.
- `mlp`: expose a dense `layer.mlp.forward(...)`. A module is treated as dense
when it does not define `_run_routed_experts` or `tokens_per_expert`.
- `moe_act`: expose `layer.mlp.experts.moe_act(...)` or `layer.mlp.moe_act(...)`
for routed MoE activations between expert input and output projections.
- `routed_experts`: expose `layer.mlp._run_routed_experts(...)` for the MoE
expert path, and optionally `layer.mlp._run_local_routed_experts(...)` when
local expert compute is separated from dispatch/combine.
Expand All @@ -33,7 +35,9 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

SELECTIVE_AC_TARGETS = frozenset({"norm", "attn_proj", "mlp", "mla_up_proj", "routed_experts", "linear_attn"})
SELECTIVE_AC_TARGETS = frozenset(
{"norm", "attn_proj", "mlp", "moe_act", "mla_up_proj", "routed_experts", "linear_attn"}
)
_PATCHED_METHODS_ATTR = "_prime_rl_selective_ac_patched_methods"


Expand Down Expand Up @@ -96,19 +100,36 @@ def _get_linear_attn_module(layer: nn.Module) -> nn.Module | None:
return None


def _get_moe_act_module(mlp: nn.Module | None) -> nn.Module | None:
if mlp is None:
return None

if hasattr(mlp, "moe_act"):
return mlp

experts = getattr(mlp, "experts", None)
if experts is not None and hasattr(experts, "moe_act"):
return experts

return None


def get_supported_targets(layer: nn.Module) -> frozenset[str]:
"""Infer which selective activation checkpoint targets a decoder layer supports."""
supported_targets = {"norm"}
self_attn = getattr(layer, "self_attn", None)
mlp = getattr(layer, "mlp", None)
linear_attn = _get_linear_attn_module(layer)
moe_act_module = _get_moe_act_module(mlp)

if _supports_attn_proj(self_attn):
supported_targets.add("attn_proj")
if self_attn is not None and hasattr(self_attn, "mla_up_proj"):
supported_targets.add("mla_up_proj")
if mlp is not None and _is_dense_mlp(mlp):
supported_targets.add("mlp")
if moe_act_module is not None:
supported_targets.add("moe_act")
if mlp is not None and (hasattr(mlp, "_run_routed_experts") or hasattr(mlp, "_run_local_routed_experts")):
supported_targets.add("routed_experts")
if linear_attn is not None:
Expand All @@ -127,7 +148,9 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s
self_attn = getattr(layer, "self_attn", None)
mlp = getattr(layer, "mlp", None)
linear_attn = _get_linear_attn_module(layer)
moe_act_module = _get_moe_act_module(mlp)
attn_proj_is_subsumed = "linear_attn" in enabled_targets and linear_attn is self_attn
moe_act_is_subsumed = "routed_experts" in enabled_targets

if self_attn is not None and "attn_proj" in enabled_targets and not attn_proj_is_subsumed:
checkpoint_method(self_attn, "attn_projections")
Expand All @@ -136,6 +159,8 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s
checkpoint_method(self_attn, "mla_up_proj")
if mlp is not None and "mlp" in enabled_targets:
checkpoint_method(mlp, "forward")
if moe_act_module is not None and "moe_act" in enabled_targets and not moe_act_is_subsumed:
checkpoint_method(moe_act_module, "moe_act")
if mlp is not None and "routed_experts" in enabled_targets:
if hasattr(mlp, "_run_routed_experts"):
checkpoint_method(mlp, "_run_routed_experts")
Expand Down
Loading
Loading