AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion#354
Conversation
[ROCm] support v2 bwd native padding
wangye805
left a comment
There was a problem hiding this comment.
Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people
Those failures were due to a mix of not correctly dispatching to the |
wangye805
left a comment
There was a problem hiding this comment.
For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?
NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:
Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.
Or it's okay if you want to leave this for another PR.
By the way, there is an "extra line" comment you may have ignored :-)
|
In fact, I saw some level 3 pytorch cp pytest failures by run level 3 ci locally: Attached you can find the detailed log |
658c105 to
871cb4e
Compare
|
Manual CI runs in local MI300X machine: |
…version (#354) * [ROCm] manually pick up fwd native padding support from Meekail's PR * Initial update * Updated stride * Corrected typing in allocation portions * Applied Ye's patch * [ROCm] manually pick Meekail's PR to support native padding for bwd * [ROCm] jax use runtime segment * [ROCm] get runtime max_seqlen as well * [ROCm] support v2 bwd native padding * Updated conversion to include bwd pass * Added BWD BSHD-->THD conversion and minor logic refactor * Corrected softmax lse bug * Updated logic flow and re-caclulation * [ROCm] manually pick Meekail's PR to support native padding for bwd [ROCm] support v2 bwd native padding * Added env var guard * Updated ptr variables and streamlined dispatch * Added env guard * Corrected bshd_to_thd conversion arguments * Corrected logical flow * Guarded memset and corrected allocation * Remove V3 API check and guard memsets * PR comments * Updated documentation * PR review reconciliation - Updated debug message for BSHD-->THD conversion - Added env variable to gate FWD output memset for padding - Removed guards on memsets for d{Q,K,V} matrices * Added explicit test * Formatting for bwd debug * Resolved error when using mixed formats e.g. sbhd_2bshd * Updated guard on flash-attention forced support * Added check for SBHD_2BSHD * Added guard on dk/dv memset * Removed env var gating for dk/dv zero padding, formatting * Added inline comment to test * Corrected Softmax LSE buffer allocation * Correct Softmax LSE buffer memory allocation * Adjusted fwd pass softmax lse allocation * Adjusted bwd pass softmax conversion allocation * Minor reversions * [ROCm] fix the aiter fwd v3 cu_seqlen/cu_seqlen_padded api issue * Update README.rst to fix formatting * [ROCm] update aiter commit with swa fix --------- Co-authored-by: Ye Wang <yewang12@amd.com>
* AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion (#354) * [ROCm] manually pick up fwd native padding support from Meekail's PR * Initial update * Updated stride * Corrected typing in allocation portions * Applied Ye's patch * [ROCm] manually pick Meekail's PR to support native padding for bwd * [ROCm] jax use runtime segment * [ROCm] get runtime max_seqlen as well * [ROCm] support v2 bwd native padding * Updated conversion to include bwd pass * Added BWD BSHD-->THD conversion and minor logic refactor * Corrected softmax lse bug * Updated logic flow and re-caclulation * [ROCm] manually pick Meekail's PR to support native padding for bwd [ROCm] support v2 bwd native padding * Added env var guard * Updated ptr variables and streamlined dispatch * Added env guard * Corrected bshd_to_thd conversion arguments * Corrected logical flow * Guarded memset and corrected allocation * Remove V3 API check and guard memsets * PR comments * Updated documentation * PR review reconciliation - Updated debug message for BSHD-->THD conversion - Added env variable to gate FWD output memset for padding - Removed guards on memsets for d{Q,K,V} matrices * Added explicit test * Formatting for bwd debug * Resolved error when using mixed formats e.g. sbhd_2bshd * Updated guard on flash-attention forced support * Added check for SBHD_2BSHD * Added guard on dk/dv memset * Removed env var gating for dk/dv zero padding, formatting * Added inline comment to test * Corrected Softmax LSE buffer allocation * Correct Softmax LSE buffer memory allocation * Adjusted fwd pass softmax lse allocation * Adjusted bwd pass softmax conversion allocation * Minor reversions * [ROCm] fix the aiter fwd v3 cu_seqlen/cu_seqlen_padded api issue * Update README.rst to fix formatting * [ROCm] update aiter commit with swa fix --------- Co-authored-by: Ye Wang <yewang12@amd.com> * Updated aiter (dropped from cherry-pick) * Update to include 7.2 fix in AITER/CK * Update test --------- Co-authored-by: Ye Wang <yewang12@amd.com>
Description
Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.
Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
max_seqlencalculation gated by new env varNVTE_CK_RUNTIME_MAX_SEQLENv3_api_checksupport (temporary)pad_between_seqs(need to follow-up with a PR cleaning up test suite for oldpad_between_seqsedge-cases)NVTE_CK_RUNTIME_NUM_SEGMENTSto guard runtime-calculation of the number of segments in the JAX integrationChecklist: