diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 9764178134f..dfc9620f488 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -160,11 +160,10 @@ def __init__( from verl.models.mcore.mtp_patch import patch_postprocess for model in self.actor_module: + patch_postprocess(model) if self.mtp_config: from verl.models.mcore.mtp_patch import patch_mtp_layer_get_embeddings - patch_postprocess(model) - if self.mtp_config.detach_encoder: patch_mtp_layer_get_embeddings(model) diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 784faf084c0..aa26853521c 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -303,6 +303,10 @@ def _build_megatron_module(self): def _maybe_enable_fused_kernels(self): if not self.engine_config.use_fused_kernels: + from verl.models.mcore.mtp_patch import patch_postprocess + + for model in self.module: + patch_postprocess(model) return if self.is_value_model or self.model_config.mtp.enable: