Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2104,14 +2104,6 @@ def forward(
"Amax reduction across TP+CP group is necessary when using context parallelism"
" with FP8!"
)
if fp8_recipe.float8_current_scaling() and context_parallel:
all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers)
for q in all_quantizers:
if isinstance(q, Float8CurrentScalingQuantizer):
q.with_amax_reduction = True
q.amax_reduction_group = (
cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group
)

if context_parallel:
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,9 @@ def forward(
cp_group = cp_group[1]
cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
# Amax reduction group for current scaling (a2a+p2p vs p2p choice).
fp8_amax_reduction_group = cp_group_a2a if cp_group_a2a is not None else cp_group
ctx.fp8_amax_reduction_group = fp8_amax_reduction_group
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
device_compute_capability = get_device_compute_capability()
Expand Down Expand Up @@ -1577,7 +1580,12 @@ def forward(
# q, k, v: torch.Tensor, dtype=torch.uint8
q_f16 = q
q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize(
qkv_layout, q, k, v, QKV_quantizer
qkv_layout,
q,
k,
v,
QKV_quantizer,
amax_reduction_group=fp8_amax_reduction_group,
)
if not fp8_recipe.mxfp8():
q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data]
Expand Down Expand Up @@ -2127,7 +2135,7 @@ def forward(
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
and not fp8_recipe.mxfp8()
):
out_fp8 = O_quantizer(out_f16)
out_fp8 = O_quantizer(out_f16, amax_reduction_group=fp8_amax_reduction_group)
out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16

ctx.layer_number = layer_number
Expand Down Expand Up @@ -2265,7 +2273,7 @@ def backward(ctx, dout, *_args):
and not isinstance(dout, QuantizedTensorStorage)
and not ctx.fp8_recipe.mxfp8()
):
dout = ctx.dO_quantizer(dout)
dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group)
if ctx.use_fused_attention:
dout._data = dout._data.contiguous()
elif ctx.use_fused_attention:
Expand Down Expand Up @@ -2383,7 +2391,7 @@ def backward(ctx, dout, *_args):
if isinstance(dout, QuantizedTensorStorage):
dout_fp8 = dout
elif not ctx.fp8_recipe.mxfp8():
dout_fp8 = ctx.dO_quantizer(dout)
dout_fp8 = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group)
if not ctx.fp8_recipe.mxfp8():
dout = dout_fp8._data

Expand Down Expand Up @@ -2903,7 +2911,14 @@ def backward(ctx, dout, *_args):
dv[cu_seqlens_kv_padded[-1] :].fill_(0)

if ctx.fp8 and ctx.is_input_fp8:
dq, dk, dv, _, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer)
dq, dk, dv, _, _ = combine_and_quantize(
ctx.qkv_layout,
dq,
dk,
dv,
ctx.dQKV_quantizer,
amax_reduction_group=ctx.fp8_amax_reduction_group,
)

if ctx.fp8:
# print quantizers
Expand Down Expand Up @@ -3055,6 +3070,9 @@ def forward(

cp_size = get_distributed_world_size(cp_group)
rank = get_distributed_rank(cp_group)
# Amax reduction group for current scaling (all_gather CP).
fp8_amax_reduction_group = cp_group
ctx.fp8_amax_reduction_group = fp8_amax_reduction_group
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
o_format = qkv_format
_, seq_dim_qkv, _ = get_bsh_dims(qkv_format)
Expand Down Expand Up @@ -3155,7 +3173,12 @@ def forward(
fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8
if not is_input_fp8 and not fp8_recipe.mxfp8():
q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize(
qkv_layout, q, k, v, QKV_quantizer
qkv_layout,
q,
k,
v,
QKV_quantizer,
amax_reduction_group=fp8_amax_reduction_group,
)
if not fp8_recipe.mxfp8():
q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data]
Expand Down Expand Up @@ -3379,7 +3402,7 @@ def forward(
)
)
if (fp8 and is_output_fp8) or bwd_requires_o_fp8:
out_fp8 = O_quantizer(out_f16)
out_fp8 = O_quantizer(out_f16, amax_reduction_group=fp8_amax_reduction_group)
out_ret = out_fp8 if is_output_fp8 else out_f16

# save tensors for backward
Expand Down Expand Up @@ -3532,7 +3555,7 @@ def backward(ctx, dout, *_args):
if isinstance(dout, QuantizedTensorStorage):
dout_fp8 = dout
elif not ctx.fp8_recipe.mxfp8():
dout = ctx.dO_quantizer(dout)
dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group)
dout_fp8 = dout
if not ctx.fp8_recipe.mxfp8():
dout = dout_fp8._data
Expand Down Expand Up @@ -3842,7 +3865,14 @@ def backward(ctx, dout, *_args):

# quantize if necessary
if ctx.fp8 and ctx.is_input_fp8:
dq, dk, dv, _, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer)
dq, dk, dv, _, _ = combine_and_quantize(
ctx.dqkv_layout,
dq,
dk,
dv,
ctx.dQKV_quantizer,
amax_reduction_group=ctx.fp8_amax_reduction_group,
)

nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
return (
Expand Down Expand Up @@ -3918,6 +3948,9 @@ def forward(
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")

cp_size = get_distributed_world_size(cp_group)
# Amax reduction group for current scaling (a2a CP).
fp8_amax_reduction_group = cp_group
ctx.fp8_amax_reduction_group = fp8_amax_reduction_group
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
original_qkv_layout = qkv_layout
orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape
Expand Down Expand Up @@ -4023,7 +4056,12 @@ def forward(
q_fp8, k_fp8, v_fp8 = q, k, v
elif not fp8_recipe.mxfp8():
q_fp8, k_fp8, v_fp8, qkv_layout, _ = combine_and_quantize(
qkv_layout, q, k, v, QKV_quantizer
qkv_layout,
q,
k,
v,
QKV_quantizer,
amax_reduction_group=fp8_amax_reduction_group,
)
if not fp8_recipe.mxfp8():
q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data]
Expand Down Expand Up @@ -4137,7 +4175,7 @@ def forward(
out_f16 = out_
if bwd_requires_o_fp8:
if not isinstance(out_, QuantizedTensorStorage):
out_fp8 = O_quantizer(out_)
out_fp8 = O_quantizer(out_, amax_reduction_group=fp8_amax_reduction_group)
out_part = out_fp8
if bwd_requires_o_f16:
if isinstance(out_, QuantizedTensorStorage):
Expand Down Expand Up @@ -4212,7 +4250,7 @@ def forward(
out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype)
if is_output_fp8:
if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8():
out_fp8 = O_quantizer(out_)
out_fp8 = O_quantizer(out_, amax_reduction_group=fp8_amax_reduction_group)
out_f16 = out_
else:
if fp8_recipe.delayed():
Expand Down Expand Up @@ -4358,7 +4396,7 @@ def backward(ctx, dout, *_args):
if isinstance(dout, QuantizedTensorStorage):
dout_fp8 = dout
elif not ctx.fp8_recipe.mxfp8():
dout = ctx.dO_quantizer(dout)
dout = ctx.dO_quantizer(dout, amax_reduction_group=ctx.fp8_amax_reduction_group)
dout_fp8 = dout
if not ctx.fp8_recipe.mxfp8():
dout = dout._data
Expand Down Expand Up @@ -4576,7 +4614,12 @@ def backward(ctx, dout, *_args):
ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()
) and ctx.is_input_fp8:
dq, dk, dv, _, _ = combine_and_quantize(
ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer
ctx.dqkv_layout,
dq,
dk,
dv,
ctx.dQKV_quantizer,
amax_reduction_group=ctx.fp8_amax_reduction_group,
)
if ctx.fp8_recipe.delayed():
dq, dk, dv = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2783,8 +2783,13 @@ def combine_and_quantize(
used_in_forward=True,
used_in_backward=False,
keep_same_data_and_scale_inv_format=False,
amax_reduction_group=None,
):
"""Combine Q, K, V tensors based on qkv_layout and quantize them together."""
"""Combine Q, K, V tensors based on qkv_layout and quantize them together.

``amax_reduction_group`` is an optional process group used to all-reduce the
amax during quantization (current scaling, context parallelism).
"""
if isinstance(qkv_quantizer, MXFP8Quantizer):
qkv_format, q_format, kv_format = get_qkv_format(qkv_layout)
assert qkv_format in ("bshd", "sbhd"), (
Expand Down Expand Up @@ -2885,7 +2890,7 @@ def combine_and_quantize(
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
Expand All @@ -2896,7 +2901,7 @@ def combine_and_quantize(
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
Expand All @@ -2908,7 +2913,7 @@ def combine_and_quantize(
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
qkv_fp8 = qkv_quantizer(qkv, amax_reduction_group=amax_reduction_group)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
NVTE_ERROR("Unexpected type for quantizer");
}

c10::intrusive_ptr<dist_group_type> convert_amax_reduction_group(py::handle amax_reduction_group) {
if (amax_reduction_group.is_none()) {
return {};
}
return amax_reduction_group.cast<c10::intrusive_ptr<dist_group_type>>();
}

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe) {
// if e4m3 or hybrid + forward
Expand Down
Loading
Loading