From e445000ba10e9dc8c3cf55c686a9469658cf36c5 Mon Sep 17 00:00:00 2001 From: allenphilipj Date: Thu, 4 Jun 2026 11:52:35 +0100 Subject: [PATCH] Fix GroupedLinear FP8 graph weight update flag Signed-off-by: allenphilipj --- transformer_engine/pytorch/module/grouped_linear.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 15ec3fe322..c1d45511df 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1696,6 +1696,15 @@ def forward( f"does not match number of GEMMs ({num_gemms})." ) + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + # Preprocess input tensor if isinstance(inp, QuantizedTensorStorage): raise TypeError("GroupedLinear doesn't support input tensor in FP8.") @@ -1754,7 +1763,7 @@ def forward( is_grad_enabled, weight_workspaces, cache_weight, - None, # skip_fp8_weight_update + skip_fp8_weight_update, self.save_original_input, debug, )