From dcd22f8300aaff7566cd4e1326ce10ca08ff6770 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Thu, 2 Apr 2026 00:42:35 -0700 Subject: [PATCH 1/2] Add RL micro batch token cap --- docs/benchmarking.md | 7 ++- docs/memory_usage.md | 15 ++++++- docs/troubleshooting.md | 4 +- src/prime_rl/configs/trainer.py | 29 ++++++++++++ src/prime_rl/trainer/batch.py | 9 +++- src/prime_rl/trainer/rl/data.py | 8 ++++ src/prime_rl/trainer/rl/packer.py | 47 ++++++++++++++++++-- src/prime_rl/trainer/rl/train.py | 48 +++++++++++++++----- src/prime_rl/transport/types.py | 1 + tests/unit/orchestrator/test_batch.py | 27 +++++++++++ tests/unit/test_configs.py | 16 +++++++ tests/unit/train/rl/test_packer.py | 64 +++++++++++++++++++++++++++ 12 files changed, 254 insertions(+), 21 deletions(-) diff --git a/docs/benchmarking.md b/docs/benchmarking.md index f75b03b23a..19744ed470 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -47,11 +47,14 @@ uv run trainer ... --data.fake --bench Benchmark different batch configurations, i.e. the (micro) batch size and sequence length ```bash -uv run trainer ... --model.seq-len 4096 --data.fake.batch-size 64 --data.fake.micro-batch-size 2 --bench +uv run trainer ... --model.seq-len 4096 --data.fake.batch-size 64 --bench ``` *Note, that it is not yet possible to benchmark the RL trainer against real data when benchmarking the RL trainer in isolation.* +When training the real RL path against orchestrator-produced batches, use `--micro-batch-max-tokens` to cap the +number of tokens packed into each local trainer micro batch. This knob does not apply to the fake RL benchmark path. + ### Inference To benchmark the inference engine in isolation, start the inference server with the correct configuration file and run the orchestrator with the `--bench` flag. @@ -76,4 +79,4 @@ uv run rl \ --orchestrator @ path/to/orch.toml \ --inference @ path/to/infer.toml \ --bench -``` \ No newline at end of file +``` diff --git a/docs/memory_usage.md b/docs/memory_usage.md index b36c117254..53a0b2584b 100644 --- a/docs/memory_usage.md +++ b/docs/memory_usage.md @@ -63,6 +63,20 @@ To enable it, use: fused_lm_head_token_chunk_size = auto ``` +## RL micro batch cap + +RL already performs gradient accumulation implicitly by packing multiple local micro batches into each optimizer step. +If you need lower trainer memory usage without changing trainer-step semantics, cap the number of tokens in each local +RL micro batch: + +```toml +[trainer] +micro_batch_max_tokens = 4096 +``` + +This only applies to real RL batches from the orchestrator. Lower values increase the number of local micro batches +per trainer step while keeping one optimizer step per trainer step. + ## Expert parallelism @@ -129,4 +143,3 @@ LoRA training significantly reduces the memory usage of the trainer at the cost [trainer.model.lora] rank = 8 ``` - diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 408ad0df87..9453961ed4 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -12,10 +12,10 @@ ulimit -n 32000 Assuming this is happening on the RL or SFT trainer, you can try the following: - Use full activation checkpointing (`--model.ac`) -- Reduce the the micro batch size (`--data.micro-batch-size`) and sequence length (`--data.seq-len`) +- Reduce the sequence length (`--data.seq-len` for SFT, `--model.seq-len` for RL) +- Reduce the trainer micro batch size (`--data.micro-batch-size` for SFT, `--micro-batch-max-tokens` for RL) - (*Experimental*) Use context parallelism with `--model.cp` > I cannot pass my TOML config file Check that you *did* leave a whitespace between the `@` and the config file (e.g. `uv run ... @ path/to/config.toml` instead of `uv run ... @path/to/config.toml`). Also, make sure that your TOML config matches the configuration schema. If not, the Pydantic error message (which arguably is quite ugly) will hopefully point you in the right direction. - diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index 1508a0a683..8af0a0bc46 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -728,6 +728,19 @@ class TrainerConfig(BaseConfig): # The data configuration data: DataLoaderConfig = DataLoaderConfig() + micro_batch_max_tokens: Annotated[ + int | None, + Field( + ge=1, + description=( + "Maximum number of text tokens to pack into each local RL micro batch. " + "Defaults to model.seq_len, which preserves the current behavior. " + "Lower values increase the number of local micro batches per trainer step without changing " + "optimizer-step or checkpoint semantics." + ), + ), + ] = None + # The loss configuration loss: LossConfig = DefaultLossConfig() @@ -850,6 +863,22 @@ def auto_setup_bench(self): self.ckpt = None return self + @model_validator(mode="after") + def validate_micro_batch_max_tokens(self): + if self.micro_batch_max_tokens is None: + return self + + if self.micro_batch_max_tokens > self.model.seq_len: + raise ValueError("micro_batch_max_tokens must be less than or equal to model.seq_len") + + if self.data.fake is not None: + raise ValueError( + "micro_batch_max_tokens is only supported for real RL batches. " + "Fake RL data already uses data.fake.batch_size to control local micro batches." + ) + + return self + @model_validator(mode="after") def dont_do_massive_traces(self): if self.trace_path: diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index e5580ebe28..ed2c3798cd 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -121,6 +121,7 @@ def packed_samples_into_micro_bs( bin_content.routed_experts = [] bin_content.routed_experts.extend(sample.routed_experts) bin_content.position_ids.extend(sample.position_ids) + bin_content.sample_count += sample.sample_count bin_content.lora_num_tokens[idx] += len(sample.input_ids) break else: @@ -168,6 +169,7 @@ def _make_dummy_batch(source: MicroBatch) -> MicroBatch: dummy = copy.deepcopy(source) dummy.advantages = [0.0] * len(dummy.input_ids) dummy.loss_mask = [False] * len(dummy.input_ids) + dummy.sample_count = 0 return dummy @@ -186,11 +188,13 @@ def prepare_batch( num_train_workers: int, idxs: list[int], num_loras: int, + micro_batch_max_tokens: int | None = None, pad_to_multiple_of: int = 1, ) -> list[list[MicroBatch]]: """ Prepare a batch of problems for each GPU. Each batch is a list of micro batches. - Each micro batch is shape [1, seq_len], the number of samples is not fixed per micro batch. + Each micro batch is shape [1, <= micro_batch_max_tokens], and each individual sample is + truncated to seq_len before packing. FSDP requires all ranks to execute the same operations at each step. If one rank processes a multimodal batch (triggering the vision encoder) while another processes @@ -199,7 +203,8 @@ def prepare_batch( """ all_samples = [(idx, prepare_sample(rollout, seq_len)) for idx, rollout in zip(idxs, rollouts)] - micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras) + max_micro_batch_tokens = micro_batch_max_tokens or seq_len + micro_batches = packed_samples_into_micro_bs(all_samples, max_micro_batch_tokens, num_loras) micro_batches = [pad_micro_batch(micro_batch, pad_to_multiple_of) for micro_batch in micro_batches] # Separate by modality so each step index has uniform modality across all ranks diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index ee72d420dc..b10c662f26 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -27,6 +27,7 @@ class TensorMicroBatch(TypedDict): # Batch level lora_num_tokens: Int[Tensor, "n_loras"] + sample_count: int # MoE router replay routed_experts: Int[Tensor, "batch seq layers topk"] | None @@ -73,6 +74,7 @@ def get_batch(self) -> list[TensorMicroBatch]: def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: total_seq_len = 0 + sample_count = 0 input_ids = [] position_ids = [] @@ -87,6 +89,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc input_ids.append(tmp_input_ids) position_ids.append(tmp_position_ids) + sample_count += 1 input_ids = torch.cat(input_ids, dim=0) position_ids = torch.cat(position_ids, dim=0) @@ -105,6 +108,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), "loss_mask": loss_mask.unsqueeze(0), "lora_num_tokens": lora_num_tokens, + "sample_count": sample_count, "routed_experts": None, "pixel_values": None, "image_grid_thw": None, @@ -130,6 +134,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: "temperatures": torch.ones(self.seq_len).unsqueeze(0), "loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0), "lora_num_tokens": lora_num_tokens, + "sample_count": 1, "routed_experts": None, "pixel_values": None, "image_grid_thw": None, @@ -145,6 +150,7 @@ def __init__( start_step: int, dp_world_size: int, seq_len: int, + micro_batch_max_tokens: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, @@ -155,6 +161,7 @@ def __init__( self.packer: BasePacker = setup_packer( dp_world_size=dp_world_size, seq_len=seq_len, + micro_batch_max_tokens=micro_batch_max_tokens, tokenizer=tokenizer, transport_config=config, pad_to_multiple_of=pad_to_multiple_of, @@ -197,6 +204,7 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0), temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0), lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32), + sample_count=micro_batch.sample_count, # Multimodal fields - no batch dimension for these as they are variable-sized pixel_values=torch.frombuffer(bytearray(micro_batch.pixel_values), dtype=torch.float32).reshape( micro_batch.pixel_values_shape diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..d552aacdee 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -29,6 +29,7 @@ def __init__( self, dp_world_size: int, seq_len: int, + micro_batch_max_tokens: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, @@ -38,6 +39,7 @@ def __init__( self.multi_run_manager = get_multi_run_manager() self.dp_world_size = dp_world_size self.seq_len = seq_len + self.micro_batch_max_tokens = micro_batch_max_tokens self.pad_to_multiple_of = pad_to_multiple_of self.tokenizer = tokenizer self.receiver = setup_training_batch_receiver(config) @@ -81,12 +83,21 @@ def __init__( self, dp_world_size: int, seq_len: int, + micro_batch_max_tokens: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + super().__init__( + dp_world_size, + seq_len, + micro_batch_max_tokens, + pad_to_multiple_of, + tokenizer, + config, + start_step, + ) assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run" def pack(self): @@ -106,6 +117,7 @@ def pack(self): micro_batch_grid = prepare_batch( rollouts=batch.examples, seq_len=self.seq_len, + micro_batch_max_tokens=self.micro_batch_max_tokens, pad_to_multiple_of=self.pad_to_multiple_of, num_train_workers=self.dp_world_size, idxs=[0] * len(batch.examples), @@ -120,12 +132,21 @@ def __init__( self, dp_world_size: int, seq_len: int, + micro_batch_max_tokens: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + super().__init__( + dp_world_size, + seq_len, + micro_batch_max_tokens, + pad_to_multiple_of, + tokenizer, + config, + start_step, + ) # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -323,6 +344,7 @@ def pack(self): run_micro_batch_grid = prepare_batch( rollouts=run_samples, seq_len=self.seq_len, + micro_batch_max_tokens=self.micro_batch_max_tokens, pad_to_multiple_of=self.pad_to_multiple_of, num_train_workers=self.dp_world_size, idxs=[run_idx] * len(run_samples), @@ -338,6 +360,7 @@ def pack(self): def setup_packer( dp_world_size: int, seq_len: int, + micro_batch_max_tokens: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, transport_config: TransportConfig, @@ -345,6 +368,22 @@ def setup_packer( ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: - return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return SinglePacker( + dp_world_size, + seq_len, + micro_batch_max_tokens, + pad_to_multiple_of, + tokenizer, + transport_config, + start_step, + ) else: - return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return MultiPacker( + dp_world_size, + seq_len, + micro_batch_max_tokens, + pad_to_multiple_of, + tokenizer, + transport_config, + start_step, + ) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 626ed151c7..86aa9adc08 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -75,6 +75,7 @@ def train(config: TrainerConfig): json_logging=config.log.json_logging, ) logger.info(f"Starting RL trainer in {world} in {config.output_dir}") + effective_micro_batch_max_tokens = config.micro_batch_max_tokens or config.model.seq_len # Print warning if running in benchmark mode if config.bench is not None: @@ -201,13 +202,19 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Set up the data loader (Optionally, use a fake data loader for debugging) logger.info(f"Initializing data loader ({config.data})") if config.data.fake: + logger.info("Using fake RL data loader; each local micro batch uses model.seq_len tokens") dataloader = FakeDataLoader(config.data.fake, config.model.seq_len, parallel_dims.get_mesh("dp").size()) else: + logger.info( + "Using RL micro batch cap of " + f"{effective_micro_batch_max_tokens} tokens (model.seq_len={config.model.seq_len})" + ) dataloader = DataLoader( config.output_dir, progress.step, parallel_dims.get_mesh("dp").size(), config.model.seq_len, + effective_micro_batch_max_tokens, config.model.cp, tokenizer, config.rollout_transport, @@ -304,19 +311,21 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: load_data_time = time.perf_counter() - load_data_start_time logger.debug(f"Loaded batch in {load_data_time:.2f} seconds") - batch_size = len(micro_batches) + num_local_micro_batches = len(micro_batches) memory_profiler = None if config.memory_profiler_path is not None: memory_profiler = MemoryProfiler(progress.step, config.memory_profiler_path) forward_backward_start_time = time.perf_counter() - seq_len = micro_batches[0]["input_ids"].shape[1] + max_local_micro_batch_tokens = max(micro_batch["input_ids"].shape[1] for micro_batch in micro_batches) + num_local_tokens = sum(micro_batch["input_ids"].numel() for micro_batch in micro_batches) + num_local_loss_tokens = sum(int(micro_batch["loss_mask"].sum().item()) for micro_batch in micro_batches) + num_local_samples = sum(micro_batch["sample_count"] for micro_batch in micro_batches) # Normalize by the local number of unmasked tokens in the batch (per-batch length normalization) - loss_scale = sum(micro_batch["loss_mask"].sum().item() for micro_batch in micro_batches) - loss_scale = max(loss_scale, 1) + loss_scale = max(num_local_loss_tokens, 1) - logger.debug(f"Starting forward and backward pass ({batch_size=})") + logger.debug(f"Starting forward and backward pass ({num_local_micro_batches=})") tensors = Tensors() # Used to accumulate tensor statistics across micro-batches and ranks for logging cp_enabled = parallel_dims.cp_enabled cp_rank = parallel_dims.world_mesh["cp"].get_local_rank() if cp_enabled else 0 @@ -455,7 +464,11 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensors[key].append(loss_tensor) # Debug log with *local, micro step* stats - micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy'][-1].mean().item():.4f}" + micro_step_message = ( + f"Micro Step {micro_step + 1}/{num_local_micro_batches} | " + f"Loss: {tensors['loss'][-1].mean().item():.4f} | " + f"Entropy: {tensors['entropy'][-1].mean().item():.4f}" + ) if "mismatch_kl" in tensors: micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}" if "max_vio" in tensors: @@ -492,22 +505,37 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensor_stats = tensors.compute_stats() # Compute step metrics - num_local_tokens = seq_len * batch_size num_tokens = parallel_dims.get_mesh("dp").size() * num_local_tokens progress.total_tokens += num_tokens - progress.total_samples += batch_size - perf_counter = get_perf_counter(model, seq_len) + progress.total_samples += num_local_samples + perf_counter = get_perf_counter(model, effective_micro_batch_max_tokens) perf_counter.count_tokens(num_tokens) throughput = perf_counter.get_tokens_per_second() or 0 mfu = perf_counter.get_mfu() or 0 peak_memory = torch.cuda.max_memory_reserved() / 1024**3 # GiB + progress_metrics = { + "progress/total_tokens": progress.total_tokens, + "progress/total_samples": progress.total_samples, + "batch/local_tokens": num_local_tokens, + "batch/local_loss_tokens": num_local_loss_tokens, + "batch/local_samples": num_local_samples, + "batch/local_micro_batches": num_local_micro_batches, + "batch/local_micro_batch_tokens_max": max_local_micro_batch_tokens, + "batch/micro_batch_max_tokens": effective_micro_batch_max_tokens, + "step": progress.step, + } + monitor.log(progress_metrics, step=progress.step) + # Log step metrics step_time = time.perf_counter() - step_start_time step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {tensor_stats['loss/mean']:.4f} | Entropy: {tensor_stats['entropy/mean']:.4f}" if "mismatch_kl/mean" in tensor_stats: step_message += f" | Mismatch KL: {tensor_stats['mismatch_kl/mean']:.4f}" - step_message += f" | Grad. Norm: {grad_norm:.4f} | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB" + step_message += ( + f" | Micro Batches: {num_local_micro_batches} | Grad. Norm: {grad_norm:.4f} | LR: {current_lr:.2e}" + f" | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB" + ) if "max_vio/mean" in tensor_stats: step_message += f" | Max Vio: {tensor_stats['max_vio/mean']:.4f}" logger.success(step_message) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 8579112fa4..3f7cc1a2fa 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -51,3 +51,4 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): pixel_values_shape: list[int] | None = None # [num_patches, patch_dim] # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] image_grid_thw: list[list[int]] | None = None + sample_count: int = 1 diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index cbe7c70f1e..4cd7ec7875 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -77,6 +77,33 @@ def test_prepare_batch_packs_different_temperatures(make_training_example): assert flat_batches[0].temperatures[:4] == [0.7, 0.7, 0.7, 0.7] # Second sample (4 tokens): all get temp 1.1 assert flat_batches[0].temperatures[4:8] == [1.1, 1.1, 1.1, 1.1] + assert flat_batches[0].sample_count == 2 + + +def test_prepare_batch_decouples_sample_truncation_from_micro_batch_cap(): + sample = TrainingSample( + prompt_ids=[1, 2, 3], + prompt_mask=[False, False, False], + completion_ids=[4, 5, 6], + completion_mask=[True, True, True], + completion_logprobs=[-0.1, -0.2, -0.3], + completion_temperatures=[1.0, 1.0, 1.0], + advantage=1.0, + ) + + batches_per_gpu = prepare_batch( + rollouts=[sample, sample], + seq_len=4, + micro_batch_max_tokens=8, + num_train_workers=1, + idxs=[0, 0], + num_loras=1, + ) + + flat_batches = [batch for worker_batches in batches_per_gpu for batch in worker_batches] + assert len(flat_batches) == 1 + assert len(flat_batches[0].input_ids) == 8 + assert flat_batches[0].sample_count == 2 def test_prepare_sample_with_routed_experts(): diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index ffcc18f270..923182d7ca 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -159,3 +159,19 @@ def test_removed_fused_lm_head_chunk_size_field_is_rejected(): def test_selective_activation_checkpointing_requires_custom_impl(): with pytest.raises(ValidationError, match="Selective activation checkpointing requires model.impl='custom'"): TrainerModelConfig.model_validate({"impl": "hf", "ac": {"mode": "selective"}}) + + +def test_rl_micro_batch_max_tokens_must_not_exceed_seq_len(): + with pytest.raises(ValidationError, match="micro_batch_max_tokens must be less than or equal to model.seq_len"): + TrainerConfig.model_validate({"model": {"seq_len": 1024}, "micro_batch_max_tokens": 2048}) + + +def test_rl_micro_batch_max_tokens_is_rejected_for_fake_data(): + with pytest.raises(ValidationError, match="micro_batch_max_tokens is only supported for real RL batches"): + TrainerConfig.model_validate( + { + "model": {"seq_len": 1024}, + "micro_batch_max_tokens": 512, + "data": {"fake": {"batch_size": 8}}, + } + ) diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index c661ec0df5..75f38faf4a 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -88,6 +88,7 @@ def fake_sender(_output_dir, _data_world_size, _current_step, _config): packer = MultiPacker( dp_world_size=1, seq_len=4, + micro_batch_max_tokens=4, pad_to_multiple_of=1, tokenizer=None, config=FileSystemTransportConfig(), @@ -107,3 +108,66 @@ def fake_sender(_output_dir, _data_world_size, _current_step, _config): sender = sender_holder["sender"] assert len(sender.sent) == 1 assert len(sender.sent[0][0]) == 1 + + +def test_packer_respects_micro_batch_max_tokens(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + reset_world() + runs._MULTI_RUN_MANAGER = None + manager = setup_multi_run_manager(output_dir=tmp_path, max_runs=1, device=torch.device("cpu")) + + create_run_with_config(tmp_path, "run_test123") + manager.discover_runs() + run_idx = manager.id_2_idx["run_test123"] + + class DummyReceiver: + def receive(self): + return [] + + def reset_run(self, idx: int) -> None: + pass + + class DummySender: + def __init__(self): + self.sent = [] + + def send(self, micro_batch_grid): + self.sent.append(micro_batch_grid) + + sender_holder: dict[str, DummySender] = {} + + def fake_receiver(_config): + return DummyReceiver() + + def fake_sender(_output_dir, _data_world_size, _current_step, _config): + sender = DummySender() + sender_holder["sender"] = sender + return sender + + monkeypatch.setattr("prime_rl.trainer.rl.packer.setup_training_batch_receiver", fake_receiver) + monkeypatch.setattr("prime_rl.trainer.rl.packer.setup_micro_batch_sender", fake_sender) + + packer = MultiPacker( + dp_world_size=1, + seq_len=4, + micro_batch_max_tokens=2, + pad_to_multiple_of=1, + tokenizer=None, + config=FileSystemTransportConfig(), + start_step=0, + ) + + packer.buffers[run_idx].append((make_training_sample(), 0)) + packer.buffers[run_idx].append((make_training_sample(), 0)) + + packer.pack() + + progress = manager.progress[run_idx] + assert progress.total_samples == 2 + assert progress.total_tokens == 4 + assert progress.step == 1 + + sender = sender_holder["sender"] + assert len(sender.sent) == 1 + assert len(sender.sent[0][0]) == 2 + assert [micro_batch.sample_count for micro_batch in sender.sent[0][0]] == [1, 1] + assert [len(micro_batch.input_ids) for micro_batch in sender.sent[0][0]] == [2, 2] From 2c370e61433301b4b7d3dc03f1a135366cb42c37 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sat, 4 Apr 2026 11:29:55 -0700 Subject: [PATCH 2/2] Fix RL DP progress accounting --- CHANGELOG.md | 1 + src/prime_rl/trainer/rl/stats.py | 39 ++++++++++++++++++ src/prime_rl/trainer/rl/train.py | 29 +++++++++---- tests/unit/train/rl/test_stats.py | 67 +++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+), 7 deletions(-) create mode 100644 src/prime_rl/trainer/rl/stats.py create mode 100644 tests/unit/train/rl/test_stats.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7bf13850..cba5e254c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes). +- **`trainer.micro_batch_max_tokens`**: Added optional cap for packed local RL micro-batches. Defaults to `model.seq_len`, so existing configs keep the previous behavior. Lower values reduce per-forward RL trainer memory usage by splitting one trainer step into more local micro-batches without changing optimizer-step or checkpoint semantics. This applies only to real RL batches from the orchestrator; fake RL data still uses `data.fake.batch_size`. RL trainer token and sample progress metrics now sum the actual packed work across DP ranks instead of assuming uniform packing. (2026-04-04) - **`log.file` and `log.env_worker_logs` removed**: Removed `log.file` (from `LogConfig` and `SharedLogConfig`) and `log.env_worker_logs` (from `LogConfig`). Python file logging is replaced by deployment-level capture. Existing configs using these fields must delete them. Log paths unified: `.stdout` files renamed to `.log`, SLURM logs moved from `slurm/` to `logs/`. (2026-03-31) - **`trainer.log.ranks_filter` (NEW)**: Added `ranks_filter: list[int]` to `TrainerLogConfig` (default: `[0]`). Controls which ranks appear in trainer console output via torchrun's `--local-ranks-filter`. (2026-03-31) - **`wandb.log_extras.sample_ratio` / monitor sample logging defaults**: `wandb.log_extras.sample_ratio` is now actually applied to W&B sample-table logging via the shared monitor sampler (it was previously a no-op for WandB). Separately, the orchestrator no longer hard-caps sample logging to 8 rollouts before monitor-level sampling runs, so when monitor `sample_ratio` is `None`, monitors now receive and may log the full rollout batch for a step instead of at most 8 rollouts. This affects both W&B and Prime monitor sample logging behavior. (2026-03-27) diff --git a/src/prime_rl/trainer/rl/stats.py b/src/prime_rl/trainer/rl/stats.py new file mode 100644 index 0000000000..c2b8e567e9 --- /dev/null +++ b/src/prime_rl/trainer/rl/stats.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Any, Mapping + +import torch +import torch.distributed as dist + + +@dataclass(frozen=True) +class LocalBatchStats: + num_micro_batches: int + num_tokens: int + num_loss_tokens: int + num_samples: int + max_micro_batch_tokens: int + + +def get_local_batch_stats(micro_batches: list[Mapping[str, Any]]) -> LocalBatchStats: + return LocalBatchStats( + num_micro_batches=len(micro_batches), + num_tokens=sum(int(micro_batch["input_ids"].numel()) for micro_batch in micro_batches), + num_loss_tokens=sum(int(micro_batch["loss_mask"].sum().item()) for micro_batch in micro_batches), + num_samples=sum(int(micro_batch["sample_count"]) for micro_batch in micro_batches), + max_micro_batch_tokens=max(int(micro_batch["input_ids"].shape[1]) for micro_batch in micro_batches), + ) + + +def aggregate_dp_count( + num_local_value: int, + *, + dp_world_size: int, + dp_group, + device: torch.device, +) -> int: + if dp_world_size == 1: + return num_local_value + + num_value = torch.tensor(num_local_value, device=device, dtype=torch.long) + dist.all_reduce(num_value, op=dist.ReduceOp.SUM, group=dp_group) + return int(num_value.item()) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 86aa9adc08..391af9d7be 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -41,6 +41,7 @@ ) from prime_rl.trainer.parallel_dims import get_parallel_dims from prime_rl.trainer.perf import get_perf_counter +from prime_rl.trainer.rl.stats import aggregate_dp_count, get_local_batch_stats from prime_rl.trainer.utils import ( GarbageCollection, MemoryProfiler, @@ -311,16 +312,17 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: load_data_time = time.perf_counter() - load_data_start_time logger.debug(f"Loaded batch in {load_data_time:.2f} seconds") - num_local_micro_batches = len(micro_batches) + batch_stats = get_local_batch_stats(micro_batches) + num_local_micro_batches = batch_stats.num_micro_batches memory_profiler = None if config.memory_profiler_path is not None: memory_profiler = MemoryProfiler(progress.step, config.memory_profiler_path) forward_backward_start_time = time.perf_counter() - max_local_micro_batch_tokens = max(micro_batch["input_ids"].shape[1] for micro_batch in micro_batches) - num_local_tokens = sum(micro_batch["input_ids"].numel() for micro_batch in micro_batches) - num_local_loss_tokens = sum(int(micro_batch["loss_mask"].sum().item()) for micro_batch in micro_batches) - num_local_samples = sum(micro_batch["sample_count"] for micro_batch in micro_batches) + max_local_micro_batch_tokens = batch_stats.max_micro_batch_tokens + num_local_tokens = batch_stats.num_tokens + num_local_loss_tokens = batch_stats.num_loss_tokens + num_local_samples = batch_stats.num_samples # Normalize by the local number of unmasked tokens in the batch (per-batch length normalization) loss_scale = max(num_local_loss_tokens, 1) @@ -505,9 +507,22 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensor_stats = tensors.compute_stats() # Compute step metrics - num_tokens = parallel_dims.get_mesh("dp").size() * num_local_tokens + dp_mesh = parallel_dims.get_mesh("dp") + reduction_device = torch.device("cuda", torch.cuda.current_device()) + num_tokens = aggregate_dp_count( + num_local_tokens, + dp_world_size=dp_mesh.size(), + dp_group=dp_mesh.get_group(), + device=reduction_device, + ) + num_samples = aggregate_dp_count( + num_local_samples, + dp_world_size=dp_mesh.size(), + dp_group=dp_mesh.get_group(), + device=reduction_device, + ) progress.total_tokens += num_tokens - progress.total_samples += num_local_samples + progress.total_samples += num_samples perf_counter = get_perf_counter(model, effective_micro_batch_max_tokens) perf_counter.count_tokens(num_tokens) throughput = perf_counter.get_tokens_per_second() or 0 diff --git a/tests/unit/train/rl/test_stats.py b/tests/unit/train/rl/test_stats.py new file mode 100644 index 0000000000..98af0ce9aa --- /dev/null +++ b/tests/unit/train/rl/test_stats.py @@ -0,0 +1,67 @@ +import torch +import torch.distributed as dist + +from prime_rl.trainer.rl.stats import aggregate_dp_count, get_local_batch_stats + + +def test_get_local_batch_stats_uses_actual_packed_tokens(): + micro_batches = [ + { + "input_ids": torch.ones((1, 3), dtype=torch.long), + "loss_mask": torch.tensor([[True, False, True]]), + "sample_count": 2, + }, + { + "input_ids": torch.ones((1, 5), dtype=torch.long), + "loss_mask": torch.tensor([[True, True, True, False, False]]), + "sample_count": 1, + }, + ] + + stats = get_local_batch_stats(micro_batches) + + assert stats.num_micro_batches == 2 + assert stats.num_tokens == 8 + assert stats.num_loss_tokens == 5 + assert stats.num_samples == 3 + assert stats.max_micro_batch_tokens == 5 + + +def test_aggregate_dp_count_skips_collective_for_single_rank(monkeypatch): + called = False + + def fake_all_reduce(*args, **kwargs): + nonlocal called + called = True + + monkeypatch.setattr(dist, "all_reduce", fake_all_reduce) + + num_tokens = aggregate_dp_count( + 17, + dp_world_size=1, + dp_group=None, + device=torch.device("cpu"), + ) + + assert num_tokens == 17 + assert not called + + +def test_aggregate_dp_count_sums_across_dp_group(monkeypatch): + calls = [] + + def fake_all_reduce(tensor, op, group): + calls.append((int(tensor.item()), op, group)) + tensor.add_(13) + + monkeypatch.setattr(dist, "all_reduce", fake_all_reduce) + + num_tokens = aggregate_dp_count( + 17, + dp_world_size=2, + dp_group="dp-group", + device=torch.device("cpu"), + ) + + assert num_tokens == 30 + assert calls == [(17, dist.ReduceOp.SUM, "dp-group")]