Skip to content

Add MXFP8 attention#2719

Merged
ptrendx merged 242 commits into
NVIDIA:mainfrom
cyanguwa:add_mxfp8
Apr 21, 2026
Merged

Add MXFP8 attention#2719
ptrendx merged 242 commits into
NVIDIA:mainfrom
cyanguwa:add_mxfp8

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented Mar 1, 2026

Description

  1. Added MXFP8 support in FusedAttention backend (fwd+bwd, BSHD/SBHD, TE-PyTorch)
  2. Decoupled input/output format in APIs by introducing o_format, do_format, and dqkv_layout
  3. Implemented mxfp8_quantize_fast_path() to quantize multiple tensors, pad/permute/swizzle the scale_invs for a more efficient quantization pipeline; added qkv_scale_inv_format, do_scale_inv_format to indicate scale_invs' format
  4. Implemented multi_tensor_transpose_to_bhsd() to permute tensors from BSHD/SBHD to BHSD, with the TMA path optimized for FP16/BF16 dtype, D=192/128 cases, fallback_vec_aligned path for Byte, D=8/4, and a fallback_non_vec_aligned path for Byte, D=6
  5. Implemented multi_tensor_pad_last_dim() to pad multiple tensors' D to %4 for rowwise and %128 for columnwise
  6. Implemented multi_tensor_swizzle_row_scaling_narrow_k_kernel and multi_tensor_swizzle_col_scaling_narrow_m_kernel to optimize for small K/M dimensions
  7. Added MXFP8PaddedSizes, pad_s_d_for_mxfp8, generateMatrixStridesWithFormat, generateMatrixStridesWithLayout for size/stride computation
  8. Fixed bug in O/dQKV shape logic for MLA; added utility nvte_convert_qkv_shape
  9. Refactored CP and added support for MXFP8 in cp_comm_type={'a2a', 'p2p', 'a2a+p2p', 'all_gather'}, with {'a2a', 'p2p', 'a2a+p2p', 'all_gather'} for MHA/GQA/MQA/MLA, {'a2a', 'all_gather'} for SWA, and {'a2a'} for sink attention
  10. Fixed scale_inv_offsets calculation in GroupedTensor
  11. MXFP8 attention requires cudnn-frontend v1.22+ and cuDNN 9.21+

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please see Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This reverts commit d9ff566.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
cyanguwa and others added 5 commits April 17, 2026 14:34
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci L1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci L1

pre-commit-ci Bot and others added 3 commits April 18, 2026 01:22
This reverts commit f09961a.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci L1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci L1

cyanguwa and others added 2 commits April 20, 2026 17:24
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci L1

@ptrendx ptrendx merged commit ee5dcec into NVIDIA:main Apr 21, 2026
49 of 53 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 22, 2026
* initial implementation for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* semi-working FP8; broken F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* comment out F16 pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pull in grouped_quantize for MXFP8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* grouped tensor - pytorch

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix shapes/strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused; clean up

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split d to d_qk/d_v; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at SWA/MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leftover prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "update FE"

This reverts commit d9ff566.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MLA O strides; add bottom_right_diagonal

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_quantizers; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fprop; add o_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd with o_format/d_out_format/dqkv_layout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dtype/o_format/etc in bwd calls

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix upon last commit for paddedsizes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mxfp8 env var

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable FA for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mha test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd; force determinism; fix shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE from pre-merge branch to post-merge develop

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow MXFP8 linear + f16 attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* open up a2a for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweaks for last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable mla ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to main grouped tensor impl

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks to return to main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix combine_and_quantize for f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor tweaks

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ds descale_o

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix ds descale_o"

This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for p2p and ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tweak cp test skips

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd KV tensors

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak recipe control and backend selection

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak quantizer logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes after last two commits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* improve generate strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for previous commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd for current/delayed

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dO/dO_f16 strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests: SWA logic/test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add fp8 sink attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix a2a comm for F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove nan/inf print in test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a+p2p f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to include new fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd for bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor a2a for fu/fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix d64

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor p2p/a2a+p2p; mostly regarding shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add shadow f16 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix SWA/BRCM

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch to GH FE temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to latest commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update group tensor usage after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* env vars for qdq(q,k), o_f16 tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow other recipes than mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix grouped tensor for MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change cp test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add shadow f16 bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix a2a+p2p for sbhd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix last commit and causal flag for fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable fp8 sink and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor cleanup for cp/non-cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix TE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* temporary: random sink/print sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "temporary: random sink/print sink"

This reverts commit 706095f.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix compare_and_assert for None cases

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove logic for b and simplify logic for dqkv types

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix for ndim_q/kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add explanation of fp8_output/grad in MHA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tidy up FP8 checks for bhsd/learnable

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leading underscores in nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify logic in generateMatrixStridesWithLayout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up strides/ifelse-recipe logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak checks in utils.py

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak UnfusedDPA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable testing for ag+swa and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak FusedAttn, fp8/f16 tensor naming/docstring

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up p2p/a2a+p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qdq dO in bwd shadow f16 path

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak qdq dO logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints in shadow paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to allow non-determinism

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fuse qkv transposes; first pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allocate contiguous block for qkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix grouped tensor row/col scale_inv offsets

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use fused permute kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize row/col as needed in fwd/bwd, non-cp/cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit ca53769.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit f19e852.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix v_col format when row is quantized

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back necessary bwd quants for shadow paths/cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove ZInv for all layouts except T3HD

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cp p2p with zinv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporarily switch to GH FE main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add condition for qdq(do) to not affect other tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix custom_mha_fp8 test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix amax dqkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8_recipe in DPA utils

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of amax for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* enable sink attn + FP8 in CP

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to GH v1.22.0

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix for inconsistent kwarg name in permute to grouped tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add TMA permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "add TMA permute"

This reverts commit 2532a50.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* TMA load for bhsd transposes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix some lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temp: quant+perm+swizzle, rope, perm_fused

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* implement narrow-m for col swizzle; reorder to pad+perm+swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fused pad into perm; remove at::zeros as zeros done in perm kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove shadow code

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for permute shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check smem size before entering narrow-k/m kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* expand permute to multi_tensor_

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor qkv/do quant; create a fast_path call

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup grouped tensor fix

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _with_amax for create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reimplement inplace_multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit; set swizzled flag in python

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove permute_to_grouped_tensor_bwd; clean up fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add doxygen for multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes based on code review

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* group layouts/formats in APIs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename nvte_convert_qkv_format to nvte_convert_qkv_shape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove MXFP8 create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename permute_to_grouped_tensor to transpose_to_bhsd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add multi_tensor_swizzle_xx_unchecked and split the calls/paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* straighten up indexing for multi_tensor_pad

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* batch up kernel calls per-16-tensors for pad and permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove nvec128; rename nvec64 back to nvec

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add Macros/arch specifics for compilation

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt 1: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 1: MLA RoPE"

This reverts commit 7922924.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kv_cache tests for Fused, is_page=True

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* attempt 2: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove CUDNN_VERSION 8900 macros

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add narrow-k/m swizzle tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* compile flash_attn.cu with special archs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 2: MLA RoPE"

This reverts commit 3b854b2.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* make contiguous instead of check is_contiguous

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused s_q/s_kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused issue_tma_store_strided

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version gate for mxfp8 for CPP users

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace nvte_get_qkv_shape with AttentionShape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* populate nvte_ changes to Jax

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix minor merge issue

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to FE 1.21 since it's what mxfp8 needs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* udpate jax attention shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "revert to FE 1.21 since it's what mxfp8 needs"

This reverts commit f09961a.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP AG test on Hopper

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
YigongQin pushed a commit to YigongQin/TransformerEngine that referenced this pull request Apr 23, 2026
* initial implementation for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* semi-working FP8; broken F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* comment out F16 pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pull in grouped_quantize for MXFP8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* grouped tensor - pytorch

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix shapes/strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused; clean up

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split d to d_qk/d_v; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at SWA/MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leftover prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "update FE"

This reverts commit d9ff566.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MLA O strides; add bottom_right_diagonal

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_quantizers; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fprop; add o_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd with o_format/d_out_format/dqkv_layout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dtype/o_format/etc in bwd calls

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix upon last commit for paddedsizes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mxfp8 env var

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable FA for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mha test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd; force determinism; fix shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE from pre-merge branch to post-merge develop

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow MXFP8 linear + f16 attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* open up a2a for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweaks for last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable mla ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to main grouped tensor impl

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks to return to main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix combine_and_quantize for f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor tweaks

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ds descale_o

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix ds descale_o"

This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for p2p and ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tweak cp test skips

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd KV tensors

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak recipe control and backend selection

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak quantizer logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes after last two commits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* improve generate strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for previous commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd for current/delayed

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dO/dO_f16 strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests: SWA logic/test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add fp8 sink attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix a2a comm for F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove nan/inf print in test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a+p2p f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to include new fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd for bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor a2a for fu/fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix d64

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor p2p/a2a+p2p; mostly regarding shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add shadow f16 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix SWA/BRCM

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch to GH FE temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to latest commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update group tensor usage after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* env vars for qdq(q,k), o_f16 tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow other recipes than mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix grouped tensor for MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change cp test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add shadow f16 bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix a2a+p2p for sbhd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix last commit and causal flag for fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable fp8 sink and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor cleanup for cp/non-cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix TE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* temporary: random sink/print sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "temporary: random sink/print sink"

This reverts commit 706095f.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix compare_and_assert for None cases

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove logic for b and simplify logic for dqkv types

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix for ndim_q/kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add explanation of fp8_output/grad in MHA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tidy up FP8 checks for bhsd/learnable

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leading underscores in nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify logic in generateMatrixStridesWithLayout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up strides/ifelse-recipe logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak checks in utils.py

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak UnfusedDPA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable testing for ag+swa and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak FusedAttn, fp8/f16 tensor naming/docstring

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up p2p/a2a+p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qdq dO in bwd shadow f16 path

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak qdq dO logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints in shadow paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to allow non-determinism

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fuse qkv transposes; first pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allocate contiguous block for qkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix grouped tensor row/col scale_inv offsets

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use fused permute kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize row/col as needed in fwd/bwd, non-cp/cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit ca53769.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit f19e852.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix v_col format when row is quantized

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back necessary bwd quants for shadow paths/cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove ZInv for all layouts except T3HD

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cp p2p with zinv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporarily switch to GH FE main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add condition for qdq(do) to not affect other tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix custom_mha_fp8 test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix amax dqkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8_recipe in DPA utils

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of amax for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* enable sink attn + FP8 in CP

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to GH v1.22.0

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix for inconsistent kwarg name in permute to grouped tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add TMA permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "add TMA permute"

This reverts commit 2532a50.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* TMA load for bhsd transposes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix some lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temp: quant+perm+swizzle, rope, perm_fused

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* implement narrow-m for col swizzle; reorder to pad+perm+swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fused pad into perm; remove at::zeros as zeros done in perm kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove shadow code

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for permute shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check smem size before entering narrow-k/m kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* expand permute to multi_tensor_

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor qkv/do quant; create a fast_path call

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup grouped tensor fix

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _with_amax for create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reimplement inplace_multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit; set swizzled flag in python

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove permute_to_grouped_tensor_bwd; clean up fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add doxygen for multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes based on code review

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* group layouts/formats in APIs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename nvte_convert_qkv_format to nvte_convert_qkv_shape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove MXFP8 create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename permute_to_grouped_tensor to transpose_to_bhsd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add multi_tensor_swizzle_xx_unchecked and split the calls/paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* straighten up indexing for multi_tensor_pad

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* batch up kernel calls per-16-tensors for pad and permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove nvec128; rename nvec64 back to nvec

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add Macros/arch specifics for compilation

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt 1: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 1: MLA RoPE"

This reverts commit 7922924.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kv_cache tests for Fused, is_page=True

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* attempt 2: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove CUDNN_VERSION 8900 macros

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add narrow-k/m swizzle tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* compile flash_attn.cu with special archs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 2: MLA RoPE"

This reverts commit 3b854b2.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* make contiguous instead of check is_contiguous

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused s_q/s_kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused issue_tma_store_strided

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version gate for mxfp8 for CPP users

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace nvte_get_qkv_shape with AttentionShape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* populate nvte_ changes to Jax

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix minor merge issue

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to FE 1.21 since it's what mxfp8 needs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* udpate jax attention shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "revert to FE 1.21 since it's what mxfp8 needs"

This reverts commit f09961a.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP AG test on Hopper

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request Apr 24, 2026
The merge of main (which included PR NVIDIA#2719 "Add MXFP8 attention")
renamed ctx.qkv_format to ctx.dqkv_format in the A2A backward path
and refactored flash_attn_a2a_communicate to accept separate
cu_seqlens_q_padded / cu_seqlens_kv_padded. The merge conflict
resolution introduced two bugs:

1. Bare `qkv_format` (undefined) instead of `ctx.dqkv_format` in the
   pad_between_seqs condition, so the branch that computes
   fa_cu_seqlens_q/kv with padded values was never taken.

2. `cu_seqlens_q`/`cu_seqlens_kv` passed to get_fa_args instead of
   `fa_cu_seqlens_q`/`fa_cu_seqlens_kv`, so even if the condition
   had fired, the un-padded values would have been used.

Together these caused FA3 backward to see wrong sequence boundaries,
writing non-zero gradients into padding positions of dq/dk/dv in the
A2A CP comm type.

Validated: pad_between_seqs=True + a2a + thd tests pass 3/3 runs
(was 3/3 FAILED before this fix).

Known issue (pre-existing on main, not from this branch):
- test_cp_with_fused_attention[True-None-False-False-False-{p2p,a2a}-thd-cp_1_0-bf16]
  fails with cuDNN "CompositeSoftmaxNode" error when return_max_logit=True.
  This is a cuDNN graph validation bug introduced by PR NVIDIA#2719 on main.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
* initial implementation for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* semi-working FP8; broken F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* comment out F16 pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pull in grouped_quantize for MXFP8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* grouped tensor - pytorch

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix shapes/strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused; clean up

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split d to d_qk/d_v; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at SWA/MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leftover prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "update FE"

This reverts commit d9ff566.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MLA O strides; add bottom_right_diagonal

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix get_quantizers; attempt at bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fprop; add o_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd with o_format/d_out_format/dqkv_layout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dtype/o_format/etc in bwd calls

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix upon last commit for paddedsizes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mxfp8 env var

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable FA for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add mha test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt at bwd; force determinism; fix shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE from pre-merge branch to post-merge develop

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow MXFP8 linear + f16 attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test cp p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* open up a2a for mla

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweaks for last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable mla ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to main grouped tensor impl

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks to return to main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix combine_and_quantize for f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor tweaks

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ds descale_o

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "fix ds descale_o"

This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for p2p and ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tweak cp test skips

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd KV tensors

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak recipe control and backend selection

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak quantizer logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes after last two commits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* improve generate strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for previous commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bwd for current/delayed

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dO/dO_f16 strides

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests: SWA logic/test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add fp8 sink attn

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix a2a comm for F16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove nan/inf print in test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa a2a+p2p f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to include new fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd for bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor a2a for fu/fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix d64

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor p2p/a2a+p2p; mostly regarding shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add shadow f16 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to fix SWA/BRCM

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch to GH FE temporarily

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to latest commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update group tensor usage after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* env vars for qdq(q,k), o_f16 tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow other recipes than mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix grouped tensor for MLA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change cp test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add shadow f16 bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix a2a+p2p for sbhd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix last commit and causal flag for fa

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable fp8 sink and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor cleanup for cp/non-cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix TE for FP8 sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* temporary: random sink/print sink

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "temporary: random sink/print sink"

This reverts commit 706095f.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix compare_and_assert for None cases

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove logic for b and simplify logic for dqkv types

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix for ndim_q/kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add explanation of fp8_output/grad in MHA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tidy up FP8 checks for bhsd/learnable

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove leading underscores in nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* simplify logic in generateMatrixStridesWithLayout

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up strides/ifelse-recipe logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak checks in utils.py

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak UnfusedDPA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* enable testing for ag+swa and disable fp8_mha

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak FusedAttn, fp8/f16 tensor naming/docstring

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace d_out_format with do_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up ag

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up p2p/a2a+p2p

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak test configs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* qdq dO in bwd shadow f16 path

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak qdq dO logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints in shadow paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to allow non-determinism

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fuse qkv transposes; first pass

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allocate contiguous block for qkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix grouped tensor row/col scale_inv offsets

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use fused permute kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* quantize row/col as needed in fwd/bwd, non-cp/cp

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit ca53769.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp"

This reverts commit f19e852.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix v_col format when row is quantized

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back necessary bwd quants for shadow paths/cp a2a

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove ZInv for all layouts except T3HD

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cp p2p with zinv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporarily switch to GH FE main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* switch back to GL FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ag after merge main

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add condition for qdq(do) to not affect other tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix custom_mha_fp8 test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix amax dqkv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8_recipe in DPA utils

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove use of amax for mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* enable sink attn + FP8 in CP

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to GH v1.22.0

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix for inconsistent kwarg name in permute to grouped tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add TMA permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "add TMA permute"

This reverts commit 2532a50.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* TMA load for bhsd transposes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix some lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temp: quant+perm+swizzle, rope, perm_fused

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* implement narrow-m for col swizzle; reorder to pad+perm+swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fused pad into perm; remove at::zeros as zeros done in perm kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove shadow code

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for permute shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* check smem size before entering narrow-k/m kernels

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* expand permute to multi_tensor_

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* refactor qkv/do quant; create a fast_path call

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* cleanup grouped tensor fix

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _with_amax for create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reimplement inplace_multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit; set swizzled flag in python

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove permute_to_grouped_tensor_bwd; clean up fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add doxygen for multi_tensor_swizzle

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up nvte_convert_qkv_format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes based on code review

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* group layouts/formats in APIs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename nvte_convert_qkv_format to nvte_convert_qkv_shape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove MXFP8 create_unquantized_tensor

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename permute_to_grouped_tensor to transpose_to_bhsd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add multi_tensor_swizzle_xx_unchecked and split the calls/paths

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* straighten up indexing for multi_tensor_pad

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* batch up kernel calls per-16-tensors for pad and permute

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove nvec128; rename nvec64 back to nvec

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add Macros/arch specifics for compilation

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* attempt 1: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 1: MLA RoPE"

This reverts commit 7922924.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kv_cache tests for Fused, is_page=True

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* attempt 2: MLA RoPE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove CUDNN_VERSION 8900 macros

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add narrow-k/m swizzle tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* compile flash_attn.cu with special archs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "attempt 2: MLA RoPE"

This reverts commit 3b854b2.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* make contiguous instead of check is_contiguous

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused s_q/s_kv

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unused issue_tma_store_strided

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version gate for mxfp8 for CPP users

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace nvte_get_qkv_shape with AttentionShape

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* populate nvte_ changes to Jax

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix minor merge issue

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to FE 1.21 since it's what mxfp8 needs

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* udpate jax attention shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "revert to FE 1.21 since it's what mxfp8 needs"

This reverts commit f09961a.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP AG test on Hopper

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants