[perf, fsdp, trainer] feat: Skip training for zero-advantage responses to speed up RL.#5838
[perf, fsdp, trainer] feat: Skip training for zero-advantage responses to speed up RL.#5838sheilaliuxl wants to merge 3 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a feature to filter out zero-advantage responses during the actor update in PPO training, aiming to save compute by skipping policy gradient calculations for sequences that contribute no meaningful gradient. The implementation includes logic to pad the filtered batch with the shortest zero-advantage samples to maintain alignment across data parallel groups and mini-batches, along with corrections to loss normalization and MFU metrics to ensure consistency with unfiltered training baselines. The reviewer highlighted concerns regarding gradient scaling consistency when match_loss_curve is disabled and potential performance overhead from ghost optimizer steps, suggesting further optimization for gradient clipping and scaling.
746b673 to
3adf36d
Compare
What does this PR do?
SKIP training for zero-advantage responses to speed up RL
At high accuracy, less than
25%samples have non-zero advantage (critic/advantages/zero_adv_ratio), and they don't contribute to thepg_lossat all. With4mini-batches:#mini-batchesas before, but with less#samples:match_loss_curve = Truenum_ghost_opt_stepsopt steps when needed (Very unlikely)1single mini-batch:match_loss_curve = Falsemicro-bs(static bs) ortokens_per_gpu(dynamic bs), faster than both baseline and#1optionBoth options closely match baseline loss curve, and train faster due to the
skip.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,testTest
1.
Qwen2.5 3B LoRA:2epochs forGSM8kMetrics comparison dropping noisy first
10step and last3steps:timing_s/update_actoravg:3.853 -> 2.278 (-40.9%)or1.7xreductionstd:0.183 -> 0.072 (-60.7%)or2.5xreductioncritic/rewards/meanavg:0.909 -> 0.915val-core/openai/gsm8k/acc/mean@1avg:0.855 -> 0.861API and Usage Example
Add to your RL config with
FSDP:Known Limitation
entropy_lossandkl_loss, currently hard coded to raise anExceptionif it violates the conditionkl_loss_coefis sufficiently small, it's up to the user to turn on this feature: No longer lossless, while a good approximation by keeping KL divergence for!0adv samplesDesign & Code Changes
dp_actor.pyhas3losscomponents, when both{#2, #3}are off, we can skiplosscompletely as thepg_losshasadvantageas a weight factor:policy_lossadventropy_loss: By(calculate_entropy: bool, entropy_coeff: float)andkl_loss: By(use_kl_loss: bool, kl_loss_coeff: float)Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)