Skip to content

[megatron] fix: enable_routing_replay fails with MLATransformerConfig…#5884

Merged
wuxibin89 merged 1 commit intoverl-project:mainfrom
NoonePauseferg:fix/mla-router-replay
Apr 7, 2026
Merged

[megatron] fix: enable_routing_replay fails with MLATransformerConfig…#5884
wuxibin89 merged 1 commit intoverl-project:mainfrom
NoonePauseferg:fix/mla-router-replay

Conversation

@NoonePauseferg
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes R3 router replay crash when used with DeepSeek models via vanilla mbridge (MLATransformerConfig).

_build_tf_config() passes enable_routing_replay=True through bridge.set_extra_args(**override_transformer_config), which calls MLATransformerConfig.__init__(). But MLATransformerConfig is a Python dataclass — it generates its own __init__ and doesn't call the patched TransformerConfig.__init__ that accepts enable_routing_replay. Result:

TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'

This affects any mbridge model config that is a dataclass subclass of TransformerConfig (DeepSeek, potentially others).

Fix: Remove enable_routing_replay from override_transformer_config dict. Set it directly as an attribute on tf_config / provider after construction.

Related: #4567 (similar fix for Qwen3VLTransformerConfig, still open). This PR is more generic — works for any dataclass config subclass.

Checklist Before Starting

Test

Tested on 10B DeepSeek MoE model (megatron engine + sglang rollout, R3 mode, 32x H100 GPUs, PP=2 TP=2 EP=2):

  • Without fix: TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'
  • With fix: R3 router replay works correctly
    • rollout_corr/log_ppl_diff = 0.0003 (sglang and megatron log-probs match)
    • rollout_corr/kl = 0.0003
    • actor/grad_norm = 0.44 (stable, normal range)
    • Training runs for 10+ steps without issues

API and Usage Example

No API changes. Existing R3 router replay config works as before:

actor_rollout_ref.actor.megatron.router_replay.mode=R3
actor_rollout_ref.rollout.enable_rollout_routing_replay=True

Design & Code Changes

Single file change in verl/workers/engine/megatron/transformer_impl.py:

  1. Remove enable_routing_replay from override_transformer_config dict (was passed to bridge.set_extra_args() which forwards to dataclass __init__)
  2. Add tf_config.enable_routing_replay = True after config creation (vanilla mbridge path)
  3. Add provider.enable_routing_replay = True after overrides (non-vanilla mbridge path)

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add / Update the documentation. — N/A, no doc changes needed.
  • Add unit or end-to-end test(s). — Not feasible: requires mbridge + DeepSeek model + multi-GPU setup. The existing tests/special_e2e/run_ppo_trainer_megatron.sh with ROUTING_REPLAY_MODE=R3 covers this path when run on MoE models.
  • Once your PR is ready for CI, send a message in the ci-request channel.

… (mbridge)

When using R3 router replay with DeepSeek models via vanilla mbridge,
_build_tf_config() passes enable_routing_replay=True through
bridge.set_extra_args(**override_transformer_config). This calls
MLATransformerConfig.__init__() which doesn't accept this kwarg:

  TypeError: MLATransformerConfig.__init__() got an unexpected keyword
  argument 'enable_routing_replay'

MLATransformerConfig is a dataclass that generates its own __init__ and
doesn't call the patched TransformerConfig.__init__ that the router
replay patch modifies to accept enable_routing_replay.

Fix: remove enable_routing_replay from override_transformer_config and
set it directly as an attribute on tf_config/provider after construction.

Tested on 10B DeepSeek MoE (megatron + sglang, 32x H100, R3 mode).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors how the 'enable_routing_replay' flag is handled in 'transformer_impl.py' by moving its assignment from the 'override_transformer_config' dictionary to direct attribute setting on the provider and configuration objects. This change avoids potential 'TypeError' issues with dataclass-based configurations that do not support arbitrary keyword arguments. I have suggested adding a defensive 'pop' operation to ensure the flag is removed from the override dictionary if present, preventing unexpected initialization errors.

I am having trouble creating individual review comments. Click here to see my feedback.

verl/workers/engine/megatron/transformer_impl.py (166-167)

high

To ensure that enable_routing_replay does not cause a TypeError when initializing dataclass-based configurations (like MLATransformerConfig), it should be explicitly removed from the override_transformer_config dictionary. While the logic that was adding it has been removed, it could still be present if provided via self.engine_config.override_transformer_config. The attribute is now correctly set directly on the config or provider objects later in the function.

        override_transformer_config.pop("enable_routing_replay", None)

@wuxibin89 wuxibin89 merged commit b90edef into verl-project:main Apr 7, 2026
55 of 67 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants