Skip to content

Commit caa6032

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent efce53b commit caa6032

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

  • transformer_engine/jax/cpp_extensions

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2586,9 +2586,7 @@ def grouped_gemm(
25862586
raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling")
25872587

25882588
if use_v2_ffi:
2589-
alpha_beta_numel = (
2590-
num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
2591-
)
2589+
alpha_beta_numel = num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
25922590
additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha
25932591
additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta
25942592
else:

0 commit comments

Comments
 (0)