From 1347eaeff1e7d7aadeb231bc9cdfa64b2e60d02c Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 4 Jun 2026 15:36:02 -0700 Subject: [PATCH 1/5] Fix norm workspace on global shapes Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/normalization.py | 79 +++++++------------ 1 file changed, 30 insertions(+), 49 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 29292f946b..55bb42df28 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -120,7 +120,7 @@ def abstract( is_outer, ): """ - LayerNorm fwd inner primitive abstract + LayerNorm fwd primitive abstract. """ del amax_scope, transpose_batch_sequence assert not output_amax_when_no_scaling or ( @@ -196,6 +196,18 @@ def abstract( shape=colwise_scale_inv_shape, dtype=scale_dtype ) + outputs = ( + out_aval, + colwise_out_aval, + scale_inv_aval, + colwise_scale_inv_aval, + updated_amax_aval, + mu_aval, + rsigma_aval, + ) + if is_outer: + return outputs + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -213,41 +225,7 @@ def abstract( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - return ( - out_aval, - colwise_out_aval, - scale_inv_aval, - colwise_scale_inv_aval, - updated_amax_aval, - mu_aval, - rsigma_aval, - wkspace_aval, - ) - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - LayerNorm fwd outer primitive abstract - """ - ( - out_aval, - colwise_out_aval, - scale_inv_aval, - colwise_scale_inv_aval, - updated_amax_aval, - mu_aval, - rsigma_aval, - _, - ) = NormFwdPrimitive.abstract(*args, **kwargs) - return ( - out_aval, - colwise_out_aval, - scale_inv_aval, - colwise_scale_inv_aval, - updated_amax_aval, - mu_aval, - rsigma_aval, - ) + return (*outputs, wkspace_aval) @staticmethod def lowering( @@ -348,20 +326,10 @@ def impl( """ to describe implementation """ - del is_outer assert ( NormFwdPrimitive.inner_primitive is not None ), "NormFwdPrimitive.inner_primitive has not been registered" - ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - mu, - rsigma, - _, - ) = NormFwdPrimitive.inner_primitive.bind( + outputs = NormFwdPrimitive.inner_primitive.bind( x, scale, amax, @@ -377,8 +345,21 @@ def impl( amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=False, + is_outer=is_outer, ) + if is_outer: + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma = outputs + else: + ( + out, + colwise_out, + scale_inv, + colwise_scale_inv, + updated_amax, + mu, + rsigma, + _, + ) = outputs rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) @@ -630,7 +611,7 @@ def sharded_impl(x, scale, amax, gamma, beta): amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=True, + is_outer=False, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: global_updated_amax = all_reduce_max_along_all_axes_except_PP( From c2bd57d7b70990d4ce86834dd7130de8b947ee3f Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 4 Jun 2026 15:46:54 -0700 Subject: [PATCH 2/5] Bwd Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/normalization.py | 70 ++++++++++++------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 55bb42df28..108979ec42 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -707,14 +707,23 @@ class NormBwdPrimitive(BasePrimitive): name = "te_norm_backward_ffi" multiple_results = True - impl_static_args = (5, 6) # norm_type, zero_centered_gamma + impl_static_args = (5, 6, 7) # norm_type, zero_centered_gamma, is_outer inner_primitive = None outer_primitive = None @staticmethod - def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_centered_gamma): + def abstract( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ): """ - bwd inner primitive abstract + bwd primitive abstract """ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) @@ -745,6 +754,10 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ if norm_type != NVTE_Norm_Type.LayerNorm: dbeta_aval = dbeta_aval.update(shape=(1,)) + outputs = (dx_aval, dgamma_aval, dbeta_aval) + if is_outer: + return outputs + (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size @@ -758,26 +771,14 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - return ( - dx_aval, - dgamma_aval, - dbeta_aval, - wkspace_aval, - ) - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - LayerNorm bwd outer primitive abstract - """ - dx_aval, dgamma_aval, dbeta_aval, _ = NormBwdPrimitive.abstract(*args, **kwargs) - return dx_aval, dgamma_aval, dbeta_aval + return (*outputs, wkspace_aval) @staticmethod - def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): + def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma, is_outer): """ bwd lowering rules """ + del is_outer g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) @@ -805,17 +806,28 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): ) @staticmethod - def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): + def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer): assert ( NormBwdPrimitive.inner_primitive is not None ), "NormBwdPrimitive.inner_primitive has not been registered" - dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( - dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma + outputs = NormBwdPrimitive.inner_primitive.bind( + dz, + x, + mu, + rsigma, + gamma, + norm_type=norm_type, + zero_centered_gamma=zero_centered_gamma, + is_outer=is_outer, ) + if is_outer: + dx, dgamma, dbeta = outputs + else: + dx, dgamma, dbeta, _ = outputs return dx, dgamma, dbeta @staticmethod - def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): + def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma, is_outer): check_valid_batch_dims(batch_dims) assert ( NormBwdPrimitive.outer_primitive is not None @@ -833,13 +845,16 @@ def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, + is_outer=is_outer, ), out_bdims, ) @staticmethod - def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): - del norm_type, zero_centered_gamma, result_infos + def infer_sharding_from_operands( + norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos + ): + del norm_type, zero_centered_gamma, is_outer, result_infos x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( @@ -864,8 +879,8 @@ def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos return dx_sharding, dgamma_sharding, dbeta_sharding @staticmethod - def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos): - del result_infos + def partition(norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos): + del result_infos, is_outer x_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( @@ -912,6 +927,7 @@ def sharded_impl(dz, x, mu, rsigma, gamma): gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, + is_outer=False, ) global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh) if norm_type == NVTE_Norm_Type.LayerNorm: @@ -1236,6 +1252,7 @@ def layernorm_bwd( gamma, norm_type=NVTE_Norm_Type.LayerNorm, zero_centered_gamma=zero_centered_gamma, + is_outer=True, ) @@ -1481,6 +1498,7 @@ def rmsnorm_bwd( gamma, norm_type=NVTE_Norm_Type.RMSNorm, zero_centered_gamma=zero_centered_gamma, + is_outer=True, ) return (dx, dgamma) From 490f77050eef8fca0ea39c8765a184d678d06a4a Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:49:06 -0700 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- .../jax/cpp_extensions/normalization.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 108979ec42..0cb0178e79 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -336,6 +336,22 @@ def impl( gamma, beta, norm_type=norm_type, + ( + out, + colwise_out, + scale_inv, + colwise_scale_inv, + updated_amax, + mu, + rsigma, + _, + ) = NormFwdPrimitive.inner_primitive.bind( + x, + scale, + amax, + gamma, + beta, + norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=out_dtype, @@ -345,21 +361,8 @@ def impl( amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=is_outer, + is_outer=False, # inner_primitive always emits 8 outputs (incl. workspace) ) - if is_outer: - out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma = outputs - else: - ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - mu, - rsigma, - _, - ) = outputs rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) From 3597e76bc560f78351847e87c70861d11bbd0eef Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 4 Jun 2026 16:00:00 -0700 Subject: [PATCH 4/5] fixes Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/normalization.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0cb0178e79..d522ce03aa 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -336,22 +336,6 @@ def impl( gamma, beta, norm_type=norm_type, - ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - mu, - rsigma, - _, - ) = NormFwdPrimitive.inner_primitive.bind( - x, - scale, - amax, - gamma, - beta, - norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, out_dtype=out_dtype, @@ -361,8 +345,21 @@ def impl( amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=False, # inner_primitive always emits 8 outputs (incl. workspace) + is_outer=is_outer, ) + if is_outer: + out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma = outputs + else: + ( + out, + colwise_out, + scale_inv, + colwise_scale_inv, + updated_amax, + mu, + rsigma, + _, + ) = outputs rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) @@ -384,6 +381,12 @@ def impl( rsigma, ) # Exclude wkspace + @staticmethod + def outer_impl(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = False + return NormFwdPrimitive.impl(*args, **kwargs) + @staticmethod def batcher( batched_args, @@ -829,6 +832,12 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer): dx, dgamma, dbeta, _ = outputs return dx, dgamma, dbeta + @staticmethod + def outer_impl(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = False + return NormBwdPrimitive.impl(*args, **kwargs) + @staticmethod def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma, is_outer): check_valid_batch_dims(batch_dims) From 241d8214a35bea4e743e6dbfc69edbde7e729342 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 5 Jun 2026 07:05:12 -0700 Subject: [PATCH 5/5] Fix single-GPU eager case Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/normalization.py | 242 +++++++++++++----- 1 file changed, 183 insertions(+), 59 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d522ce03aa..dd2771afcc 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -31,6 +31,7 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp, + get_num_devices_in_mesh, ) from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -88,6 +89,14 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo FUSED_MXFP8_NORM_CUDNN_MIN_VERSION = (9, 10, 0) +def _is_custom_partitioning_outer_trace(is_outer): + # custom_partitioning traces cls.impl once with global avals before partition() lowers the + # per-shard implementation. That trace must not bind the single-GPU inner primitive, because + # the inner primitive sizes temporary workspace from its input avals. For eager and + # single-device JIT paths this jaxpr is executed directly, so those paths must call inner. + return is_outer and get_num_devices_in_mesh() > 1 + + class NormFwdPrimitive(BasePrimitive): """ Layer Normalization Forward FP8 Primitive @@ -120,9 +129,67 @@ def abstract( is_outer, ): """ - LayerNorm fwd primitive abstract. + LayerNorm fwd inner primitive abstract. """ - del amax_scope, transpose_batch_sequence + outputs = NormFwdPrimitive._abstract_outputs( + x_aval, + scale_aval, + amax_aval, + gamma_aval, + beta_aval, + norm_type=norm_type, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + quantize_layout=quantize_layout, + scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, + ) + + (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # itype + jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype + jax_dtype_to_te_dtype(out_dtype), + norm_type, + scaling_mode, + zero_centered_gamma, + epsilon, + get_forward_sm_margin(), + True, # is_training + ) + wkspace_aval = jax.core.ShapedArray( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return (*outputs, wkspace_aval) + + @staticmethod + def _abstract_outputs( + x_aval, + scale_aval, + amax_aval, + gamma_aval, + beta_aval, + *, + norm_type, + zero_centered_gamma, + epsilon, + out_dtype, + scaling_mode, + quantize_layout, + scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, + ): + del amax_scope, transpose_batch_sequence, zero_centered_gamma, epsilon assert not output_amax_when_no_scaling or ( scaling_mode == ScalingMode.NO_SCALING.value and not is_norm_fwd_cudnn_enabled(scaling_mode) @@ -205,27 +272,41 @@ def abstract( mu_aval, rsigma_aval, ) - if is_outer: - return outputs + return outputs - (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # itype - jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype - jax_dtype_to_te_dtype(out_dtype), - norm_type, - scaling_mode, - zero_centered_gamma, - epsilon, - get_forward_sm_margin(), - True, # is_training - ) - wkspace_aval = jax.core.ShapedArray( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + @staticmethod + def outer_abstract(*args, **kwargs): + """ + LayerNorm fwd outer primitive abstract. + """ + return NormFwdPrimitive._abstract_outputs(*args, **kwargs) - return (*outputs, wkspace_aval) + @staticmethod + def _custom_partitioning_trace_outputs( + x, + *, + norm_type, + out_dtype, + scaling_mode, + quantize_layout, + scale_dtype, + ): + out = jnp.empty(x.shape, dtype=out_dtype) + colwise_out_shape = x.shape if quantize_layout.has_colwise else (1,) + colwise_out = jnp.empty(colwise_out_shape, dtype=out_dtype) + + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( + scaling_mode + ).get_scale_shape_2x(x.shape, is_padded=False) + scale_inv = jnp.empty(rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,) + colwise_scale_inv = jnp.empty(colwise_scale_inv_shape, dtype=scale_dtype) + + updated_amax = jnp.empty((1,), dtype=jnp.float32) + mu_shape = (1,) if norm_type == NVTE_Norm_Type.RMSNorm else x.shape[:-1] + mu = jnp.empty(mu_shape, dtype=jnp.float32) + rsigma = jnp.empty(x.shape[:-1], dtype=jnp.float32) + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma @staticmethod def lowering( @@ -326,10 +407,29 @@ def impl( """ to describe implementation """ + if _is_custom_partitioning_outer_trace(is_outer): + return NormFwdPrimitive._custom_partitioning_trace_outputs( + x, + norm_type=norm_type, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + quantize_layout=quantize_layout, + scale_dtype=scale_dtype, + ) + assert ( NormFwdPrimitive.inner_primitive is not None ), "NormFwdPrimitive.inner_primitive has not been registered" - outputs = NormFwdPrimitive.inner_primitive.bind( + ( + out, + colwise_out, + scale_inv, + colwise_scale_inv, + updated_amax, + mu, + rsigma, + _, + ) = NormFwdPrimitive.inner_primitive.bind( x, scale, amax, @@ -345,21 +445,8 @@ def impl( amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=output_amax_when_no_scaling, - is_outer=is_outer, + is_outer=False, ) - if is_outer: - out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma = outputs - else: - ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - mu, - rsigma, - _, - ) = outputs rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) @@ -729,8 +816,46 @@ def abstract( is_outer, ): """ - bwd primitive abstract + bwd inner primitive abstract """ + outputs = NormBwdPrimitive._abstract_outputs( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ) + + (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + norm_type, + zero_centered_gamma, + get_backward_sm_margin(), + ) + wkspace_aval = outputs[0].update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return (*outputs, wkspace_aval) + + @staticmethod + def _abstract_outputs( + dz_aval, + x_aval, + mu_aval, + rsigma_aval, + gamma_aval, + norm_type, + zero_centered_gamma, + is_outer, + ): + del zero_centered_gamma, is_outer w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) @@ -760,24 +885,22 @@ def abstract( if norm_type != NVTE_Norm_Type.LayerNorm: dbeta_aval = dbeta_aval.update(shape=(1,)) - outputs = (dx_aval, dgamma_aval, dbeta_aval) - if is_outer: - return outputs + return dx_aval, dgamma_aval, dbeta_aval - (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - norm_type, - zero_centered_gamma, - get_backward_sm_margin(), - ) - wkspace_aval = dx_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + @staticmethod + def outer_abstract(*args, **kwargs): + """ + bwd outer primitive abstract + """ + return NormBwdPrimitive._abstract_outputs(*args, **kwargs) - return (*outputs, wkspace_aval) + @staticmethod + def _custom_partitioning_trace_outputs(dz, gamma, *, norm_type): + dx = jnp.empty_like(dz) + dgamma = jnp.empty_like(gamma) + dbeta_shape = gamma.shape if norm_type == NVTE_Norm_Type.LayerNorm else (1,) + dbeta = jnp.empty(dbeta_shape, dtype=gamma.dtype) + return dx, dgamma, dbeta @staticmethod def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma, is_outer): @@ -813,10 +936,15 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma, i @staticmethod def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer): + if _is_custom_partitioning_outer_trace(is_outer): + return NormBwdPrimitive._custom_partitioning_trace_outputs( + dz, gamma, norm_type=norm_type + ) + assert ( NormBwdPrimitive.inner_primitive is not None ), "NormBwdPrimitive.inner_primitive has not been registered" - outputs = NormBwdPrimitive.inner_primitive.bind( + dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( dz, x, mu, @@ -824,12 +952,8 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer): gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, - is_outer=is_outer, + is_outer=False, ) - if is_outer: - dx, dgamma, dbeta = outputs - else: - dx, dgamma, dbeta, _ = outputs return dx, dgamma, dbeta @staticmethod