diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e4842cebe..5ef35fdebe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes). +- **`model.ac`**: `trainer.model.ac` no longer bakes in a config-level default. When it is unset, supported custom implementations default to `mode = "selective"` with `targets = ["norm"]`, while HF implementations leave activation checkpointing disabled. Bare `[model.ac]` / `--model.ac` again mean explicit full-layer checkpointing, and selective AC remains custom-only. (2026-04-01) +- **`model.ac.targets`**: Added `moe_act` as a selective activation checkpoint target for custom MoE layers. This checkpoints the routed expert activation function inside expert MLPs without checkpointing the broader routed-expert path; if both `moe_act` and `routed_experts` are selected, the broader `routed_experts` hook wins to avoid double-wrapping nested checkpoints. (2026-04-01) - **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31) - **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31) - **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27) diff --git a/docs/memory_usage.md b/docs/memory_usage.md index b36c117254..0063b6b91b 100644 --- a/docs/memory_usage.md +++ b/docs/memory_usage.md @@ -4,7 +4,6 @@ While most of our parallelism techniques in prime-rl are designed to scale train These techniques target the trainer part of prime-rl. - ## TLDR: config to use for maximum memory usage reduction with correct throughput ```toml @@ -19,6 +18,7 @@ optim_cpu_offload = true [trainer.model.compile] [trainer.model.ac] +mode = "full" freq = 1 [trainer.model.ac_offloading] @@ -29,20 +29,86 @@ max_inflight_activations = 1 Activation checkpointing discards intermediate activations during the forward pass and recomputes them during the backward pass, trading compute for memory. -To enable it, use: +If `trainer.model.ac` is unset, supported custom implementations default to selective AC on the cheapest target: + +```toml +[trainer.model.ac] +mode = "selective" +targets = ["norm"] +freq = 1 +``` + +HF or other non-custom implementations do not auto-enable activation checkpointing when `trainer.model.ac` is unset. + +To explicitly enable full-layer checkpointing, use: + +```toml +[trainer.model.ac] +freq = 1 +``` + +This is equivalent to: ```toml [trainer.model.ac] +mode = "full" freq = 1 ``` -`freq` controls how often layers are checkpointed: every `freq` layers. Lower values yield lower memory usage (e.g. `freq = 1` checkpoints every layer). +## Selective activation checkpointing tuning + +Selective AC is only supported with the custom model implementation. It lets you add memory savings incrementally before switching all the way to `mode = "full"`. + +The orders below are rough tuning heuristics from best memory-saved/recompute tradeoff to worst. Start on the left and add targets as needed. The runtime treats `targets` as a set, so the order in your config file does not matter. + +```toml +[trainer.model] +impl = "custom" + +[trainer.model.ac] +mode = "selective" +targets = ["norm", "attn_proj", "moe_act"] +``` + +Available targets by model family: +- `llama`: `norm`, `attn_proj`, `mlp` +- `minimax_m2`: `norm`, `moe_act`, `attn_proj`, `routed_experts` +- `qwen3_moe` / `glm4_moe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `routed_experts` +- `afmoe`: `norm`, `mlp`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts` +- `qwen3_5_moe` / `nemotron_h`: `norm`, `moe_act`, `attn_proj`, `linear_attn`, `routed_experts` +- `glm_moe_dsa`: `norm`, `mlp`, `mla_up_proj`, `moe_act`, `attn_proj`, `routed_experts` + +Notes: +- These lists are unions across the model. On mixed-architecture families, not every target applies to every layer. +- `qwen3_moe`, `glm4_moe`, and `glm_moe_dsa` contain both dense and MoE layers. `mlp` applies to dense layers, while `moe_act` and `routed_experts` apply to MoE layers. +- `afmoe` only exposes `linear_attn` on its sliding-window attention layers. +- `qwen3_5_moe` only exposes `linear_attn` on its GatedDeltaNet layers. +- `nemotron_h` only exposes `linear_attn` on Mamba layers, while `moe_act` and `routed_experts` only apply to its LatentMoE layers. +- `glm_moe_dsa` only exposes `mla_up_proj` on its sparse MLA attention layers. + +When selective tuning is not enough, switch to `mode = "full"`. ## Activation offloading Activation offloading offloads the activations to CPU to reduce the memory usage of the trainer. It can be used in combination with activation checkpointing. -To enable it, use: +If you set `trainer.model.ac_offloading` without `trainer.model.ac`, prime-rl also enables explicit full-layer activation checkpointing. + +To use activation offloading with the custom-model selective setup, configure both explicitly: + +```toml +[trainer.model] +impl = "custom" + +[trainer.model.ac] +mode = "selective" +targets = ["norm"] + +[trainer.model.ac_offloading] +max_inflight_activations = 5 +``` + +To use activation offloading with full AC, use: ```toml [trainer.model.ac] @@ -63,14 +129,13 @@ To enable it, use: fused_lm_head_token_chunk_size = auto ``` - ## Expert parallelism While expert parallelism splits the weights of the experts across all GPUs like FSDP, using EP still reduces memory usage by reducing the communication size and therefore the FSDP buffer. EP is only available for models with MoE layers using the custom model implementation. -``` +```toml [trainer.model] impl = "custom" ep = 8 @@ -82,7 +147,7 @@ Context parallelism splits the context into smaller chunks to reduce the memory CP is only available for certain models and only with the custom model implementation. -``` +```toml [trainer.model] impl = "custom" cp = 2 @@ -90,12 +155,11 @@ cp = 2 We recommend CP 2 or CP 4 for most 128K sequence length training runs. Can be pushed to 8. - ## torch compile -Enabling torch.compile can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation. +Enabling `torch.compile` can reduce the memory usage for certain model architectures, especially MoE with the custom model implementation. -``` +```toml [trainer.model.compile] ``` @@ -105,8 +169,8 @@ Offloading the optimizer states to CPU can reduce the memory usage of the traine In RL, in contrast with pretraining, we end up with many gradient accumulation steps, so the cost of offloading the optimizer states is not as high as in pretraining, and indeed barely noticeable. -``` -[trainer.optim] +```toml +[trainer.model] optim_cpu_offload = true ``` @@ -116,7 +180,7 @@ FSDP CPU offloading offloads the parameters, gradients, and optimizer states to This will make training significantly slower and is not recommended most of the time. -``` +```toml [trainer.model] fsdp_cpu_offload = true ``` @@ -125,8 +189,7 @@ fsdp_cpu_offload = true LoRA training significantly reduces the memory usage of the trainer at the cost of smaller gradient updates. -``` +```toml [trainer.model.lora] rank = 8 ``` - diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 408ad0df87..13dfce3272 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -11,7 +11,8 @@ ulimit -n 32000 > I'm getting CUDA out of memory errors. Assuming this is happening on the RL or SFT trainer, you can try the following: -- Use full activation checkpointing (`--model.ac`) +- Supported custom implementations already default to selective activation checkpointing with `targets = ["norm"]`. +- If you still OOM, enable full activation checkpointing with `--model.ac` or, on custom implementations, add more selective targets as described in `docs/memory_usage.md`. - Reduce the the micro batch size (`--data.micro-batch-size`) and sequence length (`--data.seq-len`) - (*Experimental*) Use context parallelism with `--model.cp` diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index 3cceb51296..3256f5857e 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -63,7 +63,7 @@ class ActivationCheckpointConfig(BaseConfig): targets: Annotated[ list[str], Field( - description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints routed expert compute in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.", + description="Selective checkpoint targets. `norm` checkpoints every norm module inside selected layers (decoder, attention, MLA, etc.). `attn_proj` checkpoints projection-side attention work outside the kernel, including input/output projections, attention-local norms, RoPE, gating, and model-specific MLA projection helpers where exposed. `mlp` checkpoints the entire dense MLP forward (not applicable to MoE layers). `moe_act` checkpoints routed MoE activation functions inside expert MLPs, such as SiLU/GLU or relu2, without checkpointing the surrounding expert projections. `mla_up_proj` checkpoints MLA Q/KV up-projection work where supported. `routed_experts` checkpoints the broader routed expert path in MoE layers (including LatentMoE). `linear_attn` checkpoints supported token mixers outside the standard softmax-attention path, including NemotronH Mamba layers, Qwen3.5-MoE GatedDeltaNet layers, and AFMoE sliding-window attention layers.", ), ] = ["norm"] @@ -197,7 +197,11 @@ class ModelConfig(BaseModelConfig): ac: Annotated[ ActivationCheckpointConfig | None, Field( - description="Whether to apply activation checkpointing to the model. If None, will not apply activation checkpointing.", + description=( + "Activation checkpointing config. If None, supported custom implementations default to " + 'selective activation checkpointing with `targets = ["norm"]`; other ' + "implementations leave activation checkpointing disabled." + ), ), ] = None diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index b5e1fcd4c5..60edad91e1 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -619,6 +619,22 @@ def reshard_module(model: nn.Module): module.reshard() +def get_default_activation_checkpoint_config(resolved_impl: str) -> ActivationCheckpointConfig | None: + if resolved_impl == "custom": + return ActivationCheckpointConfig(mode="selective") + return None + + +def get_effective_activation_checkpoint_config( + config: ModelConfig, model: nn.Module +) -> ActivationCheckpointConfig | None: + if config.ac is not None: + return config.ac + + resolved_impl = "custom" if isinstance(model, PreTrainedModelPrimeRL) else "hf" + return get_default_activation_checkpoint_config(resolved_impl) + + def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig): logger = get_logger() language_model = get_language_model(model) @@ -797,8 +813,14 @@ def setup_model( freeze_all_except_lora_and_specified(model, config.lora) # the right order is AC -> Compile -> FSDP - if config.ac is not None: - apply_ac(model, config.ac) + effective_ac_config = get_effective_activation_checkpoint_config(config, model) + if effective_ac_config is not None: + if config.ac is None: + logger.info( + "Defaulting activation checkpointing to selective mode " + "(freq=1, targets=['norm']) for the custom implementation." + ) + apply_ac(model, effective_ac_config) if config.compile is not None: apply_compile(model, config.compile) diff --git a/src/prime_rl/trainer/models/layers/checkpointing.py b/src/prime_rl/trainer/models/layers/checkpointing.py index ef4c41cd7d..5daa7b7758 100644 --- a/src/prime_rl/trainer/models/layers/checkpointing.py +++ b/src/prime_rl/trainer/models/layers/checkpointing.py @@ -15,6 +15,8 @@ `attn_projections(...)`. - `mlp`: expose a dense `layer.mlp.forward(...)`. A module is treated as dense when it does not define `_run_routed_experts` or `tokens_per_expert`. +- `moe_act`: expose `layer.mlp.experts.moe_act(...)` or `layer.mlp.moe_act(...)` + for routed MoE activations between expert input and output projections. - `routed_experts`: expose `layer.mlp._run_routed_experts(...)` for the MoE expert path, and optionally `layer.mlp._run_local_routed_experts(...)` when local expert compute is separated from dispatch/combine. @@ -33,7 +35,9 @@ import torch.nn as nn from torch.utils.checkpoint import checkpoint -SELECTIVE_AC_TARGETS = frozenset({"norm", "attn_proj", "mlp", "mla_up_proj", "routed_experts", "linear_attn"}) +SELECTIVE_AC_TARGETS = frozenset( + {"norm", "attn_proj", "mlp", "moe_act", "mla_up_proj", "routed_experts", "linear_attn"} +) _PATCHED_METHODS_ATTR = "_prime_rl_selective_ac_patched_methods" @@ -96,12 +100,27 @@ def _get_linear_attn_module(layer: nn.Module) -> nn.Module | None: return None +def _get_moe_act_module(mlp: nn.Module | None) -> nn.Module | None: + if mlp is None: + return None + + if hasattr(mlp, "moe_act"): + return mlp + + experts = getattr(mlp, "experts", None) + if experts is not None and hasattr(experts, "moe_act"): + return experts + + return None + + def get_supported_targets(layer: nn.Module) -> frozenset[str]: """Infer which selective activation checkpoint targets a decoder layer supports.""" supported_targets = {"norm"} self_attn = getattr(layer, "self_attn", None) mlp = getattr(layer, "mlp", None) linear_attn = _get_linear_attn_module(layer) + moe_act_module = _get_moe_act_module(mlp) if _supports_attn_proj(self_attn): supported_targets.add("attn_proj") @@ -109,6 +128,8 @@ def get_supported_targets(layer: nn.Module) -> frozenset[str]: supported_targets.add("mla_up_proj") if mlp is not None and _is_dense_mlp(mlp): supported_targets.add("mlp") + if moe_act_module is not None: + supported_targets.add("moe_act") if mlp is not None and (hasattr(mlp, "_run_routed_experts") or hasattr(mlp, "_run_local_routed_experts")): supported_targets.add("routed_experts") if linear_attn is not None: @@ -127,7 +148,9 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s self_attn = getattr(layer, "self_attn", None) mlp = getattr(layer, "mlp", None) linear_attn = _get_linear_attn_module(layer) + moe_act_module = _get_moe_act_module(mlp) attn_proj_is_subsumed = "linear_attn" in enabled_targets and linear_attn is self_attn + moe_act_is_subsumed = "routed_experts" in enabled_targets if self_attn is not None and "attn_proj" in enabled_targets and not attn_proj_is_subsumed: checkpoint_method(self_attn, "attn_projections") @@ -136,6 +159,8 @@ def set_selective_activation_checkpointing(layer: nn.Module, targets: Iterable[s checkpoint_method(self_attn, "mla_up_proj") if mlp is not None and "mlp" in enabled_targets: checkpoint_method(mlp, "forward") + if moe_act_module is not None and "moe_act" in enabled_targets and not moe_act_is_subsumed: + checkpoint_method(moe_act_module, "moe_act") if mlp is not None and "routed_experts" in enabled_targets: if hasattr(mlp, "_run_routed_experts"): checkpoint_method(mlp, "_run_routed_experts") diff --git a/src/prime_rl/trainer/models/layers/moe.py b/src/prime_rl/trainer/models/layers/moe.py index f6f5649cce..6eb0ab1773 100644 --- a/src/prime_rl/trainer/models/layers/moe.py +++ b/src/prime_rl/trainer/models/layers/moe.py @@ -85,83 +85,6 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) -# TODO: keeping this for-loop implementation for comparison -# and readability, may remove later -def _run_experts_for_loop_impl( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - - return out - - -@expert_parallel -def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - return _run_experts_for_loop_impl(w1, w2, w3, x, num_tokens_per_expert) - - -@expert_parallel -def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - return _run_experts_grouped_mm_impl(w1, w2, w3, x, num_tokens_per_expert) - - -def _run_experts_grouped_mm_impl( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) - - return out - - class GroupedExperts(nn.Module): def __init__( self, @@ -177,17 +100,65 @@ def __init__( self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) self.use_grouped_mm = use_grouped_mm self.ep_comm_backend: EPCommBackend = "torch" + self._run_experts_for_loop = expert_parallel(self._run_experts_for_loop_impl) + self._run_experts_grouped_mm = expert_parallel(self._run_experts_grouped_mm_impl) def set_ep_comm_backend(self, backend: EPCommBackend) -> None: self.ep_comm_backend = backend + def moe_act(self, fc1_output: torch.Tensor, gate_output: torch.Tensor) -> torch.Tensor: + return F.silu(fc1_output) * gate_output + + def _run_experts_for_loop_impl( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + fc1_output = torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)) + gate_output = torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + hidden_states = self.moe_act(fc1_output, gate_output) + out_experts_splits.append(torch.matmul(hidden_states, w2[expert_idx].transpose(-2, -1))) + + out = torch.cat(out_experts_splits, dim=0) + return torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + + def _run_experts_grouped_mm_impl( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + assert x.dim() == 2 + + fc1_output = torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) + gate_output = torch._grouped_mm(x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets) + hidden_states = self.moe_act(fc1_output, gate_output) + return torch._grouped_mm(hidden_states, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + def _forward_deepep(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: w1 = self.w1.to_local() w2 = self.w2.to_local() w3 = self.w3.to_local() if self.use_grouped_mm: - return _run_experts_grouped_mm_impl(w1, w2, w3, x, num_tokens_per_expert) - return _run_experts_for_loop_impl(w1, w2, w3, x, num_tokens_per_expert) + return self._run_experts_grouped_mm_impl(w1, w2, w3, x, num_tokens_per_expert) + return self._run_experts_for_loop_impl(w1, w2, w3, x, num_tokens_per_expert) def forward( self, @@ -198,9 +169,9 @@ def forward( return self._forward_deepep(x, num_tokens_per_expert) if self.use_grouped_mm: - return _run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert) + return self._run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert) else: - return _run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert) + return self._run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) @@ -596,68 +567,6 @@ def relu2(x: torch.Tensor) -> torch.Tensor: return F.relu(x).square() -def _run_nongated_experts_for_loop_impl( - w1: torch.Tensor, - w2: torch.Tensor, - _w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - num_tokens_per_expert = num_tokens_per_expert.tolist() - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = relu2(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - return out - - -@expert_parallel -def _run_nongated_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - _w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - return _run_nongated_experts_for_loop_impl(w1, w2, _w3, x, num_tokens_per_expert) - - -def _run_nongated_experts_grouped_mm_impl( - w1: torch.Tensor, - w2: torch.Tensor, - _w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - assert x.dim() == 2 - - h = relu2(torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets)) - out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) - return out - - -@expert_parallel -def _run_nongated_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - _w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, -) -> torch.Tensor: - return _run_nongated_experts_grouped_mm_impl(w1, w2, _w3, x, num_tokens_per_expert) - - class NonGatedGroupedExperts(nn.Module): def __init__( self, @@ -674,17 +583,61 @@ def __init__( self.w3 = nn.Parameter(torch.empty(0)) self.use_grouped_mm = use_grouped_mm self.ep_comm_backend: EPCommBackend = "torch" + self._run_experts_for_loop = expert_parallel(self._run_experts_for_loop_impl) + self._run_experts_grouped_mm = expert_parallel(self._run_experts_grouped_mm_impl) def set_ep_comm_backend(self, backend: EPCommBackend) -> None: self.ep_comm_backend = backend + def moe_act(self, fc1_output: torch.Tensor) -> torch.Tensor: + return relu2(fc1_output) + + def _run_experts_for_loop_impl( + self, + w1: torch.Tensor, + w2: torch.Tensor, + _w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + num_tokens_per_expert = num_tokens_per_expert.tolist() + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + fc1_output = torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)) + hidden_states = self.moe_act(fc1_output) + out_experts_splits.append(torch.matmul(hidden_states, w2[expert_idx].transpose(-2, -1))) + out = torch.cat(out_experts_splits, dim=0) + return torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + + def _run_experts_grouped_mm_impl( + self, + w1: torch.Tensor, + w2: torch.Tensor, + _w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + assert x.dim() == 2 + + fc1_output = torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) + hidden_states = self.moe_act(fc1_output) + return torch._grouped_mm(hidden_states, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + def _forward_deepep(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: w1 = self.w1.to_local() w2 = self.w2.to_local() w3 = self.w3.to_local() if self.use_grouped_mm: - return _run_nongated_experts_grouped_mm_impl(w1, w2, w3, x, num_tokens_per_expert) - return _run_nongated_experts_for_loop_impl(w1, w2, w3, x, num_tokens_per_expert) + return self._run_experts_grouped_mm_impl(w1, w2, w3, x, num_tokens_per_expert) + return self._run_experts_for_loop_impl(w1, w2, w3, x, num_tokens_per_expert) def forward( self, @@ -694,9 +647,9 @@ def forward( if self.ep_comm_backend == "deepep": return self._forward_deepep(x, num_tokens_per_expert) if self.use_grouped_mm: - return _run_nongated_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert) + return self._run_experts_grouped_mm(self.w1, self.w2, self.w3, x, num_tokens_per_expert) else: - return _run_nongated_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert) + return self._run_experts_for_loop(self.w1, self.w2, self.w3, x, num_tokens_per_expert) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index ffcc18f270..799998d8dc 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -156,6 +156,12 @@ def test_removed_fused_lm_head_chunk_size_field_is_rejected(): TrainerModelConfig.model_validate({"fused_lm_head_chunk_size": "auto"}) +def test_activation_checkpointing_is_disabled_in_config_by_default(): + config = TrainerModelConfig.model_validate({}) + + assert config.ac is None + + def test_selective_activation_checkpointing_requires_custom_impl(): with pytest.raises(ValidationError, match="Selective activation checkpointing requires model.impl='custom'"): TrainerModelConfig.model_validate({"impl": "hf", "ac": {"mode": "selective"}}) diff --git a/tests/unit/train/models/test_checkpointing.py b/tests/unit/train/models/test_checkpointing.py index 90fe3f32c3..44afbc2fc3 100644 --- a/tests/unit/train/models/test_checkpointing.py +++ b/tests/unit/train/models/test_checkpointing.py @@ -1,5 +1,6 @@ import torch.nn as nn +from prime_rl.trainer.model import get_default_activation_checkpoint_config from prime_rl.trainer.models.layers.checkpointing import ( get_supported_targets, set_selective_activation_checkpointing, @@ -37,7 +38,16 @@ def __init__(self): self.mamba = DummyMamba() +class DummyMoEExperts(nn.Module): + def moe_act(self, hidden_states, gate_output=None): + return hidden_states + + class DummyMoEMlp(nn.Module): + def __init__(self): + super().__init__() + self.experts = DummyMoEExperts() + def forward(self, hidden_states): return hidden_states @@ -74,3 +84,34 @@ def test_routed_experts_checkpointing_patches_local_and_global_helpers(): set_selective_activation_checkpointing(layer, ["routed_experts"]) assert getattr(layer.mlp, _PATCHED_METHODS_ATTR) == frozenset({"_run_local_routed_experts", "_run_routed_experts"}) + + +def test_moe_act_checkpointing_patches_expert_activation(): + layer = DummyMoELayer() + + assert "moe_act" in get_supported_targets(layer) + + set_selective_activation_checkpointing(layer, ["moe_act"]) + + assert getattr(layer.mlp.experts, _PATCHED_METHODS_ATTR) == frozenset({"moe_act"}) + + +def test_routed_experts_subsumes_moe_act_checkpointing(): + layer = DummyMoELayer() + + set_selective_activation_checkpointing(layer, ["moe_act", "routed_experts"]) + + assert getattr(layer.mlp, _PATCHED_METHODS_ATTR) == frozenset({"_run_local_routed_experts", "_run_routed_experts"}) + assert not hasattr(layer.mlp.experts, _PATCHED_METHODS_ATTR) + + +def test_custom_impl_gets_default_selective_activation_checkpointing(): + config = get_default_activation_checkpoint_config("custom") + + assert config is not None + assert config.mode == "selective" + assert config.targets == ["norm"] + + +def test_hf_impl_has_no_default_activation_checkpointing(): + assert get_default_activation_checkpoint_config("hf") is None