Skip to content

Add MXFP8 attention unit test with linear and rope layers#3033

Merged
cyanguwa merged 2 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test
Jun 4, 2026
Merged

Add MXFP8 attention unit test with linear and rope layers#3033
cyanguwa merged 2 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test

Conversation

@layalir

@layalir layalir commented May 22, 2026

Copy link
Copy Markdown
Contributor

Add a DSv3-shaped MXFP8 attention unit test covering the training path:

  • Adds MLA RoPE utilities for the DSv3 671B attention shape.
  • Adds an end-to-end MXFP8 path: Linear(QKV) -> MLA RoPE -> DotProductAttention -> Linear(out).
  • Exercises MXFP8 forward and backward through TE's real DotProductAttention wrapper.
  • Runs BF16 reference comparison by default.
  • Runs the performance benchmark by default and reports fprop and bprop timing separately from the same benchmark collection.

Validation

Local checks:

  • python -m py_compile tests/pytorch/attention/test_linear_mxfp8_attention.py tests/pytorch/attention/mla_rope_utils.py
  • git diff --check

GB300 dlcluster validation:

  • Job: 1062811
  • GPU: NVIDIA GB300
  • CUDA capability: (10, 3)
  • cuDNN: (9, 21, 1)
  • MXFP8 available: (True, '')
  • Command: python -m pytest tests/pytorch/attention/test_linear_mxfp8_attention.py -v -s
  • Result: 3 passed

Perf output:

[PERF] b=1 s=4096:
  BF16 fprop:  7.219 ms  (567397 tok/s)
  BF16 bprop:  15.179 ms  (269844 tok/s)
  MXFP8 fprop: 4.718 ms  (868181 tok/s)
  MXFP8 bprop: 9.215 ms  (444492 tok/s)
  Fprop speedup: 1.53x
  Bprop speedup: 1.65x

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a DSv3 671B-shaped MXFP8 end-to-end attention unit test, covering LayerNormLinear(Q/KV) → Triton MLA RoPE → DotProductAttentionLinear(out) in both forward and backward. A new mla_rope_utils.py provides the MLA RoPE Triton kernels (with PyTorch fallback), and test_linear_mxfp8_attention.py exercises accuracy, gradient flow, and performance benchmarking.

  • mla_rope_utils.py: Implements interleaved-to-neox RoPE for Q in-place and a split KV kernel that outputs separate K/V tensors with the shared positional embedding broadcast to all heads; backward kernels correctly accumulate dEMB across all heads before inverting the rotation.
  • test_linear_mxfp8_attention.py: Three pytest tests (test_accuracy, test_backward, test_performance) gated on MXFP8 availability; previously flagged issues (split fp8_autocast scope, weight-cache invalidation in benchmark warmup) have been resolved.
  • qa/L0_pytorch_unittest/test.sh: One-line addition to include the new test in the CI pipeline, which will be skipped on non-MXFP8 hardware.

Confidence Score: 5/5

Safe to merge; only test files are added with no modifications to production TE code, and tests are correctly gated on MXFP8 availability.

All changes are test infrastructure. The MLA RoPE Triton kernel math is correct, the Triton masking concern from a prior round was fixed, and the single fp8_autocast scope now covers the full forward path. The two remaining concerns are a hard speedup assertion in CI and a private-symbol import, both quality-of-life issues rather than correctness bugs.

tests/pytorch/attention/test_linear_mxfp8_attention.py — the performance assertion and private symbol import are worth revisiting before this test runs on new CI pools.

Important Files Changed

Filename Overview
qa/L0_pytorch_unittest/test.sh Adds the new MXFP8 attention test to the CI runner; one-liner addition with no structural issues.
tests/pytorch/attention/mla_rope_utils.py New Triton-backed MLA RoPE utility: fwd/bwd kernels for Q and KV, with PyTorch fallback. Previous in-bounds masking concern was addressed using absolute head offsets. Overall logic is correct for the tested b=1 case.
tests/pytorch/attention/test_linear_mxfp8_attention.py New end-to-end MXFP8 attention test covering accuracy, gradient flow, and a timing benchmark. Previous issues (split fp8_autocast scope, is_first_microbatch weight-cache invalidation) were resolved. The performance test adds a hard speedup > 1.0 assertion to CI that could be brittle; a private TE internal symbol is also imported.

Reviews (5): Last reviewed commit: "Merge branch 'main' into add_linear_mxfp..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/mla_rope_utils.py Outdated
Comment thread tests/pytorch/attention/test_linear_mxfp8_attention.py Outdated
Comment thread tests/pytorch/attention/test_linear_mxfp8_attention.py Outdated
@cyanguwa

Copy link
Copy Markdown
Collaborator

Thanks for the contribution! Could you please:

  • fix the DCO (there are instructions in the DCO link)
  • address Greptile comments
  • add test_linear_mxfp8_attention.py to qa/L0_pytorch_unittest/test.sh, similar to this

Comment thread tests/pytorch/attention/test_linear_mxfp8_attention.py Outdated
Comment thread tests/pytorch/attention/test_linear_mxfp8_attention.py Outdated
Comment thread tests/pytorch/attention/test_linear_mxfp8_attention.py Outdated
Signed-off-by: Layali Rashid <lrashid@nvidia.com>
@layalir layalir force-pushed the add_linear_mxfp8_unit_test branch from c2a41f1 to 46c6a44 Compare June 3, 2026 00:34
@cyanguwa

cyanguwa commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

@cyanguwa cyanguwa merged commit 64311fe into NVIDIA:main Jun 4, 2026
10 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants