Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddleformers.utils.log import logger

from fastdeploy import envs
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
from fastdeploy.platforms import current_platform
Expand Down Expand Up @@ -593,7 +594,7 @@ def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
out = multi_outs[:token_num, :]
return out

def forward(self, x: paddle.Tensor, gate: nn.Layer):
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
"""
Defines the forward computation of the moe layer.

Expand Down
10 changes: 5 additions & 5 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def load_state_dict(self, state_dict):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)

def forward(self, x):
def forward(self, x, forward_meta):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
Expand Down Expand Up @@ -187,10 +187,10 @@ def load_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)

def forward(self, hidden_states: paddle.Tensor):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
""" """
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)
moe_out = self.experts(hidden_states, self.gate, forward_meta)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1:
Expand Down Expand Up @@ -514,7 +514,7 @@ def forward(
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual


Expand Down
23 changes: 17 additions & 6 deletions fastdeploy/model_executor/models/ernie4_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def load_state_dict(self, state_dict):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)

def forward(self, hidden_states: paddle.Tensor):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
gate_up_out = self.up_gate_proj(hidden_states)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
Expand Down Expand Up @@ -213,10 +213,18 @@ def load_state_dict(self, state_dict):
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)

def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
def forward(
self,
hidden_states: paddle.Tensor,
forward_meta: ForwardMeta,
):
out = self.experts(
x=hidden_states,
gate=self.gate,
forward_meta=forward_meta,
)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
s_x = self.shared_experts(hidden_states, forward_meta)
out = out + s_x
return out

Expand Down Expand Up @@ -344,7 +352,10 @@ def forward(
residual,
)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(
hidden_states=hidden_states,
forward_meta=forward_meta,
)

return hidden_states, residual

Expand Down Expand Up @@ -621,7 +632,7 @@ def empty_input_forward(self):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate, self.forward_meta)

def forward(
self,
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/models/ernie4_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def compute_logits(self, hidden_states: paddle.Tensor):

return logits

def empty_input_forward(self):
def empty_input_forward(self, forward_meta):
"""
empty_input_forward
"""
Expand All @@ -448,7 +448,7 @@ def empty_input_forward(self):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.fused_moe(hidden_states=fake_hidden_states, forward_meta=forward_meta)

def forward(
self,
Expand Down
20 changes: 10 additions & 10 deletions fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def __init__(
model_format="",
)

def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
out = self.experts(hidden_states, self.gate, forward_meta)
return out

def load_state_dict(self, state_dict):
Expand Down Expand Up @@ -269,9 +269,9 @@ def load_state_dict(self, state_dict):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)

def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)
hidden_states, text_input, image_input = text_image_gather_scatter(
hidden_states,
vl_moe_meta.text_input,
Expand All @@ -281,8 +281,8 @@ def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
vl_moe_meta.image_index,
True,
)
text_out = self.text_fused_moe(text_input)
image_out = self.image_fused_moe(image_input)
text_out = self.text_fused_moe(text_input, forward_meta)
image_out = self.image_fused_moe(image_input, forward_meta)
hidden_states, _, _ = text_image_gather_scatter(
hidden_states,
text_out,
Expand Down Expand Up @@ -388,9 +388,9 @@ def forward(
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

if isinstance(self.mlp, Ernie4_5_VLMoE):
hidden_states = self.mlp(hidden_states, vl_moe_meta)
hidden_states = self.mlp(hidden_states, forward_meta, vl_moe_meta)
else:
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)

return hidden_states, residual

Expand Down Expand Up @@ -757,8 +757,8 @@ def empty_input_forward(self):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states, self.forward_meta)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states, self.forward_meta)

def get_input_embeddings(
self,
Expand Down
11 changes: 7 additions & 4 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
act_method=fd_config.model_config.hidden_act,
)

def forward(self, x):
def forward(self, x, forward_meta):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
Expand Down Expand Up @@ -161,9 +161,9 @@ def __init__(
reduce_results=False,
)

def forward(self, x):
def forward(self, x, forward_meta):
shared_experts_out = self.shared_experts(x)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Glm4MoeMLP which now requires forward_meta as the second parameter (line 88). The call should be: shared_experts_out = self.shared_experts(x, forward_meta)

Suggested change
shared_experts_out = self.shared_experts(x)
shared_experts_out = self.shared_experts(x, forward_meta)

Copilot uses AI. Check for mistakes.
out = self.experts(x, self.gate)
out = self.experts(x, self.gate, forward_meta)
out = out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tensor_parallel_size > 1:
Expand Down Expand Up @@ -306,7 +306,10 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(
hidden_states,
forward_meta,
)

return hidden_states, residual

Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = ""):
model_format="",
)

def forward(self, hidden_states: paddle.Tensor):
expert_output = self.experts(hidden_states, self.router)
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
expert_output = self.experts(hidden_states, self.router, forward_meta)
return expert_output


Expand Down Expand Up @@ -173,7 +173,7 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual


Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_state_dict(self, state_dict):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)

def forward(self, x):
def forward(self, x, forward_meta):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)

return hidden_states, residual

Expand Down
10 changes: 5 additions & 5 deletions fastdeploy/model_executor/models/qwen3moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
weight_dtype="float32",
)

def forward(self, x):
return self.experts(x, self.gate)
def forward(self, x, forward_meta):
return self.experts(x, self.gate, forward_meta)

def load_state_dict(self, state_dict):
""" """
Expand Down Expand Up @@ -127,7 +127,7 @@ def load_state_dict(self, state_dict):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)

def forward(self, x):
def forward(self, x, forward_meta):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)

return hidden_states, residual

Expand Down Expand Up @@ -430,7 +430,7 @@ def empty_input_forward(self):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, self.forward_meta)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def _propose(self, step_use_cudagraph: bool = False):
self._get_self_hidden_states(hidden_states)
else:
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
self.model.empty_input_forward(self.forward_meta)

def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
Expand Down
Loading