Skip to content

Commit 06d37ee

Browse files
Hopper BF16 grouped GEMM v2 support + native-JAX bias for now
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
1 parent abdb406 commit 06d37ee

2 files changed

Lines changed: 80 additions & 23 deletions

File tree

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,8 +1519,8 @@ def abstract(
15191519
additional_args: Either
15201520
* group_offsets: 1D array containing offsets for each group (not yet implemented)
15211521
OR
1522-
* alpha: 1D array of shape (G,) containing alpha values for each group
1523-
* beta: 1D array of shape (G,) containing beta values for each group
1522+
* alpha: 1D array of shape (G,) or (1,) containing alpha values
1523+
* beta: 1D array of shape (G,) or (1,) containing beta values
15241524
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
15251525
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
15261526
scaling_mode: Scaling mode for the GEMM operations
@@ -1606,12 +1606,17 @@ def abstract(
16061606
f" GEMM primitive, but got {len(additional_args)} arguments."
16071607
)
16081608
alpha_aval, beta_aval = additional_args
1609-
if alpha_aval.shape != (num_groups,):
1610-
raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}")
1609+
valid_alpha_beta_shapes = ((num_groups,), (1,))
1610+
if alpha_aval.shape not in valid_alpha_beta_shapes:
1611+
raise ValueError(
1612+
f"Expected alpha shape {(num_groups,)} or (1,), got {alpha_aval.shape}"
1613+
)
16111614
if alpha_aval.dtype != jnp.float32:
16121615
raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}")
1613-
if beta_aval.shape != (num_groups,):
1614-
raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}")
1616+
if beta_aval.shape not in valid_alpha_beta_shapes:
1617+
raise ValueError(
1618+
f"Expected beta shape {(num_groups,)} or (1,), got {beta_aval.shape}"
1619+
)
16151620
if beta_aval.dtype != jnp.float32:
16161621
raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}")
16171622

@@ -2091,6 +2096,11 @@ def _should_enforce_v2_grouped_gemm() -> bool:
20912096
) from e
20922097

20932098

2099+
def _v2_grouped_gemm_supports_per_group_alpha_beta() -> bool:
2100+
"""Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices."""
2101+
return get_min_device_compute_capability() >= 100
2102+
2103+
20942104
def _is_v2_grouped_gemm_supported(
20952105
scaling_mode: ScalingMode,
20962106
dtype: jnp.dtype,
@@ -2111,24 +2121,31 @@ def _is_v2_grouped_gemm_supported(
21112121
),
21122122
)
21132123

2114-
# nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer).
2115-
# Fall back to the v1 path on SM90 (Hopper) and older architectures.
2116-
if get_min_device_compute_capability() < 100:
2124+
# nvte_grouped_gemm (the v2 kernel) supports BF16 on SM90+ (Hopper or newer).
2125+
# MXFP8 remains gated to SM100+ below.
2126+
if get_min_device_compute_capability() < 90:
21172127
return (
21182128
False,
21192129
(
2120-
"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current min device"
2130+
"The TE V2 grouped GEMM requires SM90+ (Hopper or newer) but current min device"
21212131
f" compute capability is {get_min_device_compute_capability()}."
21222132
),
21232133
)
21242134

2125-
if has_bias:
2126-
return False, "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel."
2127-
21282135
if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16:
21292136
return True, ""
21302137

21312138
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
2139+
if get_min_device_compute_capability() < 100:
2140+
return (
2141+
False,
2142+
(
2143+
"The TE V2 grouped GEMM for MXFP8 requires SM100+ (Blackwell or newer) but"
2144+
" current min device compute capability is"
2145+
f" {get_min_device_compute_capability()}."
2146+
),
2147+
)
2148+
21322149
# V2 MXFP8 requires that the total first dimension of both operands (up to
21332150
# axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement.
21342151
# Individual group sizes must also be 128-aligned (dynamic constraint).
@@ -2188,9 +2205,10 @@ def _is_v2_grouped_gemm_supported(
21882205
return (
21892206
False,
21902207
(
2191-
"The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D"
2192-
" block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input"
2193-
f" parameters do not meet these requirements (scaling_mode= {scaling_mode},"
2208+
"The TE V2 grouped GEMM currently only supports non-quantized BF16, and MXFP8 with"
2209+
" 1D block scaling on SM100+, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and"
2210+
" the input parameters do not meet these requirements"
2211+
f" (scaling_mode= {scaling_mode},"
21942212
f" dtype={dtype}, has_bias={has_bias}, lhs_shape={lhs_shape}, rhs_shape={rhs_shape},"
21952213
f" lhs_axis_boundary={lhs_axis_boundary}, rhs_axis_boundary={rhs_axis_boundary})."
21962214
),
@@ -2390,6 +2408,35 @@ def _get_num_gemms(
23902408
)
23912409

23922410

2411+
def _add_grouped_gemm_bias(
2412+
out: jnp.ndarray,
2413+
bias: jnp.ndarray,
2414+
out_first_dims: Optional[jnp.ndarray],
2415+
out_last_dims: Optional[jnp.ndarray],
2416+
out_shape: Tuple[int, ...],
2417+
num_gemms: int,
2418+
n_dim: int,
2419+
) -> jnp.ndarray:
2420+
"""Add grouped GEMM bias in JAX for V2 kernels that do not fuse bias."""
2421+
if out_last_dims is not None:
2422+
raise NotImplementedError("V2 grouped GEMM bias is not supported for ragged last dims")
2423+
2424+
bias = bias.astype(out.dtype)
2425+
bias_2d = bias.reshape((num_gemms, n_dim))
2426+
if out_first_dims is not None:
2427+
out_2d = out.reshape((-1, n_dim))
2428+
bias_rows = jnp.repeat(
2429+
bias_2d,
2430+
out_first_dims,
2431+
axis=0,
2432+
total_repeat_length=out_2d.shape[0],
2433+
)
2434+
return (out_2d + bias_rows).reshape(out_shape)
2435+
2436+
bias_shape = (num_gemms,) + (1,) * (out.ndim - 2) + (n_dim,)
2437+
return out + bias_2d.reshape(bias_shape)
2438+
2439+
23932440
def grouped_gemm(
23942441
lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x],
23952442
rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x],
@@ -2509,7 +2556,8 @@ def grouped_gemm(
25092556
num_gemms,
25102557
N_dim,
25112558
), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}"
2512-
bias = jnp.empty((), jnp.float32) if bias is None else bias
2559+
else:
2560+
N_dim = 0
25132561

25142562
if group_offset is not None:
25152563
raise RuntimeError(
@@ -2538,18 +2586,23 @@ def grouped_gemm(
25382586
raise ValueError("rhs must be pre-swizzled for MXFP8 1D scaling")
25392587

25402588
if use_v2_ffi:
2541-
additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha
2542-
additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta
2589+
alpha_beta_numel = (
2590+
num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta() else 1
2591+
)
2592+
additional_arg_0 = jnp.ones((alpha_beta_numel,), jnp.float32) # alpha
2593+
additional_arg_1 = jnp.zeros((alpha_beta_numel,), jnp.float32) # beta
25432594
else:
25442595
additional_arg_0 = jnp.zeros((1,), jnp.int32) # group_offset
25452596
additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder
2597+
bias_for_ffi = jnp.empty((), jnp.float32) if (bias is None or use_v2_ffi) else bias
2598+
has_bias_for_ffi = has_bias and not use_v2_ffi
25462599

25472600
(out,) = GroupedGemmPrimitive.outer_primitive.bind(
25482601
lhs.data,
25492602
lhs.scale_inv if isinstance(lhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32),
25502603
rhs.data,
25512604
rhs.scale_inv if isinstance(rhs, GroupedScaledTensor1x) else jnp.empty((0,), jnp.float32),
2552-
bias,
2605+
bias_for_ffi,
25532606
lhs.first_dims if lhs.first_dims is not None else empty_gs,
25542607
lhs.last_dims if lhs.last_dims is not None else empty_gs,
25552608
rhs.first_dims if rhs.first_dims is not None else empty_gs,
@@ -2562,7 +2615,7 @@ def grouped_gemm(
25622615
rhs_is_trans=rhs_is_trans,
25632616
scaling_mode=scaling_mode.value,
25642617
out_dtype=out_dtype,
2565-
has_bias=has_bias,
2618+
has_bias=has_bias_for_ffi,
25662619
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
25672620
use_v2_ffi=use_v2_ffi,
25682621
lhs_axis_boundary=lhs_axis_boundary,
@@ -2573,4 +2626,8 @@ def grouped_gemm(
25732626
rhs_left_size=int(rhs_left_size),
25742627
rhs_right_size=int(rhs_right_size),
25752628
)
2629+
if use_v2_ffi and has_bias:
2630+
out = _add_grouped_gemm_bias(
2631+
out, bias, out_first_dims, out_last_dims, out_shape, num_gemms, N_dim
2632+
)
25762633
return out

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -953,10 +953,10 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty
953953
DType::kByte);
954954

955955
TensorWrapper alpha_tensor(static_cast<void *>(alpha.untyped_data()),
956-
std::vector<size_t>{num_gemms},
956+
std::vector<size_t>{alpha.element_count()},
957957
convert_ffi_datatype_to_te_dtype(alpha.element_type()));
958958
TensorWrapper beta_tensor(static_cast<void *>(beta.untyped_data()),
959-
std::vector<size_t>{num_gemms},
959+
std::vector<size_t>{beta.element_count()},
960960
convert_ffi_datatype_to_te_dtype(beta.element_type()));
961961

962962
// Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed.

0 commit comments

Comments
 (0)