Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 30 additions & 49 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def abstract(
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
LayerNorm fwd primitive abstract.
"""
del amax_scope, transpose_batch_sequence
assert not output_amax_when_no_scaling or (
Expand Down Expand Up @@ -196,6 +196,18 @@ def abstract(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)

outputs = (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
if is_outer:
return outputs

(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
Expand All @@ -213,41 +225,7 @@ def abstract(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)

return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
wkspace_aval,
)

@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
(
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
_,
) = NormFwdPrimitive.abstract(*args, **kwargs)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
return (*outputs, wkspace_aval)

@staticmethod
def lowering(
Expand Down Expand Up @@ -348,20 +326,10 @@ def impl(
"""
to describe implementation
"""
del is_outer
assert (
NormFwdPrimitive.inner_primitive is not None
), "NormFwdPrimitive.inner_primitive has not been registered"
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
_,
) = NormFwdPrimitive.inner_primitive.bind(
outputs = NormFwdPrimitive.inner_primitive.bind(
x,
scale,
amax,
Expand All @@ -377,8 +345,21 @@ def impl(
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
is_outer=is_outer,
)
if is_outer:
out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, mu, rsigma = outputs
else:
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
_,
) = outputs
Comment thread
jberchtold-nvidia marked this conversation as resolved.
Outdated
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
Expand Down Expand Up @@ -630,7 +611,7 @@ def sharded_impl(x, scale, amax, gamma, beta):
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
is_outer=False,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(
Expand Down
Loading