@@ -1519,8 +1519,8 @@ def abstract(
15191519 additional_args: Either
15201520 * group_offsets: 1D array containing offsets for each group (not yet implemented)
15211521 OR
1522- * alpha: 1D array of shape (G,) containing alpha values for each group
1523- * beta: 1D array of shape (G,) containing beta values for each group
1522+ * alpha: 1D array of shape (G,) or (1,) containing alpha values
1523+ * beta: 1D array of shape (G,) or (1,) containing beta values
15241524 lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
15251525 rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
15261526 scaling_mode: Scaling mode for the GEMM operations
@@ -1606,12 +1606,17 @@ def abstract(
16061606 f" GEMM primitive, but got { len (additional_args )} arguments."
16071607 )
16081608 alpha_aval , beta_aval = additional_args
1609- if alpha_aval .shape != (num_groups ,):
1610- raise ValueError (f"Expected alpha shape { (num_groups ,)} , got { alpha_aval .shape } " )
1609+ valid_alpha_beta_shapes = ((num_groups ,), (1 ,))
1610+ if alpha_aval .shape not in valid_alpha_beta_shapes :
1611+ raise ValueError (
1612+ f"Expected alpha shape { (num_groups ,)} or (1,), got { alpha_aval .shape } "
1613+ )
16111614 if alpha_aval .dtype != jnp .float32 :
16121615 raise ValueError (f"Expected alpha dtype float32, got { alpha_aval .dtype } " )
1613- if beta_aval .shape != (num_groups ,):
1614- raise ValueError (f"Expected beta shape { (num_groups ,)} , got { beta_aval .shape } " )
1616+ if beta_aval .shape not in valid_alpha_beta_shapes :
1617+ raise ValueError (
1618+ f"Expected beta shape { (num_groups ,)} or (1,), got { beta_aval .shape } "
1619+ )
16151620 if beta_aval .dtype != jnp .float32 :
16161621 raise ValueError (f"Expected beta dtype float32, got { beta_aval .dtype } " )
16171622
@@ -2091,6 +2096,11 @@ def _should_enforce_v2_grouped_gemm() -> bool:
20912096 ) from e
20922097
20932098
2099+ def _v2_grouped_gemm_supports_per_group_alpha_beta () -> bool :
2100+ """Whether nvte_grouped_gemm accepts per-group alpha/beta on all visible devices."""
2101+ return get_min_device_compute_capability () >= 100
2102+
2103+
20942104def _is_v2_grouped_gemm_supported (
20952105 scaling_mode : ScalingMode ,
20962106 dtype : jnp .dtype ,
@@ -2111,24 +2121,31 @@ def _is_v2_grouped_gemm_supported(
21112121 ),
21122122 )
21132123
2114- # nvte_grouped_gemm (the v2 kernel) requires SM100 + (Blackwell or newer).
2115- # Fall back to the v1 path on SM90 (Hopper) and older architectures .
2116- if get_min_device_compute_capability () < 100 :
2124+ # nvte_grouped_gemm (the v2 kernel) supports BF16 on SM90 + (Hopper or newer).
2125+ # MXFP8 remains gated to SM100+ below .
2126+ if get_min_device_compute_capability () < 90 :
21172127 return (
21182128 False ,
21192129 (
2120- "The TE V2 grouped GEMM requires SM100 + (Blackwell or newer) but current min device"
2130+ "The TE V2 grouped GEMM requires SM90 + (Hopper or newer) but current min device"
21212131 f" compute capability is { get_min_device_compute_capability ()} ."
21222132 ),
21232133 )
21242134
2125- if has_bias :
2126- return False , "Grouped GEMM with bias is not supported in the TE V2 grouped GEMM kernel."
2127-
21282135 if scaling_mode == ScalingMode .NO_SCALING and dtype == jnp .bfloat16 :
21292136 return True , ""
21302137
21312138 if scaling_mode == ScalingMode .MXFP8_1D_SCALING :
2139+ if get_min_device_compute_capability () < 100 :
2140+ return (
2141+ False ,
2142+ (
2143+ "The TE V2 grouped GEMM for MXFP8 requires SM100+ (Blackwell or newer) but"
2144+ " current min device compute capability is"
2145+ f" { get_min_device_compute_capability ()} ."
2146+ ),
2147+ )
2148+
21322149 # V2 MXFP8 requires that the total first dimension of both operands (up to
21332150 # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement.
21342151 # Individual group sizes must also be 128-aligned (dynamic constraint).
@@ -2188,9 +2205,10 @@ def _is_v2_grouped_gemm_supported(
21882205 return (
21892206 False ,
21902207 (
2191- "The TE V2 grouped GEMM currently only supports non-quantized BF16 and MXFP8 with 1D"
2192- " block scaling, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and the input"
2193- f" parameters do not meet these requirements (scaling_mode= { scaling_mode } ,"
2208+ "The TE V2 grouped GEMM currently only supports non-quantized BF16, and MXFP8 with"
2209+ " 1D block scaling on SM100+, but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled and"
2210+ " the input parameters do not meet these requirements"
2211+ f" (scaling_mode= { scaling_mode } ,"
21942212 f" dtype={ dtype } , has_bias={ has_bias } , lhs_shape={ lhs_shape } , rhs_shape={ rhs_shape } ,"
21952213 f" lhs_axis_boundary={ lhs_axis_boundary } , rhs_axis_boundary={ rhs_axis_boundary } )."
21962214 ),
@@ -2390,6 +2408,35 @@ def _get_num_gemms(
23902408 )
23912409
23922410
2411+ def _add_grouped_gemm_bias (
2412+ out : jnp .ndarray ,
2413+ bias : jnp .ndarray ,
2414+ out_first_dims : Optional [jnp .ndarray ],
2415+ out_last_dims : Optional [jnp .ndarray ],
2416+ out_shape : Tuple [int , ...],
2417+ num_gemms : int ,
2418+ n_dim : int ,
2419+ ) -> jnp .ndarray :
2420+ """Add grouped GEMM bias in JAX for V2 kernels that do not fuse bias."""
2421+ if out_last_dims is not None :
2422+ raise NotImplementedError ("V2 grouped GEMM bias is not supported for ragged last dims" )
2423+
2424+ bias = bias .astype (out .dtype )
2425+ bias_2d = bias .reshape ((num_gemms , n_dim ))
2426+ if out_first_dims is not None :
2427+ out_2d = out .reshape ((- 1 , n_dim ))
2428+ bias_rows = jnp .repeat (
2429+ bias_2d ,
2430+ out_first_dims ,
2431+ axis = 0 ,
2432+ total_repeat_length = out_2d .shape [0 ],
2433+ )
2434+ return (out_2d + bias_rows ).reshape (out_shape )
2435+
2436+ bias_shape = (num_gemms ,) + (1 ,) * (out .ndim - 2 ) + (n_dim ,)
2437+ return out + bias_2d .reshape (bias_shape )
2438+
2439+
23932440def grouped_gemm (
23942441 lhs : Union [GroupedNoScaleTensor , GroupedScaledTensor1x ],
23952442 rhs : Union [GroupedNoScaleTensor , GroupedScaledTensor1x ],
@@ -2509,7 +2556,8 @@ def grouped_gemm(
25092556 num_gemms ,
25102557 N_dim ,
25112558 ), f"bias shape { bias .shape } does not match expected shape { (num_gemms , N_dim )} "
2512- bias = jnp .empty ((), jnp .float32 ) if bias is None else bias
2559+ else :
2560+ N_dim = 0
25132561
25142562 if group_offset is not None :
25152563 raise RuntimeError (
@@ -2538,18 +2586,23 @@ def grouped_gemm(
25382586 raise ValueError ("rhs must be pre-swizzled for MXFP8 1D scaling" )
25392587
25402588 if use_v2_ffi :
2541- additional_arg_0 = jnp .ones ((num_gemms ,), jnp .float32 ) # alpha
2542- additional_arg_1 = jnp .zeros ((num_gemms ,), jnp .float32 ) # beta
2589+ alpha_beta_numel = (
2590+ num_gemms if _v2_grouped_gemm_supports_per_group_alpha_beta () else 1
2591+ )
2592+ additional_arg_0 = jnp .ones ((alpha_beta_numel ,), jnp .float32 ) # alpha
2593+ additional_arg_1 = jnp .zeros ((alpha_beta_numel ,), jnp .float32 ) # beta
25432594 else :
25442595 additional_arg_0 = jnp .zeros ((1 ,), jnp .int32 ) # group_offset
25452596 additional_arg_1 = jnp .zeros ((0 ,), jnp .int32 ) # unused placeholder
2597+ bias_for_ffi = jnp .empty ((), jnp .float32 ) if (bias is None or use_v2_ffi ) else bias
2598+ has_bias_for_ffi = has_bias and not use_v2_ffi
25462599
25472600 (out ,) = GroupedGemmPrimitive .outer_primitive .bind (
25482601 lhs .data ,
25492602 lhs .scale_inv if isinstance (lhs , GroupedScaledTensor1x ) else jnp .empty ((0 ,), jnp .float32 ),
25502603 rhs .data ,
25512604 rhs .scale_inv if isinstance (rhs , GroupedScaledTensor1x ) else jnp .empty ((0 ,), jnp .float32 ),
2552- bias ,
2605+ bias_for_ffi ,
25532606 lhs .first_dims if lhs .first_dims is not None else empty_gs ,
25542607 lhs .last_dims if lhs .last_dims is not None else empty_gs ,
25552608 rhs .first_dims if rhs .first_dims is not None else empty_gs ,
@@ -2562,7 +2615,7 @@ def grouped_gemm(
25622615 rhs_is_trans = rhs_is_trans ,
25632616 scaling_mode = scaling_mode .value ,
25642617 out_dtype = out_dtype ,
2565- has_bias = has_bias ,
2618+ has_bias = has_bias_for_ffi ,
25662619 use_async_d2h_group_sizes = use_async_d2h_group_sizes ,
25672620 use_v2_ffi = use_v2_ffi ,
25682621 lhs_axis_boundary = lhs_axis_boundary ,
@@ -2573,4 +2626,8 @@ def grouped_gemm(
25732626 rhs_left_size = int (rhs_left_size ),
25742627 rhs_right_size = int (rhs_right_size ),
25752628 )
2629+ if use_v2_ffi and has_bias :
2630+ out = _add_grouped_gemm_bias (
2631+ out , bias , out_first_dims , out_last_dims , out_shape , num_gemms , N_dim
2632+ )
25762633 return out
0 commit comments