Add MXFP8 attention#2719
Merged
Merged
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This reverts commit d9ff566. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Collaborator
Author
|
/te-ci L1 |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Collaborator
Author
|
/te-ci L1 |
for more information, see https://pre-commit.ci
This reverts commit f09961a. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Collaborator
Author
|
/te-ci L1 |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Collaborator
Author
|
/te-ci L1 |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Collaborator
Author
|
/te-ci L1 |
ptrendx
approved these changes
Apr 21, 2026
This was referenced Apr 22, 2026
KshitijLakhani
pushed a commit
that referenced
this pull request
Apr 22, 2026
* initial implementation for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * semi-working FP8; broken F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * comment out F16 pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pull in grouped_quantize for MXFP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * grouped tensor - pytorch Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix shapes/strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split d to d_qk/d_v; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at SWA/MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leftover prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "update FE" This reverts commit d9ff566. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MLA O strides; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_quantizers; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fprop; add o_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd with o_format/d_out_format/dqkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dtype/o_format/etc in bwd calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix upon last commit for paddedsizes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mxfp8 env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable FA for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mha test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd; force determinism; fix shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE from pre-merge branch to post-merge develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow MXFP8 linear + f16 attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * open up a2a for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweaks for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable mla ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to main grouped tensor impl Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweaks to return to main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix combine_and_quantize for f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ds descale_o Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix ds descale_o" This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for p2p and ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tweak cp test skips Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd KV tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak recipe control and backend selection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak quantizer logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes after last two commits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * improve generate strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd for current/delayed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dO/dO_f16 strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix tests: SWA logic/test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add fp8 sink attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix a2a comm for F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove nan/inf print in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a+p2p f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to include new fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor a2a for fu/fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix d64 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor p2p/a2a+p2p; mostly regarding shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add shadow f16 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix SWA/BRCM Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch to GH FE temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to latest commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update group tensor usage after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * env vars for qdq(q,k), o_f16 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow other recipes than mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix grouped tensor for MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change cp test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add shadow f16 bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a2a+p2p for sbhd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit and causal flag for fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable fp8 sink and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor cleanup for cp/non-cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix TE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * temporary: random sink/print sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "temporary: random sink/print sink" This reverts commit 706095f. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix compare_and_assert for None cases Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for b and simplify logic for dqkv types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for ndim_q/kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add explanation of fp8_output/grad in MHA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tidy up FP8 checks for bhsd/learnable Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leading underscores in nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify logic in generateMatrixStridesWithLayout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up strides/ifelse-recipe logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak checks in utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable testing for ag+swa and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak FusedAttn, fp8/f16 tensor naming/docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up p2p/a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qdq dO in bwd shadow f16 path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak qdq dO logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints in shadow paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to allow non-determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fuse qkv transposes; first pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allocate contiguous block for qkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix grouped tensor row/col scale_inv offsets Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use fused permute kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize row/col as needed in fwd/bwd, non-cp/cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit ca53769. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit f19e852. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix v_col format when row is quantized Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back necessary bwd quants for shadow paths/cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove ZInv for all layouts except T3HD Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cp p2p with zinv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporarily switch to GH FE main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add condition for qdq(do) to not affect other tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix custom_mha_fp8 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix amax dqkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_recipe in DPA utils Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of amax for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable sink attn + FP8 in CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to GH v1.22.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix for inconsistent kwarg name in permute to grouped tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add TMA permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "add TMA permute" This reverts commit 2532a50. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * TMA load for bhsd transposes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix some lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temp: quant+perm+swizzle, rope, perm_fused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * implement narrow-m for col swizzle; reorder to pad+perm+swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fused pad into perm; remove at::zeros as zeros done in perm kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove shadow code Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for permute shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * check smem size before entering narrow-k/m kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * expand permute to multi_tensor_ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor qkv/do quant; create a fast_path call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup grouped tensor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _with_amax for create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reimplement inplace_multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit; set swizzled flag in python Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove permute_to_grouped_tensor_bwd; clean up fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add doxygen for multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes based on code review Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * group layouts/formats in APIs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename nvte_convert_qkv_format to nvte_convert_qkv_shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove MXFP8 create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename permute_to_grouped_tensor to transpose_to_bhsd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add multi_tensor_swizzle_xx_unchecked and split the calls/paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * straighten up indexing for multi_tensor_pad Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * batch up kernel calls per-16-tensors for pad and permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove nvec128; rename nvec64 back to nvec Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add Macros/arch specifics for compilation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt 1: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 1: MLA RoPE" This reverts commit 7922924. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kv_cache tests for Fused, is_page=True Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt 2: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove CUDNN_VERSION 8900 macros Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add narrow-k/m swizzle tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * compile flash_attn.cu with special archs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 2: MLA RoPE" This reverts commit 3b854b2. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make contiguous instead of check is_contiguous Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused s_q/s_kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused issue_tma_store_strided Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add version gate for mxfp8 for CPP users Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace nvte_get_qkv_shape with AttentionShape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * populate nvte_ changes to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix minor merge issue Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to FE 1.21 since it's what mxfp8 needs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * udpate jax attention shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "revert to FE 1.21 since it's what mxfp8 needs" This reverts commit f09961a. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP AG test on Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
YigongQin
pushed a commit
to YigongQin/TransformerEngine
that referenced
this pull request
Apr 23, 2026
* initial implementation for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * semi-working FP8; broken F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * comment out F16 pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pull in grouped_quantize for MXFP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * grouped tensor - pytorch Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix shapes/strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split d to d_qk/d_v; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at SWA/MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leftover prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "update FE" This reverts commit d9ff566. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MLA O strides; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_quantizers; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fprop; add o_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd with o_format/d_out_format/dqkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dtype/o_format/etc in bwd calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix upon last commit for paddedsizes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mxfp8 env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable FA for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mha test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd; force determinism; fix shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE from pre-merge branch to post-merge develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow MXFP8 linear + f16 attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * open up a2a for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweaks for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable mla ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to main grouped tensor impl Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweaks to return to main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix combine_and_quantize for f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ds descale_o Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix ds descale_o" This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for p2p and ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tweak cp test skips Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd KV tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak recipe control and backend selection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak quantizer logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes after last two commits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * improve generate strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd for current/delayed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dO/dO_f16 strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix tests: SWA logic/test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add fp8 sink attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix a2a comm for F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove nan/inf print in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a+p2p f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to include new fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor a2a for fu/fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix d64 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor p2p/a2a+p2p; mostly regarding shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add shadow f16 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix SWA/BRCM Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch to GH FE temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to latest commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update group tensor usage after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * env vars for qdq(q,k), o_f16 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow other recipes than mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix grouped tensor for MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change cp test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add shadow f16 bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a2a+p2p for sbhd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit and causal flag for fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable fp8 sink and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor cleanup for cp/non-cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix TE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * temporary: random sink/print sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "temporary: random sink/print sink" This reverts commit 706095f. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix compare_and_assert for None cases Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for b and simplify logic for dqkv types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for ndim_q/kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add explanation of fp8_output/grad in MHA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tidy up FP8 checks for bhsd/learnable Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leading underscores in nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify logic in generateMatrixStridesWithLayout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up strides/ifelse-recipe logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak checks in utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable testing for ag+swa and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak FusedAttn, fp8/f16 tensor naming/docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up p2p/a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qdq dO in bwd shadow f16 path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak qdq dO logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints in shadow paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to allow non-determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fuse qkv transposes; first pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allocate contiguous block for qkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix grouped tensor row/col scale_inv offsets Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use fused permute kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize row/col as needed in fwd/bwd, non-cp/cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit ca53769. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit f19e852. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix v_col format when row is quantized Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back necessary bwd quants for shadow paths/cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove ZInv for all layouts except T3HD Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cp p2p with zinv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporarily switch to GH FE main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add condition for qdq(do) to not affect other tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix custom_mha_fp8 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix amax dqkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_recipe in DPA utils Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of amax for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable sink attn + FP8 in CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to GH v1.22.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix for inconsistent kwarg name in permute to grouped tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add TMA permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "add TMA permute" This reverts commit 2532a50. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * TMA load for bhsd transposes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix some lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temp: quant+perm+swizzle, rope, perm_fused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * implement narrow-m for col swizzle; reorder to pad+perm+swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fused pad into perm; remove at::zeros as zeros done in perm kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove shadow code Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for permute shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * check smem size before entering narrow-k/m kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * expand permute to multi_tensor_ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor qkv/do quant; create a fast_path call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup grouped tensor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _with_amax for create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reimplement inplace_multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit; set swizzled flag in python Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove permute_to_grouped_tensor_bwd; clean up fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add doxygen for multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes based on code review Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * group layouts/formats in APIs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename nvte_convert_qkv_format to nvte_convert_qkv_shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove MXFP8 create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename permute_to_grouped_tensor to transpose_to_bhsd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add multi_tensor_swizzle_xx_unchecked and split the calls/paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * straighten up indexing for multi_tensor_pad Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * batch up kernel calls per-16-tensors for pad and permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove nvec128; rename nvec64 back to nvec Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add Macros/arch specifics for compilation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt 1: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 1: MLA RoPE" This reverts commit 7922924. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kv_cache tests for Fused, is_page=True Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt 2: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove CUDNN_VERSION 8900 macros Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add narrow-k/m swizzle tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * compile flash_attn.cu with special archs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 2: MLA RoPE" This reverts commit 3b854b2. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make contiguous instead of check is_contiguous Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused s_q/s_kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused issue_tma_store_strided Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add version gate for mxfp8 for CPP users Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace nvte_get_qkv_shape with AttentionShape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * populate nvte_ changes to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix minor merge issue Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to FE 1.21 since it's what mxfp8 needs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * udpate jax attention shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "revert to FE 1.21 since it's what mxfp8 needs" This reverts commit f09961a. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP AG test on Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
sudhakarsingh27
added a commit
to sudhakarsingh27/TransformerEngine
that referenced
this pull request
Apr 24, 2026
The merge of main (which included PR NVIDIA#2719 "Add MXFP8 attention") renamed ctx.qkv_format to ctx.dqkv_format in the A2A backward path and refactored flash_attn_a2a_communicate to accept separate cu_seqlens_q_padded / cu_seqlens_kv_padded. The merge conflict resolution introduced two bugs: 1. Bare `qkv_format` (undefined) instead of `ctx.dqkv_format` in the pad_between_seqs condition, so the branch that computes fa_cu_seqlens_q/kv with padded values was never taken. 2. `cu_seqlens_q`/`cu_seqlens_kv` passed to get_fa_args instead of `fa_cu_seqlens_q`/`fa_cu_seqlens_kv`, so even if the condition had fired, the un-padded values would have been used. Together these caused FA3 backward to see wrong sequence boundaries, writing non-zero gradients into padding positions of dq/dk/dv in the A2A CP comm type. Validated: pad_between_seqs=True + a2a + thd tests pass 3/3 runs (was 3/3 FAILED before this fix). Known issue (pre-existing on main, not from this branch): - test_cp_with_fused_attention[True-None-False-False-False-{p2p,a2a}-thd-cp_1_0-bf16] fails with cuDNN "CompositeSoftmaxNode" error when return_max_logit=True. This is a cuDNN graph validation bug introduced by PR NVIDIA#2719 on main. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
faradawn
pushed a commit
to faradawn/TransformerEngine
that referenced
this pull request
May 14, 2026
* initial implementation for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * semi-working FP8; broken F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * comment out F16 pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pull in grouped_quantize for MXFP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * grouped tensor - pytorch Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix shapes/strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split d to d_qk/d_v; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at SWA/MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leftover prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "update FE" This reverts commit d9ff566. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MLA O strides; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get_quantizers; attempt at bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fprop; add o_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd with o_format/d_out_format/dqkv_layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dtype/o_format/etc in bwd calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix upon last commit for paddedsizes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mxfp8 env var Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable FA for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add mha test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt at bwd; force determinism; fix shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE from pre-merge branch to post-merge develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow MXFP8 linear + f16 attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test cp p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * open up a2a for mla Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweaks for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable mla ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to main grouped tensor impl Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweaks to return to main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix combine_and_quantize for f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ds descale_o Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix ds descale_o" This reverts commit cd0bd82e239ff01210338b4e34cb8784109d22ec. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for p2p and ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tweak cp test skips Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd KV tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak recipe control and backend selection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak quantizer logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes after last two commits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * improve generate strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix bwd for current/delayed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dO/dO_f16 strides Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix tests: SWA logic/test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add fp8 sink attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix a2a comm for F16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove nan/inf print in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fa a2a+p2p f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to include new fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor a2a for fu/fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix d64 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor p2p/a2a+p2p; mostly regarding shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add shadow f16 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to fix SWA/BRCM Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch to GH FE temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to latest commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update group tensor usage after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * env vars for qdq(q,k), o_f16 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow other recipes than mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix grouped tensor for MLA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change cp test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add shadow f16 bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a2a+p2p for sbhd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit and causal flag for fa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable fp8 sink and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor cleanup for cp/non-cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix TE for FP8 sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * temporary: random sink/print sink Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "temporary: random sink/print sink" This reverts commit 706095f. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix compare_and_assert for None cases Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for b and simplify logic for dqkv types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for ndim_q/kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add explanation of fp8_output/grad in MHA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tidy up FP8 checks for bhsd/learnable Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove leading underscores in nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify logic in generateMatrixStridesWithLayout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up strides/ifelse-recipe logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak checks in utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * enable testing for ag+swa and disable fp8_mha Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak FusedAttn, fp8/f16 tensor naming/docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace d_out_format with do_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up ag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up p2p/a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak test configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qdq dO in bwd shadow f16 path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak qdq dO logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints in shadow paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to allow non-determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fuse qkv transposes; first pass Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nvec = 128 bits Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allocate contiguous block for qkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix grouped tensor row/col scale_inv offsets Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use fused permute kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * quantize row/col as needed in fwd/bwd, non-cp/cp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit ca53769. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp" This reverts commit f19e852. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix v_col format when row is quantized Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back necessary bwd quants for shadow paths/cp a2a Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove ZInv for all layouts except T3HD Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cp p2p with zinv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporarily switch to GH FE main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch back to GL FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix ag after merge main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add condition for qdq(do) to not affect other tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix custom_mha_fp8 test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix amax dqkv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_recipe in DPA utils Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove use of amax for mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable sink attn + FP8 in CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to GH v1.22.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix for inconsistent kwarg name in permute to grouped tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add TMA permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "add TMA permute" This reverts commit 2532a50. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * TMA load for bhsd transposes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix some lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temp: quant+perm+swizzle, rope, perm_fused Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove mla_rope for now; clean up quant+permute+pad_swizzle; create multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * implement narrow-m for col swizzle; reorder to pad+perm+swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fused pad into perm; remove at::zeros as zeros done in perm kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove shadow code Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for permute shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * check smem size before entering narrow-k/m kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * expand permute to multi_tensor_ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * refactor qkv/do quant; create a fast_path call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup grouped tensor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _with_amax for create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reimplement inplace_multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit; set swizzled flag in python Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove permute_to_grouped_tensor_bwd; clean up fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add doxygen for multi_tensor_swizzle Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up nvte_convert_qkv_format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes based on code review Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * group layouts/formats in APIs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename nvte_convert_qkv_format to nvte_convert_qkv_shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove MXFP8 create_unquantized_tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename permute_to_grouped_tensor to transpose_to_bhsd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add multi_tensor_swizzle_xx_unchecked and split the calls/paths Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * straighten up indexing for multi_tensor_pad Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * batch up kernel calls per-16-tensors for pad and permute Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove nvec128; rename nvec64 back to nvec Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add Macros/arch specifics for compilation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * attempt 1: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 1: MLA RoPE" This reverts commit 7922924. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kv_cache tests for Fused, is_page=True Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt 2: MLA RoPE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use DIVUP/_TO_MULTIPLE for pad_s_d_for_mxfp8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove CUDNN_VERSION 8900 macros Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add narrow-k/m swizzle tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * compile flash_attn.cu with special archs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "attempt 2: MLA RoPE" This reverts commit 3b854b2. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make contiguous instead of check is_contiguous Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused s_q/s_kv Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unused issue_tma_store_strided Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add version gate for mxfp8 for CPP users Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace nvte_get_qkv_shape with AttentionShape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * populate nvte_ changes to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix minor merge issue Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to FE 1.21 since it's what mxfp8 needs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * udpate jax attention shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "revert to FE 1.21 since it's what mxfp8 needs" This reverts commit f09961a. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * pick FE 1.22 to support mxfp8 and avoid rng issue in 1.22.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP AG test on Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
FusedAttentionbackend (fwd+bwd, BSHD/SBHD, TE-PyTorch)o_format,do_format, anddqkv_layoutmxfp8_quantize_fast_path()to quantize multiple tensors, pad/permute/swizzle thescale_invs for a more efficient quantization pipeline; addedqkv_scale_inv_format,do_scale_inv_formatto indicatescale_invs' formatmulti_tensor_transpose_to_bhsd()to permute tensors from BSHD/SBHD to BHSD, with the TMA path optimized for FP16/BF16 dtype, D=192/128 cases,fallback_vec_alignedpath for Byte, D=8/4, and afallback_non_vec_alignedpath for Byte, D=6multi_tensor_pad_last_dim()to pad multiple tensors' D to %4 for rowwise and %128 for columnwisemulti_tensor_swizzle_row_scaling_narrow_k_kernelandmulti_tensor_swizzle_col_scaling_narrow_m_kernelto optimize for small K/M dimensionsMXFP8PaddedSizes,pad_s_d_for_mxfp8,generateMatrixStridesWithFormat,generateMatrixStridesWithLayoutfor size/stride computationnvte_convert_qkv_shapecp_comm_type={'a2a', 'p2p', 'a2a+p2p', 'all_gather'}, with{'a2a', 'p2p', 'a2a+p2p', 'all_gather'}for MHA/GQA/MQA/MLA,{'a2a', 'all_gather'}for SWA, and{'a2a'}for sink attentionscale_inv_offsetscalculation inGroupedTensorType of change
Changes
Please see Description.
Checklist: