diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 4d22879742a..31c4c314b6d 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -163,8 +163,6 @@ def _build_tf_config(self): self.dtype = PrecisionType.to_dtype(self.param_dtype) override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) - if self.enable_routing_replay: - override_transformer_config["enable_routing_replay"] = True self.provider = None self.vanilla_bridge = self.engine_config.vanilla_mbridge @@ -229,6 +227,9 @@ def _build_tf_config(self): for key, value in override_transformer_config.items(): setattr(provider, key, value) + if self.enable_routing_replay: + provider.enable_routing_replay = True + provider.finalize() self.provider = provider tf_config = None # Will be set after model creation @@ -237,6 +238,13 @@ def _build_tf_config(self): if not self.bridge: self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) + # Set enable_routing_replay directly on tf_config instead of passing through + # override_transformer_config, because dataclass subclasses like MLATransformerConfig + # generate their own __init__ and don't inherit the patched TransformerConfig.__init__ + # that accepts this kwarg. + if self.enable_routing_replay and tf_config is not None: + tf_config.enable_routing_replay = True + if torch.distributed.get_rank() == 0: if tf_config is not None: print(f"TF config: {tf_config}")