Skip to content

AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion#354

Merged
wangye805 merged 49 commits into
devfrom
zain/aiter-native-bshd-thd
Jan 13, 2026
Merged

AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion#354
wangye805 merged 49 commits into
devfrom
zain/aiter-native-bshd-thd

Conversation

@Micky774

@Micky774 Micky774 commented Oct 28, 2025

Copy link
Copy Markdown
Contributor

Description

Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.

Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Integrated support for native padding kernels in fwd/bwd
  • Added BSHD + Padding --> THD + Padding conversion mechanism
  • Streamlined memory allocation logic
  • Added runtime max_seqlen calculation gated by new env var NVTE_CK_RUNTIME_MAX_SEQLEN
  • Adds v3_api_check support (temporary)
  • Implements new AITER/CK API
  • Update MQA post-processing kernels
  • Remove pad_between_seqs (need to follow-up with a PR cleaning up test suite for old pad_between_seqs edge-cases)
  • Added NVTE_CK_RUNTIME_NUM_SEGMENTS to guard runtime-calculation of the number of segments in the JAX integration

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp Outdated
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp

@wangye805 wangye805 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people

Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp Outdated
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp Outdated
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp Outdated
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/attention.cpp
@Micky774

Copy link
Copy Markdown
Contributor Author

pytorch test_numerics also shows some fused-attn related failures: FAILED tests/pytorch/test_numerics.py::test_kv_cache_accuracy[False-FusedAttention-TransformerLayer-sbhd-False-126m-1-dtype1] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [0, 650] with -0.90625 vs 0.5654296875 (diff 1.4716796875).

Not sure whether this is related to our decision to remove memsettings.

Those failures were due to a mix of not correctly dispatching to the is_SBHD workflow when dealing with SBHD_2BSHD formats, and miscalculating stride in the case of the same format. Resolved now.

@wangye805 wangye805 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?

NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:

NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);

Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.

Or it's okay if you want to leave this for another PR.

By the way, there is an "extra line" comment you may have ignored :-)

Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Comment thread transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
@wangye805

Copy link
Copy Markdown
Collaborator

In fact, I saw some level 3 pytorch cp pytest failures by run level 3 ci locally:

=========================== short test summary info ============================
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_1_0-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_1_0', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_1_1-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_1_1', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_2_0-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_2_0', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_2_1-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_2_1', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
SKIPPED [48] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:68: CP implementation with KV P2P does not support sliding window yet!
SKIPPED [16] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:70: CP implementation with KV all-gather does not support THD format yet!
SKIPPED [24] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:74: CP implementation with QKVO A2A does not support THD format yet!
SKIPPED [240] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:133: FP8 attention has not been supported on ROCm yet!
SKIPPED [40] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:153: CP implementation with KV P2P does not support sliding window yet!
SKIPPED [64] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:137: THD format does not support post_scale_bias yet!
SKIPPED [32] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:155: CP implementation with KV all-gather does not support bias yet!
SKIPPED [24] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:139: CP implementation with KV all-gather does not support THD format yet!
SKIPPED [64] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:157: CP implementation with QKVO A2A does not support bias yet!
SKIPPED [48] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:141: CP implementation with QKVO A2A does not support THD format yet!
SKIPPED [104] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:164: Only fp8 works with fp8_mha=True!
===== 4 failed, 204 passed, 704 skipped, 2 warnings in 3065.08s (0:51:05) ======
Error in test [ck] fused_attn/test_fused_attn_with_cp.py
Done [ck] fused_attn/test_fused_attn_with_cp.py
Got 1 test errors during run at level 3

Attached you can find the detailed log
torch_mgpu.txt

@wangye805 wangye805 force-pushed the zain/aiter-native-bshd-thd branch from 658c105 to 871cb4e Compare January 13, 2026 04:57
@wangye805

wangye805 commented Jan 13, 2026

Copy link
Copy Markdown
Collaborator

Manual CI runs in local MI300X machine:
core_sgpu_l3.txt
torch_mgpu_l3.txt
jax_sgpu_l3.txt

@wangye805

Copy link
Copy Markdown
Collaborator

jax_mgpu_l3.txt

@wangye805 wangye805 merged commit 4920d50 into dev Jan 13, 2026
1 of 2 checks passed
@Micky774 Micky774 deleted the zain/aiter-native-bshd-thd branch January 14, 2026 17:46
Micky774 added a commit that referenced this pull request Feb 25, 2026
…version (#354)

* [ROCm] manually pick up fwd native padding support from Meekail's PR

* Initial update

* Updated stride

* Corrected typing in allocation portions

* Applied Ye's patch

* [ROCm] manually pick Meekail's PR to support native padding for bwd

* [ROCm] jax use runtime segment

* [ROCm] get runtime max_seqlen as well

* [ROCm] support v2 bwd native padding

* Updated conversion to include bwd pass

* Added BWD BSHD-->THD conversion and minor logic refactor

* Corrected softmax lse bug

* Updated logic flow and re-caclulation

* [ROCm] manually pick Meekail's PR to support native padding for bwd

[ROCm] support v2 bwd native padding

* Added env var guard

* Updated ptr variables and streamlined dispatch

* Added env guard

* Corrected bshd_to_thd conversion arguments

* Corrected logical flow

* Guarded memset and corrected allocation

* Remove V3 API check and guard memsets

* PR comments

* Updated documentation

* PR review reconciliation

- Updated debug message for BSHD-->THD conversion
- Added env variable to gate FWD output memset for padding
- Removed guards on memsets for d{Q,K,V} matrices

* Added explicit test

* Formatting for bwd debug

* Resolved error when using mixed formats e.g. sbhd_2bshd

* Updated guard on flash-attention forced support

* Added check for SBHD_2BSHD

* Added guard on dk/dv memset

* Removed env var gating for dk/dv zero padding, formatting

* Added inline comment to test

* Corrected Softmax LSE buffer allocation

* Correct Softmax LSE buffer memory allocation

* Adjusted fwd pass softmax lse allocation

* Adjusted bwd pass softmax conversion allocation

* Minor reversions

* [ROCm] fix the aiter fwd v3 cu_seqlen/cu_seqlen_padded api issue

* Update README.rst to fix formatting

* [ROCm] update aiter commit with swa fix

---------

Co-authored-by: Ye Wang <yewang12@amd.com>
Micky774 added a commit that referenced this pull request Mar 2, 2026
* AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion (#354)

* [ROCm] manually pick up fwd native padding support from Meekail's PR

* Initial update

* Updated stride

* Corrected typing in allocation portions

* Applied Ye's patch

* [ROCm] manually pick Meekail's PR to support native padding for bwd

* [ROCm] jax use runtime segment

* [ROCm] get runtime max_seqlen as well

* [ROCm] support v2 bwd native padding

* Updated conversion to include bwd pass

* Added BWD BSHD-->THD conversion and minor logic refactor

* Corrected softmax lse bug

* Updated logic flow and re-caclulation

* [ROCm] manually pick Meekail's PR to support native padding for bwd

[ROCm] support v2 bwd native padding

* Added env var guard

* Updated ptr variables and streamlined dispatch

* Added env guard

* Corrected bshd_to_thd conversion arguments

* Corrected logical flow

* Guarded memset and corrected allocation

* Remove V3 API check and guard memsets

* PR comments

* Updated documentation

* PR review reconciliation

- Updated debug message for BSHD-->THD conversion
- Added env variable to gate FWD output memset for padding
- Removed guards on memsets for d{Q,K,V} matrices

* Added explicit test

* Formatting for bwd debug

* Resolved error when using mixed formats e.g. sbhd_2bshd

* Updated guard on flash-attention forced support

* Added check for SBHD_2BSHD

* Added guard on dk/dv memset

* Removed env var gating for dk/dv zero padding, formatting

* Added inline comment to test

* Corrected Softmax LSE buffer allocation

* Correct Softmax LSE buffer memory allocation

* Adjusted fwd pass softmax lse allocation

* Adjusted bwd pass softmax conversion allocation

* Minor reversions

* [ROCm] fix the aiter fwd v3 cu_seqlen/cu_seqlen_padded api issue

* Update README.rst to fix formatting

* [ROCm] update aiter commit with swa fix

---------

Co-authored-by: Ye Wang <yewang12@amd.com>

* Updated aiter (dropped from cherry-pick)

* Update to include 7.2 fix in AITER/CK

* Update test

---------

Co-authored-by: Ye Wang <yewang12@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants