Skip to content

Commit 12d5c09

Browse files
[Models] Add forward_meta to moe models' forward function
1 parent 3e3558f commit 12d5c09

File tree

11 files changed

+55
-38
lines changed

11 files changed

+55
-38
lines changed

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle
2020
from paddle import nn
2121

22+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2223
from fastdeploy.model_executor.utils import (
2324
TensorTracker,
2425
default_weight_loader,
@@ -198,13 +199,14 @@ def apply(
198199
layer: nn.Layer,
199200
x: paddle.Tensor,
200201
gate: nn.Layer,
202+
forward_meta: ForwardMeta,
201203
) -> paddle.Tensor:
202204
"""
203205
Paddle Cutlass compute Fused MoE.
204206
"""
205207
if layer.ep_size > 1:
206208
is_moe_start_layer = layer.layer_idx == layer.fd_config.model_config.moe_layer_start_index
207-
if layer.fd_config.model_config.moe_phase.phase == "prefill":
209+
if forward_meta.moe_phase.phase == "prefill":
208210
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
209211
self.ep_prefill_runner.clean_low_latency_buffer()
210212
return self.apply_ep_prefill(layer, x, gate)

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from paddleformers.utils.log import logger
2222

2323
from fastdeploy import envs
24+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2425
from fastdeploy.model_executor.layers.utils import get_tensor
2526
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
2627
from fastdeploy.platforms import current_platform
@@ -593,7 +594,7 @@ def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
593594
out = multi_outs[:token_num, :]
594595
return out
595596

596-
def forward(self, x: paddle.Tensor, gate: nn.Layer):
597+
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
597598
"""
598599
Defines the forward computation of the moe layer.
599600
@@ -611,7 +612,7 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer):
611612
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
612613
and token_num >= self.tp_size
613614
):
614-
out = self.forward_split_allgather(x, gate)
615+
out = self.forward_split_allgather(x, gate, forward_meta)
615616
else:
616-
out = self.quant_method.apply(self, x, gate)
617+
out = self.quant_method.apply(self, x, gate, forward_meta)
617618
return out

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def load_state_dict(self, state_dict):
104104
self.up_gate_proj.load_state_dict(state_dict)
105105
self.down_proj.load_state_dict(state_dict)
106106

107-
def forward(self, x):
107+
def forward(self, x, forward_meta):
108108
""" """
109109
gate_up_out = self.up_gate_proj(x)
110110
act_out = self.act_fn(gate_up_out)
@@ -187,7 +187,7 @@ def load_state_dict(self, state_dict):
187187
self.experts.load_state_dict(state_dict)
188188
self.shared_experts.load_state_dict(state_dict)
189189

190-
def forward(self, hidden_states: paddle.Tensor):
190+
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
191191
""" """
192192
shared_experts_out = self.shared_experts(hidden_states)
193193
moe_out = self.experts(hidden_states, self.gate)
@@ -514,7 +514,7 @@ def forward(
514514
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
515515

516516
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
517-
hidden_states = self.mlp(hidden_states)
517+
hidden_states = self.mlp(hidden_states, forward_meta)
518518
return hidden_states, residual
519519

520520

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def load_state_dict(self, state_dict):
9595
self.up_gate_proj.load_state_dict(state_dict)
9696
self.down_proj.load_state_dict(state_dict)
9797

98-
def forward(self, hidden_states: paddle.Tensor):
98+
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
9999
gate_up_out = self.up_gate_proj(hidden_states)
100100
act_out = self.act_fn(gate_up_out)
101101
down_out = self.down_proj(act_out)
@@ -213,8 +213,16 @@ def load_state_dict(self, state_dict):
213213
def update_state_dict(self, state_dict):
214214
self.fused_moe.load_state_dict(state_dict, True)
215215

216-
def forward(self, hidden_states: paddle.Tensor):
217-
out = self.experts(hidden_states, self.gate)
216+
def forward(
217+
self,
218+
hidden_states: paddle.Tensor,
219+
forward_meta: ForwardMeta,
220+
):
221+
out = self.experts(
222+
x=hidden_states,
223+
gate=self.gate,
224+
forward_meta=forward_meta,
225+
)
218226
if self.num_shared_experts > 0:
219227
s_x = self.shared_experts(hidden_states)
220228
out = out + s_x
@@ -344,7 +352,10 @@ def forward(
344352
residual,
345353
)
346354

347-
hidden_states = self.mlp(hidden_states)
355+
hidden_states = self.mlp(
356+
forward_meta=forward_meta,
357+
hidden_states=hidden_states,
358+
)
348359

349360
return hidden_states, residual
350361

@@ -621,7 +632,7 @@ def empty_input_forward(self):
621632
self.fd_config.model_config.moe_layer_start_index,
622633
self.fd_config.model_config.num_hidden_layers,
623634
):
624-
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
635+
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate, self.forward_meta)
625636

626637
def forward(
627638
self,

fastdeploy/model_executor/models/ernie4_5_mtp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def compute_logits(self, hidden_states: paddle.Tensor):
436436

437437
return logits
438438

439-
def empty_input_forward(self):
439+
def empty_input_forward(self, forward_meta):
440440
"""
441441
empty_input_forward
442442
"""
@@ -448,7 +448,7 @@ def empty_input_forward(self):
448448
self.fd_config.model_config.moe_layer_start_index,
449449
self.fd_config.model_config.num_hidden_layers,
450450
):
451-
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
451+
self.ernie.layers[i].mlp.fused_moe(hidden_states=fake_hidden_states, forward_meta=forward_meta)
452452

453453
def forward(
454454
self,

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def __init__(
169169
model_format="",
170170
)
171171

172-
def forward(self, hidden_states: paddle.Tensor):
173-
out = self.experts(hidden_states, self.gate)
172+
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
173+
out = self.experts(hidden_states, self.gate, forward_meta)
174174
return out
175175

176176
def load_state_dict(self, state_dict):
@@ -269,7 +269,7 @@ def load_state_dict(self, state_dict):
269269
if self.num_shared_experts > 0:
270270
self.shared_experts.load_state_dict(state_dict)
271271

272-
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
272+
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
273273
if self.num_shared_experts > 0:
274274
shared_experts_out = self.shared_experts(hidden_states)
275275
hidden_states, text_input, image_input = text_image_gather_scatter(
@@ -281,8 +281,8 @@ def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
281281
vl_moe_meta.image_index,
282282
True,
283283
)
284-
text_out = self.text_fused_moe(text_input)
285-
image_out = self.image_fused_moe(image_input)
284+
text_out = self.text_fused_moe(text_input, forward_meta)
285+
image_out = self.image_fused_moe(image_input, forward_meta)
286286
hidden_states, _, _ = text_image_gather_scatter(
287287
hidden_states,
288288
text_out,
@@ -388,9 +388,9 @@ def forward(
388388
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
389389

390390
if isinstance(self.mlp, Ernie4_5_VLMoE):
391-
hidden_states = self.mlp(hidden_states, vl_moe_meta)
391+
hidden_states = self.mlp(hidden_states, forward_meta, vl_moe_meta)
392392
else:
393-
hidden_states = self.mlp(hidden_states)
393+
hidden_states = self.mlp(hidden_states, forward_meta)
394394

395395
return hidden_states, residual
396396

@@ -757,8 +757,8 @@ def empty_input_forward(self):
757757
self.fd_config.model_config.moe_layer_start_index,
758758
self.fd_config.model_config.num_hidden_layers,
759759
):
760-
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
761-
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
760+
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states, self.forward_meta)
761+
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states, self.forward_meta)
762762

763763
def get_input_embeddings(
764764
self,

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
act_method=fd_config.model_config.hidden_act,
8686
)
8787

88-
def forward(self, x):
88+
def forward(self, x, forward_meta):
8989
""" """
9090
gate_up_out = self.up_gate_proj(x)
9191
act_out = self.act_fn(gate_up_out)
@@ -161,9 +161,9 @@ def __init__(
161161
reduce_results=False,
162162
)
163163

164-
def forward(self, x):
164+
def forward(self, x, forward_meta):
165165
shared_experts_out = self.shared_experts(x)
166-
out = self.experts(x, self.gate)
166+
out = self.experts(x, self.gate, forward_meta)
167167
out = out + shared_experts_out
168168
# We do to TP all reduce after the sum of experts.
169169
if self.tensor_parallel_size > 1:
@@ -306,7 +306,10 @@ def forward(
306306
# Fully Connected
307307
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
308308

309-
hidden_states = self.mlp(hidden_states)
309+
hidden_states = self.mlp(
310+
hidden_states,
311+
forward_meta,
312+
)
310313

311314
return hidden_states, residual
312315

fastdeploy/model_executor/models/gpt_oss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = ""):
124124
model_format="",
125125
)
126126

127-
def forward(self, hidden_states: paddle.Tensor):
128-
expert_output = self.experts(hidden_states, self.router)
127+
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
128+
expert_output = self.experts(hidden_states, self.router, forward_meta)
129129
return expert_output
130130

131131

@@ -173,7 +173,7 @@ def forward(
173173
# Fully Connected
174174
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
175175

176-
hidden_states = self.mlp(hidden_states)
176+
hidden_states = self.mlp(hidden_states, forward_meta)
177177
return hidden_states, residual
178178

179179

fastdeploy/model_executor/models/qwen2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def load_state_dict(self, state_dict):
9090
self.up_gate_proj.load_state_dict(state_dict)
9191
self.down_proj.load_state_dict(state_dict)
9292

93-
def forward(self, x):
93+
def forward(self, x, forward_meta):
9494
""" """
9595
gate_up_out = self.up_gate_proj(x)
9696
act_out = self.act_fn(gate_up_out)
@@ -206,7 +206,7 @@ def forward(
206206
# Fully Connected
207207
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
208208

209-
hidden_states = self.mlp(hidden_states)
209+
hidden_states = self.mlp(hidden_states, forward_meta)
210210

211211
return hidden_states, residual
212212

fastdeploy/model_executor/models/qwen3moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def __init__(
7979
weight_dtype="float32",
8080
)
8181

82-
def forward(self, x):
83-
return self.experts(x, self.gate)
82+
def forward(self, x, forward_meta):
83+
return self.experts(x, self.gate, forward_meta)
8484

8585
def load_state_dict(self, state_dict):
8686
""" """
@@ -127,7 +127,7 @@ def load_state_dict(self, state_dict):
127127
self.up_gate_proj.load_state_dict(state_dict)
128128
self.down_proj.load_state_dict(state_dict)
129129

130-
def forward(self, x):
130+
def forward(self, x, forward_meta):
131131
""" """
132132
gate_up_out = self.up_gate_proj(x)
133133
act_out = self.act_fn(gate_up_out)
@@ -206,7 +206,7 @@ def forward(
206206
# Fully Connected
207207
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
208208

209-
hidden_states = self.mlp(hidden_states)
209+
hidden_states = self.mlp(hidden_states, forward_meta)
210210

211211
return hidden_states, residual
212212

@@ -430,7 +430,7 @@ def empty_input_forward(self):
430430
self.fd_config.model_config.moe_layer_start_index,
431431
self.fd_config.model_config.num_hidden_layers,
432432
):
433-
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
433+
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, self.forward_meta)
434434

435435
def forward(
436436
self,

0 commit comments

Comments
 (0)