Skip to content

Commit 72149be

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3afce1f commit 72149be

2 files changed

Lines changed: 2 additions & 8 deletions

File tree

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,9 +1332,7 @@ def fc2_wgrad_gemm(
13321332
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params)
13331333
fc1_bias_grad = dact.sum(dim=0)
13341334
dact = ctx.fc1_grad_output_quantizer(dact)
1335-
elif (
1336-
_act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd
1337-
):
1335+
elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd:
13381336
# Fusion: gemm, bias + gelu + quantize
13391337
dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2]
13401338
fc1_bias_grad, dact = dbias_dact_quantize_func(

transformer_engine/pytorch/ops/fuser.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,7 @@ def backward(
294294
FP8GlobalStateManager.is_fp8_enabled()
295295
and FP8GlobalStateManager.keep_backward_unquantized()
296296
)
297-
if (
298-
func_ctx.is_first_module
299-
and not keep_backward_unquantized
300-
and not _is_graph_capturing()
301-
):
297+
if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing():
302298
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
303299

304300
return (

0 commit comments

Comments
 (0)