@@ -169,8 +169,8 @@ def __init__(
169169 model_format = "" ,
170170 )
171171
172- def forward (self , hidden_states : paddle .Tensor ):
173- out = self .experts (hidden_states , self .gate )
172+ def forward (self , hidden_states : paddle .Tensor , forward_meta : ForwardMeta ):
173+ out = self .experts (hidden_states , self .gate , forward_meta )
174174 return out
175175
176176 def load_state_dict (self , state_dict ):
@@ -269,7 +269,7 @@ def load_state_dict(self, state_dict):
269269 if self .num_shared_experts > 0 :
270270 self .shared_experts .load_state_dict (state_dict )
271271
272- def forward (self , hidden_states : paddle .Tensor , vl_moe_meta : VLMoEMeta ):
272+ def forward (self , hidden_states : paddle .Tensor , forward_meta : ForwardMeta , vl_moe_meta : VLMoEMeta ):
273273 if self .num_shared_experts > 0 :
274274 shared_experts_out = self .shared_experts (hidden_states )
275275 hidden_states , text_input , image_input = text_image_gather_scatter (
@@ -281,8 +281,8 @@ def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
281281 vl_moe_meta .image_index ,
282282 True ,
283283 )
284- text_out = self .text_fused_moe (text_input )
285- image_out = self .image_fused_moe (image_input )
284+ text_out = self .text_fused_moe (text_input , forward_meta )
285+ image_out = self .image_fused_moe (image_input , forward_meta )
286286 hidden_states , _ , _ = text_image_gather_scatter (
287287 hidden_states ,
288288 text_out ,
@@ -388,9 +388,9 @@ def forward(
388388 hidden_states , residual = self .post_attention_layernorm (hidden_states , residual )
389389
390390 if isinstance (self .mlp , Ernie4_5_VLMoE ):
391- hidden_states = self .mlp (hidden_states , vl_moe_meta )
391+ hidden_states = self .mlp (hidden_states , forward_meta , vl_moe_meta )
392392 else :
393- hidden_states = self .mlp (hidden_states )
393+ hidden_states = self .mlp (hidden_states , forward_meta )
394394
395395 return hidden_states , residual
396396
@@ -757,8 +757,8 @@ def empty_input_forward(self):
757757 self .fd_config .model_config .moe_layer_start_index ,
758758 self .fd_config .model_config .num_hidden_layers ,
759759 ):
760- self .ernie .layers [i ].mlp .text_fused_moe (fake_hidden_states )
761- self .ernie .layers [i ].mlp .image_fused_moe (fake_hidden_states )
760+ self .ernie .layers [i ].mlp .text_fused_moe (fake_hidden_states , self . forward_meta )
761+ self .ernie .layers [i ].mlp .image_fused_moe (fake_hidden_states , self . forward_meta )
762762
763763 def get_input_embeddings (
764764 self ,
0 commit comments