Skip to content

Commit cad802f

Browse files
pggPLpre-commit-ci[bot]
authored andcommitted
[PyTorch] ONNX test fix + export for FP8 attention (#2598)
* jjit bug fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix' Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3da26cd commit cad802f

6 files changed

Lines changed: 97 additions & 16 deletions

File tree

qa/L1_pytorch_onnx_unittest/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
: ${XML_LOG_DIR:=/logs}
77
mkdir -p "$XML_LOG_DIR"
88

9-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
9+
# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
10+
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py

tests/pytorch/test_onnx_export.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
713713
_test_export_layernorm_mlp(activation=activation)
714714

715715

716+
# Quantization recipes with fp8_dpa=True for attention emulation export test
717+
dpa_quantization_recipes = [None] # None = no quantization
718+
if fp8_available:
719+
dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
720+
dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))
721+
722+
723+
@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes)
716724
@pytest.mark.parametrize(
717725
"precision, use_mask, attn_mask_type",
718726
[
@@ -730,6 +738,7 @@ def test_export_core_attention(
730738
precision: torch.dtype,
731739
use_mask: bool,
732740
attn_mask_type: str,
741+
fp8_recipe: recipe.Recipe,
733742
):
734743
# Set dimensions (these are arbitrary).
735744
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
@@ -749,22 +758,25 @@ def test_export_core_attention(
749758

750759
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
751760
high_prec_str = dtype2str(precision)
752-
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
761+
fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
762+
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"
763+
764+
is_fp8 = fp8_recipe is not None
753765

754766
model = te.attention.DotProductAttention(
755767
num_attention_heads=num_attention_heads,
756768
kv_channels=kv_channels,
757-
attention_dropout=0.5,
758769
qkv_format=qkv_format,
759770
attn_mask_type=attn_mask_type,
760771
).to(device="cuda")
761-
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
762-
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
772+
do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
773+
te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
763774
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
764775
if precision in (torch.bfloat16,):
765776
return
777+
atol = 5e-1 if is_fp8 else 1e-2
766778
validate_result(
767-
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
779+
fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
768780
)
769781

770782

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function):
164164
@staticmethod
165165
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
166166
# pylint: disable=missing-function-docstring
167+
if is_in_onnx_export_mode():
168+
return FP8EmulationFunc.onnx_forward(
169+
tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout
170+
)
171+
167172
if quantizer_name == "QKV_quantizer":
168173
query_layer, key_layer, value_layer = [
169174
x.contiguous() for x in [tensor1, tensor2, tensor3]
@@ -202,6 +207,47 @@ def backward(ctx, grad1, grad2, grad3):
202207
tensors = grad1, grad2, grad3
203208
return tensors[0], tensors[1], tensors[2], None, None, None
204209

210+
@staticmethod
211+
def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None):
212+
"""
213+
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.
214+
"""
215+
# pylint: disable=unused-argument
216+
is_qkv_quantizer = quantizer_name == "QKV_quantizer"
217+
assert isinstance(
218+
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
219+
), "ONNX FP8 emulation path supports only Float8 quantizers."
220+
221+
if is_qkv_quantizer:
222+
# Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3.
223+
orig_dtype = tensor1.dtype
224+
shapes = [tensor1.shape, tensor2.shape, tensor3.shape]
225+
numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()]
226+
227+
# Flatten and concatenate
228+
combined = torch.cat(
229+
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
230+
)
231+
232+
# Quantize + dequantize combined tensor using quantizer's ONNX methods
233+
combined_fp8 = quantizer.onnx_quantize(combined)
234+
out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype)
235+
236+
# Split back
237+
out1 = out[: numels[0]].reshape(shapes[0])
238+
out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1])
239+
out3 = out[numels[0] + numels[1] :].reshape(shapes[2])
240+
241+
return out1, out2, out3
242+
if quantizer_name in ["S_quantizer", "O_quantizer"]:
243+
# Emulate FP8 on single tensor using quantizer's ONNX methods
244+
orig_dtype = tensor1.dtype
245+
t_fp8 = quantizer.onnx_quantize(tensor1)
246+
out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype)
247+
return out, tensor2, tensor3
248+
# Pass-through
249+
return tensor1, tensor2, tensor3
250+
205251

206252
class UnfusedDotProductAttention(torch.nn.Module):
207253
"""Parallel attention w/o QKV and Proj Gemms

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,9 @@ def forward(
15521552
)
15531553

15541554
if use_unfused_attention:
1555-
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
1555+
allow_emulation = (
1556+
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
1557+
)
15561558
if checkpoint_core_attention:
15571559
return self._checkpointed_attention_forward(
15581560
self.unfused_attention,

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def get_attention_backend(
479479
logger.debug("Disabling FlashAttention 3 for FP8 training")
480480
use_flash_attention_3 = False
481481
if use_unfused_attention:
482-
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
482+
allow_emulation = (
483+
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
484+
)
483485
if not allow_emulation:
484486
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
485487
use_unfused_attention = False

transformer_engine/pytorch/jit.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,35 @@ def wrapper(*args, **kwargs):
4646

4747
# Decorator to disable Torch Dynamo
4848
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
49-
no_torch_dynamo = lambda recursive=True: lambda func: func
5049
if torch.__version__ >= "2":
5150
import torch._dynamo
5251

53-
if torch.__version__ >= "2.1":
54-
no_torch_dynamo = lambda recursive=True: lambda f: (
55-
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
56-
)
57-
else:
58-
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
59-
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
52+
def no_torch_dynamo(recursive=True):
53+
"""Decorator to disable Torch Dynamo, except during ONNX export."""
54+
55+
def decorator(f):
56+
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
57+
disabled_f = (
58+
torch._dynamo.disable(f, recursive=recursive)
59+
if torch.__version__ >= "2.1"
60+
else torch._dynamo.disable(f)
61+
)
62+
63+
@wraps(f)
64+
def wrapper(*args, **kwargs):
65+
if is_in_onnx_export_mode():
66+
return f(*args, **kwargs)
67+
return disabled_f(*args, **kwargs)
68+
69+
return wrapper
70+
71+
return decorator
72+
73+
else:
74+
# Fallback for PyTorch < 2.0: no-op decorator
75+
def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument
76+
"""No-op decorator for PyTorch < 2.0."""
77+
return lambda func: func
6078

6179

6280
def set_jit_fusion_options() -> None:

0 commit comments

Comments
 (0)