Skip to content

Add explicit RL micro-batch token cap and fix RL token accounting#2183

Open
taivu1998 wants to merge 1 commit intoPrimeIntellect-ai:mainfrom
taivu1998:fix/1514-rl-micro-batch-max-tokens
Open

Add explicit RL micro-batch token cap and fix RL token accounting#2183
taivu1998 wants to merge 1 commit intoPrimeIntellect-ai:mainfrom
taivu1998:fix/1514-rl-micro-batch-max-tokens

Conversation

@taivu1998
Copy link
Copy Markdown

@taivu1998 taivu1998 commented Apr 2, 2026

Summary

Closes #1514.

This adds an explicit RL trainer control for local micro-batch packing without reintroducing the old SFT-style micro_batch_size semantics.

The new API is trainer.micro_batch_max_tokens, which caps how many text tokens the RL trainer packs into each local micro batch while preserving the existing RL step semantics:

  • one trainer step still performs one optimizer step
  • checkpoint cadence is unchanged
  • weight-broadcast cadence is unchanged
  • fake RL data remains controlled by trainer.data.fake.batch_size

Motivation

Issue #1514 asks for RL-side gradient accumulation control similar to other training paths. In practice, RL already accumulates gradients implicitly across the packed local micro batches that make up one trainer step, but that behavior was only controlled indirectly by packing against model.seq_len.

This change makes that control explicit and safer by:

  • exposing a dedicated RL knob instead of reviving micro_batch_size
  • keeping sample truncation (model.seq_len) separate from local packing capacity (micro_batch_max_tokens)
  • preserving the existing optimizer-step semantics that the async RL pipeline depends on

What Changed

Config

  • Added trainer.micro_batch_max_tokens: int | None
  • Default remains None, which preserves current behavior by falling back to model.seq_len
  • Validation rejects values above model.seq_len
  • Validation also rejects this knob for fake RL data, since fake RL already has its own local micro-batch control via data.fake.batch_size

RL batching and packing

  • Decoupled per-sample truncation from per-micro-batch packing in the RL batch preparation path
  • Threaded micro_batch_max_tokens through the real RL data loader and packers
  • Kept run-isolation behavior intact for multi-run / LoRA packing
  • Added sample_count metadata on packed micro batches so sample accounting remains correct after packing and padding

Trainer accounting and logging

  • Replaced the old RL token-accounting assumption that used the first micro-batch shape as the whole-step token count
  • RL throughput/progress now use the actual packed local token count across all micro batches in the step
  • Added logging/monitor metrics for:
    • local tokens
    • local loss tokens
    • local samples
    • local micro-batch count
    • max packed local micro-batch tokens
    • configured micro_batch_max_tokens

Docs and tests

  • Updated docs that still referenced nonexistent RL micro-batch-size flags
  • Added tests covering:
    • config validation for micro_batch_max_tokens
    • decoupled sample truncation vs packing cap
    • packer behavior when the cap forces more local micro batches

Design Notes

A key goal here is to stay aligned with the current RL architecture instead of importing SFT semantics wholesale.

This PR intentionally does not:

  • add RL micro_batch_size
  • change the meaning of one RL trainer step
  • modify fake-data batching behavior
  • change orchestrator scheduling or checkpoint step behavior

That keeps the implementation small and makes the new knob do exactly one thing: lower per-forward memory pressure by reducing the token budget of each local RL micro batch.

Verification

I added focused unit coverage for the new config and packing behavior.

Local command attempts:

uv run pytest tests/unit/test_configs.py tests/unit/orchestrator/test_batch.py tests/unit/train/rl/test_packer.py -q

On this machine, the repo lockfile only supports Linux environments, so uv test execution was blocked on macOS. I still verified the patch with:

  • python3 -m py_compile on all changed Python files
  • git diff --check
  • manual diff review of the RL config, batching, packing, trainer-accounting, and docs changes

Note

Medium Risk
Touches RL trainer batching/packing and progress/throughput accounting; mistakes could skew metrics or change effective gradient accumulation, though validation and unit tests reduce risk.

Overview
Adds an explicit RL configuration knob, trainer.micro_batch_max_tokens, to cap tokens packed into each local micro-batch (defaulting to model.seq_len) and validates it can’t exceed model.seq_len and can’t be used with fake RL data.

Threads this cap through real RL packing (prepare_batch/packers) while tracking a new sample_count on MicroBatch to keep sample/token accounting correct across packing and dummy padding. Updates the RL training loop to compute throughput/progress from the actual packed micro-batches (tokens, loss tokens, samples, micro-batch count) and logs these new metrics; docs and unit tests are updated accordingly.

Written by Cursor Bugbot for commit dcd22f8. This will update automatically on new commits. Configure here.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

"optimizer-step or checkpoint semantics."
),
),
] = None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config field

Low Severity

A new config field trainer.micro_batch_max_tokens was added to src/prime_rl/configs/trainer.py (with validation in validate_micro_batch_max_tokens) but CHANGELOG.md has no corresponding entry. Per project rules, any PR modifying configuration structures in src/prime_rl/*/config.py or src/prime_rl/configs/trainer.py must update CHANGELOG.md.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient Accumulation?

1 participant