Skip to content

Commit 506ef2d

Browse files
committed
Fix grouped linear FP8 calibration loop
1 parent 3fffa55 commit 506ef2d

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,8 @@ def forward(
599599

600600
if fp8_calibration:
601601
for i in range(num_gemms):
602-
# amax of input
603-
for i in range(num_gemms):
604-
input_quantizers[i].calibrate(inputmats[i])
605-
for i in range(num_gemms):
606-
weight_quantizers[i].calibrate(weights[i])
602+
input_quantizers[i].calibrate(inputmats[i])
603+
weight_quantizers[i].calibrate(weights[i])
607604

608605
if cpu_offloading:
609606
mark_not_offload(*weights_fp8, *weights)

0 commit comments

Comments
 (0)