We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent efce53b commit caa6032Copy full SHA for caa6032
1 file changed
transformer_engine/jax/cpp_extensions/gemm.py
@@ -2586,9 +2586,7 @@ def grouped_gemm(
2586
raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling")
2587
2588
if use_v2_ffi:
2589
- alpha_beta_numel = (
2590
- num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
2591
- )
+ alpha_beta_numel = num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
2592
additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha
2593
additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta
2594
else:
0 commit comments