Skip to content

Commit 31f25cb

Browse files
committed
enable moe/mlp fusion
1 parent dcf59c7 commit 31f25cb

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def __init__(
347347
# self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
348348
# )
349349
# TODO: re-enable these fusions
350-
self.fusion_config.PRE_MOE_FUSION = False
351-
self.fusion_config.POST_MLP_FUSION = False
350+
# self.fusion_config.PRE_MOE_FUSION = False
351+
# self.fusion_config.POST_MLP_FUSION = False
352352

353353
self.self_attn = Llama4Attention(
354354
model_config,
@@ -374,6 +374,9 @@ def __init__(
374374

375375
# self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
376376
# )
377+
self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp()
378+
self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp()
379+
377380
else:
378381
self.feed_forward = Llama4MoE(
379382
num_experts=config.num_local_experts,
@@ -385,6 +388,10 @@ def __init__(
385388
aux_stream=aux_stream,
386389
dtype=config.torch_dtype)
387390

391+
self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp()
392+
self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp()
393+
394+
388395
# self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
389396
# )
390397

0 commit comments

Comments
 (0)