Skip to content

Conversation

@Wanglongzhi2001
Copy link
Collaborator

@Wanglongzhi2001 Wanglongzhi2001 commented Nov 20, 2025

Motivation

In some scenarios, such as chunked MoE, we need to update the state of MoE. It's reasonable to write this state variable in forward_meta, so we need to add the forward_meta parameter to the FusedMoE's forward function.

Modifications

Add forward_meta to moe models' forward function.

Usage or Command

No change

Accuracy Tests

No change.

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings November 20, 2025 05:58
@paddle-bot
Copy link

paddle-bot bot commented Nov 20, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds the forward_meta parameter to MoE (Mixture of Experts) models' forward functions to enable access to MoE phase information during forward computation. The change is needed because the forward_meta.moe_phase.phase is used in the fused MoE backend to determine whether to use prefill or decode execution paths.

Key Changes:

  • Updated core MoE layer to accept and propagate forward_meta parameter through the computation pipeline
  • Modified all MoE and MLP forward methods across multiple model architectures to include forward_meta parameter
  • Updated the speculative decoding module to pass forward_meta to empty_input_forward calls

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/moe/moe.py Added forward_meta parameter to FusedMoE forward method and propagated it to quant_method.apply calls
fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py Added forward_meta parameter to MoEMethodBase.apply and uses it to check moe_phase
fastdeploy/model_executor/models/qwen3moe.py Updated Qwen3MoeBlock and Qwen3MLP forward signatures to include forward_meta
fastdeploy/model_executor/models/qwen2.py Updated Qwen2MLP forward signature to include forward_meta
fastdeploy/model_executor/models/gpt_oss.py Updated GptOssMoe forward signature to include and propagate forward_meta
fastdeploy/model_executor/models/glm4_moe.py Updated Glm4MoeMLP and Glm4Moe forward signatures to include forward_meta
fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py Updated Ernie4_5_VLMoE and related classes to include forward_meta; updated empty_input_forward calls
fastdeploy/model_executor/models/ernie4_5_mtp.py Updated empty_input_forward signature to accept forward_meta parameter
fastdeploy/model_executor/models/ernie4_5_moe.py Updated Ernie4_5_MLP and Ernie4_5_MoE forward signatures to include forward_meta; updated empty_input_forward calls
fastdeploy/model_executor/models/deepseek_v3.py Updated DeepSeekV3MLP and DeepSeekV3MoE forward signatures to include forward_meta
fastdeploy/spec_decode/mtp.py Updated empty_input_forward call to pass forward_meta parameter
Comments suppressed due to low confidence (1)

fastdeploy/model_executor/layers/moe/moe.py:615

            out = self.forward_split_allgather(x, gate, forward_meta)

def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Ernie4_5_VLMLP (which inherits from Ernie4_5_MLP) and now requires forward_meta as the second parameter. The call should be: shared_experts_out = self.shared_experts(hidden_states, forward_meta)

Suggested change
shared_experts_out = self.shared_experts(hidden_states)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)

Copilot uses AI. Check for mistakes.
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
out = self.forward_split_allgather(x, gate, forward_meta)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward_split_allgather method is being called with forward_meta parameter (line 615), but its method signature at line 576 doesn't include this parameter. The signature should be updated to: def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta): and the parameter should be passed to quant_method.apply at line 591.

Copilot uses AI. Check for mistakes.
Comment on lines 356 to 357
forward_meta=forward_meta,
hidden_states=hidden_states,
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent parameter ordering: The method signature has forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta) (line 98), but the call uses forward_meta=forward_meta, hidden_states=hidden_states (lines 356-357). While this works with keyword arguments, it's inconsistent with the positional order. Consider using positional order: self.mlp(hidden_states, forward_meta) for better consistency with the method signature.

Suggested change
forward_meta=forward_meta,
hidden_states=hidden_states,
hidden_states,
forward_meta,

Copilot uses AI. Check for mistakes.
forward_meta=forward_meta,
)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Ernie4_5_MLP which now requires forward_meta as the second parameter (line 98). The call should be: s_x = self.shared_experts(hidden_states, forward_meta)

Suggested change
s_x = self.shared_experts(hidden_states)
s_x = self.shared_experts(hidden_states, forward_meta)

Copilot uses AI. Check for mistakes.

def forward(self, x):
def forward(self, x, forward_meta):
shared_experts_out = self.shared_experts(x)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in the shared_experts call. The shared_experts is an instance of Glm4MoeMLP which now requires forward_meta as the second parameter (line 88). The call should be: shared_experts_out = self.shared_experts(x, forward_meta)

Suggested change
shared_experts_out = self.shared_experts(x)
shared_experts_out = self.shared_experts(x, forward_meta)

Copilot uses AI. Check for mistakes.
Comment on lines 192 to 193
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forward_meta parameter in both shared_experts and experts calls. Both methods now require forward_meta. The calls should be:

  • shared_experts_out = self.shared_experts(hidden_states, forward_meta)
  • moe_out = self.experts(hidden_states, self.gate, forward_meta)
Suggested change
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
shared_experts_out = self.shared_experts(hidden_states, forward_meta)
moe_out = self.experts(hidden_states, self.gate, forward_meta)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant