Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589
Closed
yeandy wants to merge 3 commits into
Closed
Conversation
TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd` and `is_segment_ids_reordered` required keyword arguments on `SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect segment position calculation for THD layouts. Since the packing branch in `cudnn_flash_attention` uses `qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs, the correct values are `is_thd=True, is_segment_ids_reordered=False`. Without this fix, any configuration using `attention="cudnn_flash_te"` with `packing=True` and real data (`dataset_type != "synthetic"`) fails with: TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered'
The existing test_gpu_packed_attention uses dataset_type=synthetic, which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely (takes the elif branch in cudnn_flash_attention). This meant the TE v2.12 API breakage went undetected. Add test_gpu_packed_attention_hf that uses HF parquet data with packing=True + attention=cudnn_flash_te, exercising the actual SequenceDescriptor codepath. This serves as a regression test for NVIDIA/TransformerEngine#2523.
Contributor
Author
|
@NuojCheng can you please take a look? I see you have most recently worked on this section of the code. I can mark "Ready for review" now, but it will ping a bunch of reviewers. Want to minimize noise 😄 But we still may need to mark as ready if we want to run the GH actions workflow to run tests? |
Contributor
Author
|
Hi @NuojCheng, following up to see how best to proceed. Thanks! |
|
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
SequenceDescriptor.from_segment_ids_and_pos()calls incudnn_flash_attentionto pass the two new required keyword arguments (is_thd,is_segment_ids_reordered) added in TransformerEngine v2.12test_gpu_packed_attention_hfthat exercises the actualSequenceDescriptorcodepath with real HF dataTransformerEngine v2.12 (NVIDIA/TransformerEngine#2523, merged Dec 31, 2025) made
is_thdandis_segment_ids_reorderedrequired keyword-only arguments onSequenceDescriptor.from_segment_ids_and_pos()to fix incorrect segment position calculation for THD layouts. This is a breaking change documented in the TE v2.12 release notes.Since MaxText's cuda12-requirements.txt pins
transformer-engine>=2.9.0, any fresh install picks upTE >= 2.12, which breaks any configuration usingattention="cudnn_flash_te"+packing=True+ real data (dataset_type != "synthetic") with:The ROCm fork of TransformerEngine picked up the same change via its IFU 2.12 merge (ccda1a5, Apr 3, 2026).
Fix
The packing branch in
cudnn_flash_attentionusesqkv_layout="THD_THD_THD"with standard (non-reordered) segment IDs, so the correct values are:is_thd=True— THD packed layoutis_segment_ids_reordered=False— no load-balancing reordering appliedWhy this wasn't caught by existing tests
The existing
test_gpu_packed_attentionusesdataset_type=synthetic, which takes theelifbranch incudnn_flash_attentionand setsattn_mask = None— completely bypassingSequenceDescriptor.from_segment_ids_and_pos(). The newtest_gpu_packed_attention_hfuses HF parquet data to exercise the actual codepath.Tests
[X] Verify the fix with a training run using attention=cudnn_flash_te + packing=True + non-synthetic data on a GPU with TE >= 2.12
[X] test_gpu_packed_attention (existing, synthetic) still passes
[X] test_gpu_packed_attention_hf (new, HF data) passes on sm90+ hardware
Convergence Validation (H200)
Verified numerical correctness with a 3-way convergence comparison on 8× H200 (sm90), 5000 training steps, streaming
allenai/c4from HuggingFace:nvcr.io/nvidia/jax:26.03-maxtext-py3dot_productnvcr.io/nvidia/jax:26.03-maxtext-py3cudnn_flash_teis_thd=True, is_segment_ids_reordered=Falsenvcr.io/nvidia/jax:26.01-maxtext-py3cudnn_flash_teis_thdkwarg)Model config:
base_emb_dim=1024,base_mlp_dim=4096,base_num_decoder_layers=8,head_dim=128,per_device_batch_size=1,packing=True,max_target_length=512,tokenizer=google-t5/t5-largeLoss comparison at milestones (click to expand)
Results:
Conclusion: The PR fix (
is_thd=True, is_segment_ids_reordered=False) on TE 2.13 is numerically equivalent to the old TE 2.10 behavior. Correctness is preserved.Training logs — Runs 1 & 2 (TE 2.13,
nvcr.io/nvidia/jax:26.03-maxtext-py3)Training log — Run 3 (TE 2.10,
nvcr.io/nvidia/jax:26.01-maxtext-py3)Convergence Validation (MI355X / ROCm)
Verified numerical correctness with a 3-way convergence comparison on 8× MI355X (gfx950), 5000 training steps, streaming
allenai/c4from HuggingFace:dot_productcudnn_flash_teis_thd=True, is_segment_ids_reordered=Falsecudnn_flash_teis_thdkwarg)Pytest Integration Tests (TE 2.12.0)
test_gpu_packed_attentiontest_gpu_packed_attention_hfNote: Tests required patching the
compute_capability >= 9.0skip guard to accept ROCmgfx950devices.Convergence Results
Results:
NOTES
The runners that the CI uses A100 for test, so these tests
test_gpu_packed_attentionandtest_gpu_packed_attention _hfdon't actually get run, right?Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.