Skip to content

feat: add router replay(R3) for megatron engine#1207

Open
TaoZex wants to merge 119 commits into
areal-project:mainfrom
TaoZex:final_moe
Open

feat: add router replay(R3) for megatron engine#1207
TaoZex wants to merge 119 commits into
areal-project:mainfrom
TaoZex:final_moe

Conversation

@TaoZex
Copy link
Copy Markdown
Collaborator

@TaoZex TaoZex commented Apr 18, 2026

Description

This PR implements Rollout Routing Replay (R3) for MoE models, addressing training instability caused by inference-training routing discrepancy in asynchronous RL training. R3 records expert routing indices from the inference engine and replays them during training, ensuring consistent expert selection regardless of weight staleness.

Key Changes

Core MoE Patch (router_replay_patch.py):

  • RouterReplay class (one per MoE layer) with RECORD/REPLAY_FORWARD/REPLAY_BACKWARD actions
  • patched_routing: replaces TopKRouter.routing — uses scores.gather(1, target_topk_idx) in replay mode instead of torch.topk, preserving gradient flow
  • Four monkey-patches: TransformerConfig.__init__, TopKRouter.__init__, TopKRouter.routing, MoEAlltoAllTokenDispatcher.preprocess

Data Distribution (router_replay_utils.py):

  • set_router_replay_data: 4-step pipeline — right-pad→left-align → CP split → TP/SP scatter → PP layer slice → Dense/MoE mapping
  • RouterReplayHelper: locates RouterReplay instances by (pp_rank, vp_stage)
  • Layer allocation helpers: get_num_layers_to_build, get_moe_num_layers_to_build (PP/VP aware)

MegatronEngine Integration (megatron_engine_r3_patch.py):

  • Wraps forward_backward_batch: retrieves routed_experts via side-channel, splits per micro-batch, injects replay setup via per-instance class swap, toggles forward/backward replay mode, cleans up in finally

Actor & Workflow Integration (actor_r3_patch.py, rlvr_r3_patch.py):

  • Actor: splits routed_experts per mini-batch, delivers via engine side-channel (bypasses pack_tensor_dict 4D incompatibility)
  • Workflow: resolve_r3_moe_config auto-resolves num_moe_layers/topk from HF config; extract_routed_experts converts SGLang numpy output to left-padded torch tensor

SGLang Integration (sglang_r3_patch.py, sglang_remote.py):

  • Server patch: pre-encodes routed_experts as base64 in TokenizerManager._handle_batch_output (fixes jsonable_encoder silently flattening torch.Tensor to {} when skip_tokenizer_init=True)
  • Client: decodes base64, validates num_sgl_token divisibility

Orchestrator & Config (rl_trainer.py, cli_args.py):

  • return_routed_experts=True → auto-sets enable_router_replay, resolves MoE config, forces skip_tokenizer_init=True, validates SGLang-only support

Supported Parallelism

Dimension Supported Mechanism
TP scatter_to_sequence_parallel_region + seq_align_to by tp_size
PP get_current_rank_layer_info slices per PP rank's MoE layers
VP Cumulative offset by vp_stage in RouterReplayHelper
CP split_packed_seqs_for_context_parallel before TP scatter; seq_align_to = tp_size * cp_size * 2 when cp_size > 1; cp_local disabled to avoid shape mismatch
DP Data flows with mini-batches; no conflict

New Metrics

All metrics are computed under compute_logp/r3 scope at _compute_logp time (training weights = rollout weights, no optimizer drift), using stats_tracker.stat with n_valid_tokens denominator (token-weighted global mean/min/max across ranks). The denominator key n_valid_tokens is filtered from wandb/tensorboard export to avoid clutter.

Metric Definition Meaning
rollout_train_logp_abs_diff mean(|logp_train - logp_rollout|) over valid tokens Mean absolute logprob divergence; R3 should reduce this to BF16 kernel noise
rollout_train_logp_sq_diff mean((logp_train - logp_rollout)²) over valid tokens Mean squared divergence; more sensitive to outliers
rollout_train_k3_kl mean(exp(Δ) - 1 - Δ) where Δ = logp_train - logp_rollout Schulman k3 KL estimator: unbiased, non-negative estimator of KL(π_rollout ‖ π_train)
rollout_train_extreme_frac_tau2 F(τ=2): fraction of tokens with |Δ| > ln(2) Extreme token fraction from Router Replay paper (Eq. 3); tokens whose importance ratio leaves [1/2, 2]
rollout_train_extreme_frac_tau5 F(τ=5): fraction of tokens with |Δ| > ln(5) Same with threshold τ=5; more severe outliers
r3_enabled 1.0 if R3 side-channel active, 0.0 otherwise Binary flag for A/B comparison in wandb

All computed under torch.no_grad() on detached tensors; negligible overhead.

Related Paper

https://arxiv.org/abs/2510.11370 (Ma et al., arXiv:2510.11370, 2025) — proposes R3 to reduce training-inference policy KL divergence and prevent MoE RL training collapse.

Related Issue

Fixes #(issue)

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

  • Backward Compatible: return_routed_experts=False (default) → all R3 code inactive, zero overhead
  • SGLang Only: vLLM backend does not support return_routed_experts; config validation raises explicit error
  • Side-Channel Delivery: routed_experts delivered via engine._r3_pending_routed_experts to bypass pack_tensor_dict 4D incompatibility
  • Server Patch Required: sglang_r3_patch must be installed on inference server to fix torch.Tensor serialization when skip_tokenizer_init=True

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Router Replay (R3) to align Mixture-of-Experts (MoE) routing decisions between rollout inference and training, preventing performance degradation caused by weight staleness in RL. The changes include monkey-patches for Megatron-Core components, engine-level wrappers for micro-batch scheduling, and workflow integrations to propagate routing indices from SGLang. Feedback focuses on critical architectural issues regarding global state and thread safety, specifically the risks of patching class-level iterators and using global lists for router instances. Additionally, there are recommendations to fix potential data loss in uneven batch splitting and to optimize performance by removing GPU-CPU synchronization points in the data processing pipeline.

Comment thread areal/engine/megatron_engine_r3_patch.py Outdated
Comment thread areal/engine/router_replay_patch.py Outdated
Comment thread areal/engine/megatron_engine_r3_patch.py Outdated
Comment thread areal/trainer/ppo/actor_r3_patch.py Outdated
Comment thread areal/engine/router_replay_utils.py Outdated
Comment thread areal/engine/router_replay_patch.py
@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 7, 2026

  1. Unit Test Results (including test_r3_mask_alignment.py and test_router_replay.py)
image
  1. End-to-End Test Results (including test_router_replay_e2e.py)
image

@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 7, 2026

R3 Metric

Compare the metric results with router replay (r3) enabled versus disabled, R3 metrics are tested on [Moonlight-16B-A3B] (https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct), see the example config at ‎examples/math/gsm8k_grpo_megatron_r3.yaml:

The first three metrics in the table are the primary evaluation metrics from the paper; additionally, I introduced two supplementary metrics to further verify the effectiveness of R3.

Let Δ = log π_train − log π_infer, T = set of valid response tokens.

rollout_train_k3_kl

$$\frac{1}{|T|}\sum_{t \in T}\bigl[e^{\Delta_t} - 1 - \Delta_t\bigr]$$

Unbiased, non-negative estimator of $\text{KL}(\pi_{\text{train}} | \pi_{\text{infer}})$ (Schulman k3).

image

rollout_train_extreme_frac_tau2

$$\frac{1}{|T|}\sum_{t \in T} \mathbf{1}\bigl[\lvert\Delta_t\rvert > \ln 2\bigr]$$

Fraction of tokens whose probability ratio exceeds the range $[1/2,,2]$.

image

rollout_train_extreme_frac_tau5

$$\frac{1}{|T|}\sum_{t \in T} \mathbf{1}\bigl[\lvert\Delta_t\rvert > \ln 5\bigr]$$

Fraction of tokens whose probability ratio exceeds the range $[1/5,,5]$.

image

rollout_train_logp_abs_diff

$$\frac{1}{|T|}\sum_{t \in T} \lvert\Delta_t\rvert$$

Mean absolute log-probability difference between training and inference engines.

image

rollout_train_logp_sq_diff

$$\frac{1}{|T|}\sum_{t \in T} \Delta_t^2$$

Mean squared log-probability difference, more sensitive to extreme outliers.

image

@TaoZex TaoZex changed the title [WIP]feat: add router replay for megatron engine feat: add router replay(R3) for megatron engine May 7, 2026
@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 7, 2026

@garrett4wade @rchardx @nuzant Hi guys, I’d really appreciate it if you could help review pr code when you have some spare time, and I’m looking forward to your suggestions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants