Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`trainer.micro_batch_max_tokens`**: Added optional cap for packed local RL micro-batches. Defaults to `model.seq_len`, so existing configs keep the previous behavior. Lower values reduce per-forward RL trainer memory usage by splitting one trainer step into more local micro-batches without changing optimizer-step or checkpoint semantics. This applies only to real RL batches from the orchestrator; fake RL data still uses `data.fake.batch_size`. RL trainer token and sample progress metrics now sum the actual packed work across DP ranks instead of assuming uniform packing. (2026-04-04)
- **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31)
- **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31)
- **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27)
Expand Down
7 changes: 5 additions & 2 deletions docs/benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ uv run trainer ... --data.fake --bench
Benchmark different batch configurations, i.e. the (micro) batch size and sequence length

```bash
uv run trainer ... --model.seq-len 4096 --data.fake.batch-size 64 --data.fake.micro-batch-size 2 --bench
uv run trainer ... --model.seq-len 4096 --data.fake.batch-size 64 --bench
```

*Note, that it is not yet possible to benchmark the RL trainer against real data when benchmarking the RL trainer in isolation.*

When training the real RL path against orchestrator-produced batches, use `--micro-batch-max-tokens` to cap the
number of tokens packed into each local trainer micro batch. This knob does not apply to the fake RL benchmark path.

### Inference

To benchmark the inference engine in isolation, start the inference server with the correct configuration file and run the orchestrator with the `--bench` flag.
Expand All @@ -76,4 +79,4 @@ uv run rl \
--orchestrator @ path/to/orch.toml \
--inference @ path/to/infer.toml \
--bench
```
```
15 changes: 14 additions & 1 deletion docs/memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ To enable it, use:
fused_lm_head_token_chunk_size = auto
```

## RL micro batch cap

RL already performs gradient accumulation implicitly by packing multiple local micro batches into each optimizer step.
If you need lower trainer memory usage without changing trainer-step semantics, cap the number of tokens in each local
RL micro batch:

```toml
[trainer]
micro_batch_max_tokens = 4096
```

This only applies to real RL batches from the orchestrator. Lower values increase the number of local micro batches
per trainer step while keeping one optimizer step per trainer step.


## Expert parallelism

Expand Down Expand Up @@ -129,4 +143,3 @@ LoRA training significantly reduces the memory usage of the trainer at the cost
[trainer.model.lora]
rank = 8
```

4 changes: 2 additions & 2 deletions docs/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ ulimit -n 32000

Assuming this is happening on the RL or SFT trainer, you can try the following:
- Use full activation checkpointing (`--model.ac`)
- Reduce the the micro batch size (`--data.micro-batch-size`) and sequence length (`--data.seq-len`)
- Reduce the sequence length (`--data.seq-len` for SFT, `--model.seq-len` for RL)
- Reduce the trainer micro batch size (`--data.micro-batch-size` for SFT, `--micro-batch-max-tokens` for RL)
- (*Experimental*) Use context parallelism with `--model.cp`

> I cannot pass my TOML config file

Check that you *did* leave a whitespace between the `@` and the config file (e.g. `uv run ... @ path/to/config.toml` instead of `uv run ... @path/to/config.toml`). Also, make sure that your TOML config matches the configuration schema. If not, the Pydantic error message (which arguably is quite ugly) will hopefully point you in the right direction.

29 changes: 29 additions & 0 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,19 @@ class TrainerConfig(BaseConfig):
# The data configuration
data: DataLoaderConfig = DataLoaderConfig()

micro_batch_max_tokens: Annotated[
int | None,
Field(
ge=1,
description=(
"Maximum number of text tokens to pack into each local RL micro batch. "
"Defaults to model.seq_len, which preserves the current behavior. "
"Lower values increase the number of local micro batches per trainer step without changing "
"optimizer-step or checkpoint semantics."
),
),
] = None

# The loss configuration
loss: LossConfig = DefaultLossConfig()

Expand Down Expand Up @@ -850,6 +863,22 @@ def auto_setup_bench(self):
self.ckpt = None
return self

@model_validator(mode="after")
def validate_micro_batch_max_tokens(self):
if self.micro_batch_max_tokens is None:
return self

if self.micro_batch_max_tokens > self.model.seq_len:
raise ValueError("micro_batch_max_tokens must be less than or equal to model.seq_len")

if self.data.fake is not None:
raise ValueError(
"micro_batch_max_tokens is only supported for real RL batches. "
"Fake RL data already uses data.fake.batch_size to control local micro batches."
)

return self

@model_validator(mode="after")
def dont_do_massive_traces(self):
if self.trace_path:
Expand Down
9 changes: 7 additions & 2 deletions src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def packed_samples_into_micro_bs(
bin_content.routed_experts = []
bin_content.routed_experts.extend(sample.routed_experts)
bin_content.position_ids.extend(sample.position_ids)
bin_content.sample_count += sample.sample_count
bin_content.lora_num_tokens[idx] += len(sample.input_ids)
break
else:
Expand Down Expand Up @@ -168,6 +169,7 @@ def _make_dummy_batch(source: MicroBatch) -> MicroBatch:
dummy = copy.deepcopy(source)
dummy.advantages = [0.0] * len(dummy.input_ids)
dummy.loss_mask = [False] * len(dummy.input_ids)
dummy.sample_count = 0
return dummy


Expand All @@ -186,11 +188,13 @@ def prepare_batch(
num_train_workers: int,
idxs: list[int],
num_loras: int,
micro_batch_max_tokens: int | None = None,
pad_to_multiple_of: int = 1,
) -> list[list[MicroBatch]]:
"""
Prepare a batch of problems for each GPU. Each batch is a list of micro batches.
Each micro batch is shape [1, seq_len], the number of samples is not fixed per micro batch.
Each micro batch is shape [1, <= micro_batch_max_tokens], and each individual sample is
truncated to seq_len before packing.

FSDP requires all ranks to execute the same operations at each step. If one rank
processes a multimodal batch (triggering the vision encoder) while another processes
Expand All @@ -199,7 +203,8 @@ def prepare_batch(
"""
all_samples = [(idx, prepare_sample(rollout, seq_len)) for idx, rollout in zip(idxs, rollouts)]

micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras)
max_micro_batch_tokens = micro_batch_max_tokens or seq_len
micro_batches = packed_samples_into_micro_bs(all_samples, max_micro_batch_tokens, num_loras)
micro_batches = [pad_micro_batch(micro_batch, pad_to_multiple_of) for micro_batch in micro_batches]

# Separate by modality so each step index has uniform modality across all ranks
Expand Down
8 changes: 8 additions & 0 deletions src/prime_rl/trainer/rl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TensorMicroBatch(TypedDict):

# Batch level
lora_num_tokens: Int[Tensor, "n_loras"]
sample_count: int

# MoE router replay
routed_experts: Int[Tensor, "batch seq layers topk"] | None
Expand Down Expand Up @@ -73,6 +74,7 @@ def get_batch(self) -> list[TensorMicroBatch]:

def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch:
total_seq_len = 0
sample_count = 0
input_ids = []
position_ids = []

Expand All @@ -87,6 +89,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc

input_ids.append(tmp_input_ids)
position_ids.append(tmp_position_ids)
sample_count += 1

input_ids = torch.cat(input_ids, dim=0)
position_ids = torch.cat(position_ids, dim=0)
Expand All @@ -105,6 +108,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc
"temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0),
"loss_mask": loss_mask.unsqueeze(0),
"lora_num_tokens": lora_num_tokens,
"sample_count": sample_count,
"routed_experts": None,
"pixel_values": None,
"image_grid_thw": None,
Expand All @@ -130,6 +134,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch:
"temperatures": torch.ones(self.seq_len).unsqueeze(0),
"loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0),
"lora_num_tokens": lora_num_tokens,
"sample_count": 1,
"routed_experts": None,
"pixel_values": None,
"image_grid_thw": None,
Expand All @@ -145,6 +150,7 @@ def __init__(
start_step: int,
dp_world_size: int,
seq_len: int,
micro_batch_max_tokens: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
Expand All @@ -155,6 +161,7 @@ def __init__(
self.packer: BasePacker = setup_packer(
dp_world_size=dp_world_size,
seq_len=seq_len,
micro_batch_max_tokens=micro_batch_max_tokens,
tokenizer=tokenizer,
transport_config=config,
pad_to_multiple_of=pad_to_multiple_of,
Expand Down Expand Up @@ -197,6 +204,7 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch:
loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0),
temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0),
lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32),
sample_count=micro_batch.sample_count,
# Multimodal fields - no batch dimension for these as they are variable-sized
pixel_values=torch.frombuffer(bytearray(micro_batch.pixel_values), dtype=torch.float32).reshape(
micro_batch.pixel_values_shape
Expand Down
47 changes: 43 additions & 4 deletions src/prime_rl/trainer/rl/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
self,
dp_world_size: int,
seq_len: int,
micro_batch_max_tokens: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.multi_run_manager = get_multi_run_manager()
self.dp_world_size = dp_world_size
self.seq_len = seq_len
self.micro_batch_max_tokens = micro_batch_max_tokens
self.pad_to_multiple_of = pad_to_multiple_of
self.tokenizer = tokenizer
self.receiver = setup_training_batch_receiver(config)
Expand Down Expand Up @@ -81,12 +83,21 @@ def __init__(
self,
dp_world_size: int,
seq_len: int,
micro_batch_max_tokens: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
start_step: int = 0,
):
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step)
super().__init__(
dp_world_size,
seq_len,
micro_batch_max_tokens,
pad_to_multiple_of,
tokenizer,
config,
start_step,
)
assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run"

def pack(self):
Expand All @@ -106,6 +117,7 @@ def pack(self):
micro_batch_grid = prepare_batch(
rollouts=batch.examples,
seq_len=self.seq_len,
micro_batch_max_tokens=self.micro_batch_max_tokens,
pad_to_multiple_of=self.pad_to_multiple_of,
num_train_workers=self.dp_world_size,
idxs=[0] * len(batch.examples),
Expand All @@ -120,12 +132,21 @@ def __init__(
self,
dp_world_size: int,
seq_len: int,
micro_batch_max_tokens: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
start_step: int = 0,
):
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step)
super().__init__(
dp_world_size,
seq_len,
micro_batch_max_tokens,
pad_to_multiple_of,
tokenizer,
config,
start_step,
)
# Per-run buffer: stores (TrainingSample, step) tuples
self.buffers: list[deque[tuple[TrainingSample, int]]] = [
deque() for _ in range(self.multi_run_manager.max_runs)
Expand Down Expand Up @@ -323,6 +344,7 @@ def pack(self):
run_micro_batch_grid = prepare_batch(
rollouts=run_samples,
seq_len=self.seq_len,
micro_batch_max_tokens=self.micro_batch_max_tokens,
pad_to_multiple_of=self.pad_to_multiple_of,
num_train_workers=self.dp_world_size,
idxs=[run_idx] * len(run_samples),
Expand All @@ -338,13 +360,30 @@ def pack(self):
def setup_packer(
dp_world_size: int,
seq_len: int,
micro_batch_max_tokens: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
transport_config: TransportConfig,
start_step: int = 0,
) -> BasePacker:
multi_run_manager = get_multi_run_manager()
if multi_run_manager.max_runs == 1:
return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step)
return SinglePacker(
dp_world_size,
seq_len,
micro_batch_max_tokens,
pad_to_multiple_of,
tokenizer,
transport_config,
start_step,
)
else:
return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step)
return MultiPacker(
dp_world_size,
seq_len,
micro_batch_max_tokens,
pad_to_multiple_of,
tokenizer,
transport_config,
start_step,
)
39 changes: 39 additions & 0 deletions src/prime_rl/trainer/rl/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass
from typing import Any, Mapping

import torch
import torch.distributed as dist


@dataclass(frozen=True)
class LocalBatchStats:
num_micro_batches: int
num_tokens: int
num_loss_tokens: int
num_samples: int
max_micro_batch_tokens: int


def get_local_batch_stats(micro_batches: list[Mapping[str, Any]]) -> LocalBatchStats:
return LocalBatchStats(
num_micro_batches=len(micro_batches),
num_tokens=sum(int(micro_batch["input_ids"].numel()) for micro_batch in micro_batches),
num_loss_tokens=sum(int(micro_batch["loss_mask"].sum().item()) for micro_batch in micro_batches),
num_samples=sum(int(micro_batch["sample_count"]) for micro_batch in micro_batches),
max_micro_batch_tokens=max(int(micro_batch["input_ids"].shape[1]) for micro_batch in micro_batches),
)


def aggregate_dp_count(
num_local_value: int,
*,
dp_world_size: int,
dp_group,
device: torch.device,
) -> int:
if dp_world_size == 1:
return num_local_value

num_value = torch.tensor(num_local_value, device=device, dtype=torch.long)
dist.all_reduce(num_value, op=dist.ReduceOp.SUM, group=dp_group)
return int(num_value.item())
Loading