diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 29292f946b..dd2771afcc 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -31,6 +31,7 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp, + get_num_devices_in_mesh, ) from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -88,6 +89,14 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo FUSED_MXFP8_NORM_CUDNN_MIN_VERSION = (9, 10, 0) +def _is_custom_partitioning_outer_trace(is_outer): + # custom_partitioning traces cls.impl once with global avals before partition() lowers the + # per-shard implementation. That trace must not bind the single-GPU inner primitive, because + # the inner primitive sizes temporary workspace from its input avals. For eager and + # single-device JIT paths this jaxpr is executed directly, so those paths must call inner. + return is_outer and get_num_devices_in_mesh() > 1 + + class NormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward FP8 Primitive @@ -120,9 +129,67 @@ def abstract( is_outer, ): """ - LayerNorm fwd inner primitive abstract + LayerNorm fwd inner primitive abstract. """ - del amax_scope, transpose_batch_sequence + outputs = NormFwdPrimitive._abstract_outputs( + x_aval, + scale_aval, + amax_aval, + gamma_aval, + beta_aval, + norm_type=norm_type, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + quantize_layout=quantize_layout, + scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, + ) + + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # itype + jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype + jax_dtype_to_te_dtype(out_dtype), + norm_type, + scaling_mode, + zero_centered_gamma, + epsilon, + get_forward_sm_margin(), + True, # is_training + ) + wkspace_aval = jax.core.ShapedArray( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return (*outputs, wkspace_aval) + + @staticmethod + def _abstract_outputs( + x_aval, + scale_aval, + amax_aval, + gamma_aval, + beta_aval, + *, + norm_type, + zero_centered_gamma, + epsilon, + out_dtype, + scaling_mode, + quantize_layout, + scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, + ): + del amax_scope, transpose_batch_sequence, zero_centered_gamma, epsilon assert not output_amax_when_no_scaling or ( scaling_mode == ScalingMode.NO_SCALING.value and not is_norm_fwd_cudnn_enabled(scaling_mode) @@ -196,24 +263,7 @@ def abstract( shape=colwise_scale_inv_shape, dtype=scale_dtype ) - (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # itype - jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype - jax_dtype_to_te_dtype(out_dtype), - norm_type, - scaling_mode, - zero_centered_gamma, - epsilon, - get_forward_sm_margin(), - True, # is_training - ) - wkspace_aval = jax.core.ShapedArray( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) - - return ( + outputs = ( out_aval, colwise_out_aval, scale_inv_aval, @@ -221,33 +271,42 @@ def abstract( updated_amax_aval, mu_aval, rsigma_aval, - wkspace_aval, ) + return outputs @staticmethod def outer_abstract(*args, **kwargs): """ - LayerNorm fwd outer primitive abstract + LayerNorm fwd outer primitive abstract. """ - ( - out_aval, - colwise_out_aval, - scale_inv_aval, - colwise_scale_inv_aval, - updated_amax_aval, - mu_aval, - rsigma_aval, - _, - ) = NormFwdPrimitive.abstract(*args, **kwargs) - return ( - out_aval, - colwise_out_aval, - scale_inv_aval, - colwise_scale_inv_aval, - updated_amax_aval, - mu_aval, - rsigma_aval, - ) + return NormFwdPrimitive._abstract_outputs(*args, **kwargs) + + @staticmethod + def _custom_partitioning_trace_outputs( + x, + *, + norm_type, + out_dtype, + scaling_mode, + quantize_layout, + scale_dtype, + ): + out = jnp.empty(x.shape, dtype=out_dtype) + colwise_out_shape = x.shape if quantize_layout.has_colwise else (1,) + colwise_out = jnp.empty(colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) + scale_inv = jnp.empty(rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,) + colwise_scale_inv = jnp.empty(colwise_scale_inv_shape, dtype=scale_dtype) + + updated_amax = jnp.empty((1,), dtype=jnp.float32) + mu_shape = (1,) if norm_type == NVTE_Norm_Type.RMSNorm else x.shape[:-1] + mu = jnp.empty(mu_shape, dtype=jnp.float32) + rsigma = jnp.empty(x.shape[:-1], dtype=jnp.float32) + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma @staticmethod def lowering( @@ -348,7 +407,16 @@ def impl( """ to describe implementation """ - del is_outer + if _is_custom_partitioning_outer_trace(is_outer): + return NormFwdPrimitive._custom_partitioning_trace_outputs( + x, + norm_type=norm_type, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + quantize_layout=quantize_layout, + scale_dtype=scale_dtype, + ) + assert ( NormFwdPrimitive.inner_primitive is not None ), "NormFwdPrimitive.inner_primitive has not been registered" @@ -400,6 +468,12 @@ def impl( rsigma, ) # Exclude wkspace + @staticmethod + def outer_impl(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = False + return NormFwdPrimitive.impl(*args, **kwargs) + @staticmethod def batcher( batched_args, @@ -630,7 +704,7 @@ def sharded_impl(x, scale, amax, gamma, beta): amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=True, + is_outer=False, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP( @@ -726,15 +800,62 @@ class NormBwdPrimitive(BasePrimitive): name = "te_norm_backward_ffi" multiple_results = True - impl_static_args = (5, 6) # norm_type, zero_centered_gamma + impl_static_args = (5, 6, 7) # norm_type, zero_centered_gamma, is_outer inner_primitive = None outer_primitive = None @staticmethod - def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_centered_gamma): + def abstract( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ): """ bwd inner primitive abstract """ + outputs = NormBwdPrimitive._abstract_outputs( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ) + + (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + norm_type, + zero_centered_gamma, + get_backward_sm_margin(), + ) + wkspace_aval = outputs[0].update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return (*outputs, wkspace_aval) + + @staticmethod + def _abstract_outputs( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ): + del zero_centered_gamma, is_outer w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) @@ -764,39 +885,29 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ if norm_type != NVTE_Norm_Type.LayerNorm: dbeta_aval = dbeta_aval.update(shape=(1,)) - (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - norm_type, - zero_centered_gamma, - get_backward_sm_margin(), - ) - wkspace_aval = dx_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) - - return ( - dx_aval, - dgamma_aval, - dbeta_aval, - wkspace_aval, - ) + return dx_aval, dgamma_aval, dbeta_aval @staticmethod def outer_abstract(*args, **kwargs): """ - LayerNorm bwd outer primitive abstract + bwd outer primitive abstract """ - dx_aval, dgamma_aval, dbeta_aval, _ = NormBwdPrimitive.abstract(*args, **kwargs) - return dx_aval, dgamma_aval, dbeta_aval + return NormBwdPrimitive._abstract_outputs(*args, **kwargs) + + @staticmethod + def _custom_partitioning_trace_outputs(dz, gamma, *, norm_type): + dx = jnp.empty_like(dz) + dgamma = jnp.empty_like(gamma) + dbeta_shape = gamma.shape if norm_type == NVTE_Norm_Type.LayerNorm else (1,) + dbeta = jnp.empty(dbeta_shape, dtype=gamma.dtype) + return dx, dgamma, dbeta @staticmethod - def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): + def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma, is_outer): """ bwd lowering rules """ + del is_outer g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) @@ -824,17 +935,35 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): ) @staticmethod - def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): + def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer): + if _is_custom_partitioning_outer_trace(is_outer): + return NormBwdPrimitive._custom_partitioning_trace_outputs( + dz, gamma, norm_type=norm_type + ) + assert ( NormBwdPrimitive.inner_primitive is not None ), "NormBwdPrimitive.inner_primitive has not been registered" dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( - dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma + dz, + x, + mu, + rsigma, + gamma, + norm_type=norm_type, + zero_centered_gamma=zero_centered_gamma, + is_outer=False, ) return dx, dgamma, dbeta @staticmethod - def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): + def outer_impl(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = False + return NormBwdPrimitive.impl(*args, **kwargs) + + @staticmethod + def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma, is_outer): check_valid_batch_dims(batch_dims) assert ( NormBwdPrimitive.outer_primitive is not None @@ -852,13 +981,16 @@ def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, + is_outer=is_outer, ), out_bdims, ) @staticmethod - def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): - del norm_type, zero_centered_gamma, result_infos + def infer_sharding_from_operands( + norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos + ): + del norm_type, zero_centered_gamma, is_outer, result_infos x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( @@ -883,8 +1015,8 @@ def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos return dx_sharding, dgamma_sharding, dbeta_sharding @staticmethod - def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): - del result_infos + def partition(norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos): + del result_infos, is_outer x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( @@ -931,6 +1063,7 @@ def sharded_impl(dz, x, mu, rsigma, gamma): gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, + is_outer=False, ) global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh) if norm_type == NVTE_Norm_Type.LayerNorm: @@ -1255,6 +1388,7 @@ def layernorm_bwd( gamma, norm_type=NVTE_Norm_Type.LayerNorm, zero_centered_gamma=zero_centered_gamma, + is_outer=True, ) @@ -1500,6 +1634,7 @@ def rmsnorm_bwd( gamma, norm_type=NVTE_Norm_Type.RMSNorm, zero_centered_gamma=zero_centered_gamma, + is_outer=True, ) return (dx, dgamma)