File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
tensorrt_llm/_torch/models Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -347,8 +347,8 @@ def __init__(
347
347
# self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
348
348
# )
349
349
# TODO: re-enable these fusions
350
- self .fusion_config .PRE_MOE_FUSION = False
351
- self .fusion_config .POST_MLP_FUSION = False
350
+ # self.fusion_config.PRE_MOE_FUSION = False
351
+ # self.fusion_config.POST_MLP_FUSION = False
352
352
353
353
self .self_attn = Llama4Attention (
354
354
model_config ,
@@ -374,6 +374,9 @@ def __init__(
374
374
375
375
# self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
376
376
# )
377
+ self .fusion_config .PRE_MLP_FUSION = model_config .mapping .has_tp ()
378
+ self .fusion_config .POST_MLP_FUSION = model_config .mapping .has_tp ()
379
+
377
380
else :
378
381
self .feed_forward = Llama4MoE (
379
382
num_experts = config .num_local_experts ,
@@ -385,6 +388,10 @@ def __init__(
385
388
aux_stream = aux_stream ,
386
389
dtype = config .torch_dtype )
387
390
391
+ self .fusion_config .PRE_MOE_FUSION = model_config .mapping .has_tp ()
392
+ self .fusion_config .POST_MOE_FUSION = model_config .mapping .has_tp ()
393
+
394
+
388
395
# self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
389
396
# )
390
397
You can’t perform that action at this time.
0 commit comments