Skip to content

Commit b90edef

Browse files
[megatron] fix: enable_routing_replay fails with MLATransformerConfig… (#5884)
### What does this PR do? Fixes R3 router replay crash when used with DeepSeek models via vanilla mbridge (`MLATransformerConfig`). `_build_tf_config()` passes `enable_routing_replay=True` through `bridge.set_extra_args(**override_transformer_config)`, which calls `MLATransformerConfig.__init__()`. But `MLATransformerConfig` is a Python dataclass — it generates its own `__init__` and doesn't call the patched `TransformerConfig.__init__` that accepts `enable_routing_replay`. Result: ``` TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay' ``` This affects any mbridge model config that is a dataclass subclass of `TransformerConfig` (DeepSeek, potentially others). **Fix:** Remove `enable_routing_replay` from `override_transformer_config` dict. Set it directly as an attribute on `tf_config` / `provider` after construction. Related: #4567 (similar fix for `Qwen3VLTransformerConfig`, still open). This PR is more generic — works for any dataclass config subclass. ### Checklist Before Starting - [x] Search for similar PRs: [enable_routing_replay MLATransformerConfig](https://github.com/verl-project/verl/pulls?q=is%3Apr+enable_routing_replay+MLATransformerConfig) — no existing PR for this specific bug - [x] Format the PR title as `[{modules}] {type}: {description}` ### Test Tested on 10B DeepSeek MoE model (megatron engine + sglang rollout, R3 mode, 32x H100 GPUs, PP=2 TP=2 EP=2): - Without fix: `TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'` - With fix: R3 router replay works correctly - `rollout_corr/log_ppl_diff = 0.0003` (sglang and megatron log-probs match) - `rollout_corr/kl = 0.0003` - `actor/grad_norm = 0.44` (stable, normal range) - Training runs for 10+ steps without issues ### API and Usage Example No API changes. Existing R3 router replay config works as before: ```bash actor_rollout_ref.actor.megatron.router_replay.mode=R3 actor_rollout_ref.rollout.enable_rollout_routing_replay=True ``` ### Design & Code Changes Single file change in `verl/workers/engine/megatron/transformer_impl.py`: 1. **Remove** `enable_routing_replay` from `override_transformer_config` dict (was passed to `bridge.set_extra_args()` which forwards to dataclass `__init__`) 2. **Add** `tf_config.enable_routing_replay = True` after config creation (vanilla mbridge path) 3. **Add** `provider.enable_routing_replay = True` after overrides (non-vanilla mbridge path) ### Checklist Before Submitting - [x] Read the Contribute Guide. - [ ] Apply pre-commit checks. - [ ] Add / Update the documentation. — N/A, no doc changes needed. - [ ] Add unit or end-to-end test(s). — Not feasible: requires mbridge + DeepSeek model + multi-GPU setup. The existing `tests/special_e2e/run_ppo_trainer_megatron.sh` with `ROUTING_REPLAY_MODE=R3` covers this path when run on MoE models. - [ ] Once your PR is ready for CI, send a message in the `ci-request` channel.
1 parent 74dc16c commit b90edef

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,6 @@ def _build_tf_config(self):
163163
self.dtype = PrecisionType.to_dtype(self.param_dtype)
164164

165165
override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config})
166-
if self.enable_routing_replay:
167-
override_transformer_config["enable_routing_replay"] = True
168166

169167
self.provider = None
170168
self.vanilla_bridge = self.engine_config.vanilla_mbridge
@@ -229,6 +227,9 @@ def _build_tf_config(self):
229227
for key, value in override_transformer_config.items():
230228
setattr(provider, key, value)
231229

230+
if self.enable_routing_replay:
231+
provider.enable_routing_replay = True
232+
232233
provider.finalize()
233234
self.provider = provider
234235
tf_config = None # Will be set after model creation
@@ -237,6 +238,13 @@ def _build_tf_config(self):
237238
if not self.bridge:
238239
self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)
239240

241+
# Set enable_routing_replay directly on tf_config instead of passing through
242+
# override_transformer_config, because dataclass subclasses like MLATransformerConfig
243+
# generate their own __init__ and don't inherit the patched TransformerConfig.__init__
244+
# that accepts this kwarg.
245+
if self.enable_routing_replay and tf_config is not None:
246+
tf_config.enable_routing_replay = True
247+
240248
if torch.distributed.get_rank() == 0:
241249
if tf_config is not None:
242250
print(f"TF config: {tf_config}")

0 commit comments

Comments
 (0)