Skip to content

Commit 1ea48eb

Browse files
[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear (#3052)
Fix GroupedLinear FP8 graph weight update flag Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com> Co-authored-by: allenphilipj <allenphilipj@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 720ec27 commit 1ea48eb

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,15 @@ def forward(
16961696
f"does not match number of GEMMs ({num_gemms})."
16971697
)
16981698

1699+
if FP8GlobalStateManager.fp8_graph_capturing():
1700+
skip_fp8_weight_update = (
1701+
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
1702+
)
1703+
else:
1704+
skip_fp8_weight_update = None
1705+
if skip_fp8_weight_update is not None:
1706+
is_first_microbatch = False
1707+
16991708
# Preprocess input tensor
17001709
if isinstance(inp, QuantizedTensorStorage):
17011710
raise TypeError("GroupedLinear doesn't support input tensor in FP8.")
@@ -1754,7 +1763,7 @@ def forward(
17541763
is_grad_enabled,
17551764
weight_workspaces,
17561765
cache_weight,
1757-
None, # skip_fp8_weight_update
1766+
skip_fp8_weight_update,
17581767
self.save_original_input,
17591768
debug,
17601769
)

0 commit comments

Comments
 (0)