[megatron] fix: enable_routing_replay fails with MLATransformerConfig…#5884
Conversation
… (mbridge) When using R3 router replay with DeepSeek models via vanilla mbridge, _build_tf_config() passes enable_routing_replay=True through bridge.set_extra_args(**override_transformer_config). This calls MLATransformerConfig.__init__() which doesn't accept this kwarg: TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay' MLATransformerConfig is a dataclass that generates its own __init__ and doesn't call the patched TransformerConfig.__init__ that the router replay patch modifies to accept enable_routing_replay. Fix: remove enable_routing_replay from override_transformer_config and set it directly as an attribute on tf_config/provider after construction. Tested on 10B DeepSeek MoE (megatron + sglang, 32x H100, R3 mode).
There was a problem hiding this comment.
Code Review
This pull request refactors how the 'enable_routing_replay' flag is handled in 'transformer_impl.py' by moving its assignment from the 'override_transformer_config' dictionary to direct attribute setting on the provider and configuration objects. This change avoids potential 'TypeError' issues with dataclass-based configurations that do not support arbitrary keyword arguments. I have suggested adding a defensive 'pop' operation to ensure the flag is removed from the override dictionary if present, preventing unexpected initialization errors.
I am having trouble creating individual review comments. Click here to see my feedback.
verl/workers/engine/megatron/transformer_impl.py (166-167)
To ensure that enable_routing_replay does not cause a TypeError when initializing dataclass-based configurations (like MLATransformerConfig), it should be explicitly removed from the override_transformer_config dictionary. While the logic that was adding it has been removed, it could still be present if provided via self.engine_config.override_transformer_config. The attribute is now correctly set directly on the config or provider objects later in the function.
override_transformer_config.pop("enable_routing_replay", None)
What does this PR do?
Fixes R3 router replay crash when used with DeepSeek models via vanilla mbridge (
MLATransformerConfig)._build_tf_config()passesenable_routing_replay=Truethroughbridge.set_extra_args(**override_transformer_config), which callsMLATransformerConfig.__init__(). ButMLATransformerConfigis a Python dataclass — it generates its own__init__and doesn't call the patchedTransformerConfig.__init__that acceptsenable_routing_replay. Result:This affects any mbridge model config that is a dataclass subclass of
TransformerConfig(DeepSeek, potentially others).Fix: Remove
enable_routing_replayfromoverride_transformer_configdict. Set it directly as an attribute ontf_config/providerafter construction.Related: #4567 (similar fix for
Qwen3VLTransformerConfig, still open). This PR is more generic — works for any dataclass config subclass.Checklist Before Starting
[{modules}] {type}: {description}Test
Tested on 10B DeepSeek MoE model (megatron engine + sglang rollout, R3 mode, 32x H100 GPUs, PP=2 TP=2 EP=2):
TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'rollout_corr/log_ppl_diff = 0.0003(sglang and megatron log-probs match)rollout_corr/kl = 0.0003actor/grad_norm = 0.44(stable, normal range)API and Usage Example
No API changes. Existing R3 router replay config works as before:
Design & Code Changes
Single file change in
verl/workers/engine/megatron/transformer_impl.py:enable_routing_replayfromoverride_transformer_configdict (was passed tobridge.set_extra_args()which forwards to dataclass__init__)tf_config.enable_routing_replay = Trueafter config creation (vanilla mbridge path)provider.enable_routing_replay = Trueafter overrides (non-vanilla mbridge path)Checklist Before Submitting
tests/special_e2e/run_ppo_trainer_megatron.shwithROUTING_REPLAY_MODE=R3covers this path when run on MoE models.ci-requestchannel.