Skip to content

[perf, fsdp, trainer] feat: Skip training for zero-advantage responses to speed up RL.#5838

Open
sheilaliuxl wants to merge 3 commits intoverl-project:mainfrom
sheilaliuxl:dev-skip-zero-adv
Open

[perf, fsdp, trainer] feat: Skip training for zero-advantage responses to speed up RL.#5838
sheilaliuxl wants to merge 3 commits intoverl-project:mainfrom
sheilaliuxl:dev-skip-zero-adv

Conversation

@sheilaliuxl
Copy link
Copy Markdown
Contributor

@sheilaliuxl sheilaliuxl commented Apr 1, 2026

What does this PR do?

SKIP training for zero-advantage responses to speed up RL

At high accuracy, less than 25% samples have non-zero advantage (critic/advantages/zero_adv_ratio), and they don't contribute to the pg_loss at all. With 4 mini-batches:

  1. It can train with the same #mini-batches as before, but with less #samples: match_loss_curve = True
    • This option is designed to match baseline as much as possible, including num_ghost_opt_steps opt steps when needed (Very unlikely)
  2. It can be reduced into 1 single mini-batch: match_loss_curve = False
    • This option is designed to train as fast as possible with the same micro-bs (static bs) or tokens_per_gpu (dynamic bs), faster than both baseline and #1 option

Both options closely match baseline loss curve, and train faster due to the skip.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: https://github.com/verl-project/verl/pulls?q=is%3Apr+zero+adv+is%3Aopen
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test

Test

1. Qwen2.5 3B LoRA: 2 epochs for GSM8k

Metrics comparison dropping noisy first 10 step and last 3 steps:

  1. timing_s/update_actor
    • avg: 3.853 -> 2.278 (-40.9%) or 1.7x reduction
    • std: 0.183 -> 0.072 (-60.7%) or 2.5x reduction
  2. critic/rewards/mean
    • avg: 0.909 -> 0.915
  3. val-core/openai/gsm8k/acc/mean@1
    • avg: 0.855 -> 0.861
sliuxl@c889f3b957d9 20:35 ~/qwen2d5 $ M=timing_s/update_actor; conda run -n et python ~/tb.py --files "sliuxl-mverl--qwen2d5_3b_inst--gsm8k/*/events*" --metric "$M"
2026-03-31 20:35:42,832 [tb.py:206] WARNING - Unknonw metric `timing_s/update_actor`!
2026-03-31 20:35:43,069 [tb.py:235] INFO - Set benchmark with index 0
2026-03-31 20:35:43,070 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13954-run-mverl-rl-0adv-05-j-01-ns01-tp01-bs16-pad0-base--20260331.152146/events.out.tfevents.1774996037.ip-10-4-145-27.2208577.0: (avg (std), med) = (  3.853 (  0.183),   3.814)  # timing_s/update_actor
2026-03-31 20:35:43,337 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13955-run-mverl-rl-0adv-05-j-02-ns01-tp01-bs16-pad0-rm-0adv-yes-mlc--20260331.152151/events.out.tfevents.1775000615.ip-10-4-145-27.2332233.0: (avg (std), med) = (  2.278 (  0.072),   2.254)  # timing_s/update_actor
2026-03-31 20:35:43,606 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13956-run-mverl-rl-0adv-05-j-03-ns01-tp01-bs16-pad0-rm-0adv-no-mlc--20260331.152156/events.out.tfevents.1775004996.ip-10-4-145-27.2451478.0: (avg (std), med) = (  1.652 (  0.473),   1.527)  # timing_s/update_actor

sliuxl@c889f3b957d9 19:56 ~/qwen2d5 $ M=critic/rewards/mean; conda run -n et python ~/tb.py --files "sliuxl-mverl--qwen2d5_3b_inst--gsm8k/*/events*" --metric "$M"
2026-03-31 19:57:02,837 [tb.py:206] WARNING - Unknonw metric `critic/rewards/mean`!
2026-03-31 19:57:03,085 [tb.py:235] INFO - Set benchmark with index 0
2026-03-31 19:57:03,086 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13954-run-mverl-rl-0adv-05-j-01-ns01-tp01-bs16-pad0-base--20260331.152146/events.out.tfevents.1774996037.ip-10-4-145-27.2208577.0: (avg (std), med) = (  0.909 (  0.028),   0.913)  # critic/rewards/mean
2026-03-31 19:57:03,348 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13955-run-mverl-rl-0adv-05-j-02-ns01-tp01-bs16-pad0-rm-0adv-yes-mlc--20260331.152151/events.out.tfevents.1775000615.ip-10-4-145-27.2332233.0: (avg (std), med) = (  0.915 (  0.026),   0.918)  # critic/rewards/mean
2026-03-31 19:57:03,630 [tb.py:264] INFO - [103/116] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13956-run-mverl-rl-0adv-05-j-03-ns01-tp01-bs16-pad0-rm-0adv-no-mlc--20260331.152156/events.out.tfevents.1775004996.ip-10-4-145-27.2451478.0: (avg (std), med) = (  0.904 (  0.027),   0.906)  # critic/rewards/mean


sliuxl@c889f3b957d9 20:35 ~/qwen2d5 $ M=val-core/openai/gsm8k/acc/mean@1; conda run -n et python ~/tb.py --files "sliuxl-mverl--qwen2d5_3b_inst--gsm8k/*/events*" --metric "$M"
2026-03-31 20:38:39,184 [tb.py:206] WARNING - Unknonw metric `val-core/openai/gsm8k/acc/mean@1`!
2026-03-31 20:38:39,432 [tb.py:235] INFO - Set benchmark with index 0
2026-03-31 20:38:39,434 [tb.py:264] INFO - [12/25] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13954-run-mverl-rl-0adv-05-j-01-ns01-tp01-bs16-pad0-base--20260331.152146/events.out.tfevents.1774996037.ip-10-4-145-27.2208577.0: (avg (std), med) = (  0.855 (  0.004),   0.857)  # val-core/openai/gsm8k/acc/mean@1
2026-03-31 20:38:39,695 [tb.py:264] INFO - [12/25] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13955-run-mverl-rl-0adv-05-j-02-ns01-tp01-bs16-pad0-rm-0adv-yes-mlc--20260331.152151/events.out.tfevents.1775000615.ip-10-4-145-27.2332233.0: (avg (std), med) = (  0.861 (  0.005),   0.863)  # val-core/openai/gsm8k/acc/mean@1
2026-03-31 20:38:39,972 [tb.py:264] INFO - [12/25] sliuxl-mverl--qwen2d5_3b_inst--gsm8k/verl_feature_config_pbtxt-dc7fdce1--13956-run-mverl-rl-0adv-05-j-03-ns01-tp01-bs16-pad0-rm-0adv-no-mlc--20260331.152156/events.out.tfevents.1775004996.ip-10-4-145-27.2451478.0: (avg (std), med) = (  0.859 (  0.004),   0.860)  # val-core/openai/gsm8k/acc/mean@1
Screenshot 2026-03-31 at 19 53 42

API and Usage Example

Add to your RL config with FSDP:

    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.use_kl_loss=False \

    algorithm.filter_zero_adv.enable=True \
    algorithm.filter_zero_adv.match_loss_curve=True \

    trainer.use_legacy_worker_impl=auto \

Known Limitation

  • Strict: To match baseline loss curve, it requires turning off both entropy_loss and kl_loss, currently hard coded to raise an Exception if it violates the condition
  • When kl_loss_coef is sufficiently small, it's up to the user to turn on this feature: No longer lossless, while a good approximation by keeping KL divergence for !0 adv samples

Design & Code Changes

dp_actor.py has 3 loss components, when both {#2, #3} are off, we can skip loss completely as the pg_loss has advantage as a weight factor:

  1. policy_loss $\propto$ adv
  2. entropy_loss: By (calculate_entropy: bool, entropy_coeff: float) and
  3. kl_loss: By (use_kl_loss: bool, kl_loss_coeff: float)

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a feature to filter out zero-advantage responses during the actor update in PPO training, aiming to save compute by skipping policy gradient calculations for sequences that contribute no meaningful gradient. The implementation includes logic to pad the filtered batch with the shortest zero-advantage samples to maintain alignment across data parallel groups and mini-batches, along with corrections to loss normalization and MFU metrics to ensure consistency with unfiltered training baselines. The reviewer highlighted concerns regarding gradient scaling consistency when match_loss_curve is disabled and potential performance overhead from ghost optimizer steps, suggesting further optimization for gradient clipping and scaling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant