Skip to content
Draft
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 67 additions & 65 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def abstract(
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
LayerNorm fwd primitive abstract.
"""
del amax_scope, transpose_batch_sequence
assert not output_amax_when_no_scaling or (
Expand Down Expand Up @@ -196,6 +196,18 @@ def abstract(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)

outputs = (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
if is_outer:
return outputs

(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
Expand All @@ -213,41 +225,7 @@ def abstract(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)

return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
wkspace_aval,
)

@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
(
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
_,
) = NormFwdPrimitive.abstract(*args, **kwargs)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
return (*outputs, wkspace_aval)

@staticmethod
def lowering(
Expand Down Expand Up @@ -348,10 +326,16 @@ def impl(
"""
to describe implementation
"""
del is_outer
assert (
NormFwdPrimitive.inner_primitive is not None
), "NormFwdPrimitive.inner_primitive has not been registered"
outputs = NormFwdPrimitive.inner_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=norm_type,
(
out,
colwise_out,
Expand All @@ -377,7 +361,7 @@ def impl(
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
is_outer=False, # inner_primitive always emits 8 outputs (incl. workspace)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
Expand Down Expand Up @@ -630,7 +614,7 @@ def sharded_impl(x, scale, amax, gamma, beta):
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
is_outer=False,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(
Expand Down Expand Up @@ -726,14 +710,23 @@ class NormBwdPrimitive(BasePrimitive):

name = "te_norm_backward_ffi"
multiple_results = True
impl_static_args = (5, 6) # norm_type, zero_centered_gamma
impl_static_args = (5, 6, 7) # norm_type, zero_centered_gamma, is_outer
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_centered_gamma):
def abstract(
dz_aval,
x_aval,
mu_aval,
rsigma_aval,
gamma_aval,
norm_type,
zero_centered_gamma,
is_outer,
):
"""
bwd inner primitive abstract
bwd primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
Expand Down Expand Up @@ -764,6 +757,10 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_
if norm_type != NVTE_Norm_Type.LayerNorm:
dbeta_aval = dbeta_aval.update(shape=(1,))

outputs = (dx_aval, dgamma_aval, dbeta_aval)
if is_outer:
return outputs

(wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
Expand All @@ -777,26 +774,14 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)

return (
dx_aval,
dgamma_aval,
dbeta_aval,
wkspace_aval,
)
return (*outputs, wkspace_aval)

@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, dbeta_aval, _ = NormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval

@staticmethod
def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma):
def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma, is_outer):
"""
bwd lowering rules
"""
del is_outer
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
b_type = ir.RankedTensorType(gamma.type)
Expand Down Expand Up @@ -824,17 +809,28 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma):
)

@staticmethod
def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma):
def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma, is_outer):
assert (
NormBwdPrimitive.inner_primitive is not None
), "NormBwdPrimitive.inner_primitive has not been registered"
dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma
outputs = NormBwdPrimitive.inner_primitive.bind(
dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
is_outer=is_outer,
)
if is_outer:
dx, dgamma, dbeta = outputs
else:
dx, dgamma, dbeta, _ = outputs
return dx, dgamma, dbeta

@staticmethod
def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma):
def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma, is_outer):
check_valid_batch_dims(batch_dims)
assert (
NormBwdPrimitive.outer_primitive is not None
Expand All @@ -852,13 +848,16 @@ def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma):
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
is_outer=is_outer,
),
out_bdims,
)

@staticmethod
def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del norm_type, zero_centered_gamma, result_infos
def infer_sharding_from_operands(
norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos
):
del norm_type, zero_centered_gamma, is_outer, result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
Expand All @@ -883,8 +882,8 @@ def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos
return dx_sharding, dgamma_sharding, dbeta_sharding

@staticmethod
def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del result_infos
def partition(norm_type, zero_centered_gamma, is_outer, mesh, arg_infos, result_infos):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
Expand Down Expand Up @@ -931,6 +930,7 @@ def sharded_impl(dz, x, mu, rsigma, gamma):
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
is_outer=False,
)
global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh)
if norm_type == NVTE_Norm_Type.LayerNorm:
Expand Down Expand Up @@ -1255,6 +1255,7 @@ def layernorm_bwd(
gamma,
norm_type=NVTE_Norm_Type.LayerNorm,
zero_centered_gamma=zero_centered_gamma,
is_outer=True,
)


Expand Down Expand Up @@ -1500,6 +1501,7 @@ def rmsnorm_bwd(
gamma,
norm_type=NVTE_Norm_Type.RMSNorm,
zero_centered_gamma=zero_centered_gamma,
is_outer=True,
)
return (dx, dgamma)

Expand Down
Loading