feat: add router replay(R3) for megatron engine#1207
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Router Replay (R3) to align Mixture-of-Experts (MoE) routing decisions between rollout inference and training, preventing performance degradation caused by weight staleness in RL. The changes include monkey-patches for Megatron-Core components, engine-level wrappers for micro-batch scheduling, and workflow integrations to propagate routing indices from SGLang. Feedback focuses on critical architectural issues regarding global state and thread safety, specifically the risks of patching class-level iterators and using global lists for router instances. Additionally, there are recommendations to fix potential data loss in uneven batch splitting and to optimize performance by removing GPU-CPU synchronization points in the data processing pipeline.
R3 MetricCompare the metric results with router replay (r3) enabled versus disabled, R3 metrics are tested on [Moonlight-16B-A3B] (https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct), see the example config at examples/math/gsm8k_grpo_megatron_r3.yaml: The first three metrics in the table are the primary evaluation metrics from the paper; additionally, I introduced two supplementary metrics to further verify the effectiveness of R3. Let Δ = log π_train − log π_infer, T = set of valid response tokens.
|
|
@garrett4wade @rchardx @nuzant Hi guys, I’d really appreciate it if you could help review pr code when you have some spare time, and I’m looking forward to your suggestions! |







Description
This PR implements Rollout Routing Replay (R3) for MoE models, addressing training instability caused by inference-training routing discrepancy in asynchronous RL training. R3 records expert routing indices from the inference engine and replays them during training, ensuring consistent expert selection regardless of weight staleness.
Key Changes
Core MoE Patch (router_replay_patch.py):
RouterReplayclass (one per MoE layer) withRECORD/REPLAY_FORWARD/REPLAY_BACKWARDactionspatched_routing: replacesTopKRouter.routing— usesscores.gather(1, target_topk_idx)in replay mode instead oftorch.topk, preserving gradient flowTransformerConfig.__init__,TopKRouter.__init__,TopKRouter.routing,MoEAlltoAllTokenDispatcher.preprocessData Distribution (router_replay_utils.py):
set_router_replay_data: 4-step pipeline — right-pad→left-align → CP split → TP/SP scatter → PP layer slice → Dense/MoE mappingRouterReplayHelper: locates RouterReplay instances by(pp_rank, vp_stage)get_num_layers_to_build,get_moe_num_layers_to_build(PP/VP aware)MegatronEngine Integration (megatron_engine_r3_patch.py):
forward_backward_batch: retrievesrouted_expertsvia side-channel, splits per micro-batch, injects replay setup via per-instance class swap, toggles forward/backward replay mode, cleans up infinallyActor & Workflow Integration (actor_r3_patch.py, rlvr_r3_patch.py):
routed_expertsper mini-batch, delivers via engine side-channel (bypassespack_tensor_dict4D incompatibility)resolve_r3_moe_configauto-resolvesnum_moe_layers/topkfrom HF config;extract_routed_expertsconverts SGLang numpy output to left-padded torch tensorSGLang Integration (sglang_r3_patch.py, sglang_remote.py):
routed_expertsas base64 inTokenizerManager._handle_batch_output(fixesjsonable_encodersilently flatteningtorch.Tensorto{}whenskip_tokenizer_init=True)num_sgl_tokendivisibilityOrchestrator & Config (rl_trainer.py, cli_args.py):
return_routed_experts=True→ auto-setsenable_router_replay, resolves MoE config, forcesskip_tokenizer_init=True, validates SGLang-only supportSupported Parallelism
scatter_to_sequence_parallel_region+seq_align_tobytp_sizeget_current_rank_layer_infoslices per PP rank's MoE layersvp_stageinRouterReplayHelpersplit_packed_seqs_for_context_parallelbefore TP scatter;seq_align_to = tp_size * cp_size * 2whencp_size > 1;cp_localdisabled to avoid shape mismatchNew Metrics
All metrics are computed under
compute_logp/r3scope at_compute_logptime (training weights = rollout weights, no optimizer drift), usingstats_tracker.statwithn_valid_tokensdenominator (token-weighted global mean/min/max across ranks). The denominator keyn_valid_tokensis filtered from wandb/tensorboard export to avoid clutter.rollout_train_logp_abs_diffmean(|logp_train - logp_rollout|)over valid tokensrollout_train_logp_sq_diffmean((logp_train - logp_rollout)²)over valid tokensrollout_train_k3_klmean(exp(Δ) - 1 - Δ)whereΔ = logp_train - logp_rolloutKL(π_rollout ‖ π_train)rollout_train_extreme_frac_tau2F(τ=2): fraction of tokens with|Δ| > ln(2)[1/2, 2]rollout_train_extreme_frac_tau5F(τ=5): fraction of tokens with|Δ| > ln(5)r3_enabled1.0if R3 side-channel active,0.0otherwiseAll computed under
torch.no_grad()on detached tensors; negligible overhead.Related Paper
https://arxiv.org/abs/2510.11370(Ma et al., arXiv:2510.11370, 2025) — proposes R3 to reduce training-inference policy KL divergence and prevent MoE RL training collapse.Related Issue
Fixes #(issue)
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
return_routed_experts=False(default) → all R3 code inactive, zero overheadreturn_routed_experts; config validation raises explicit errorrouted_expertsdelivered viaengine._r3_pending_routed_expertsto bypasspack_tensor_dict4D incompatibilitysglang_r3_patchmust be installed on inference server to fixtorch.Tensorserialization whenskip_tokenizer_init=True