Skip to content

Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589

Closed
yeandy wants to merge 3 commits into
AI-Hypercomputer:mainfrom
ROCm:yeandy/te-sequence-descriptor-api
Closed

Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589
yeandy wants to merge 3 commits into
AI-Hypercomputer:mainfrom
ROCm:yeandy/te-sequence-descriptor-api

Conversation

@yeandy

@yeandy yeandy commented Apr 7, 2026

Copy link
Copy Markdown
Contributor

Description

  • Fix SequenceDescriptor.from_segment_ids_and_pos() calls in cudnn_flash_attention to pass the two new required keyword arguments (is_thd, is_segment_ids_reordered) added in TransformerEngine v2.12
  • Add regression test test_gpu_packed_attention_hf that exercises the actual SequenceDescriptor codepath with real HF data

TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523, merged Dec 31, 2025) made is_thd and is_segment_ids_reordered required keyword-only arguments on SequenceDescriptor.from_segment_ids_and_pos() to fix incorrect segment position calculation for THD layouts. This is a breaking change documented in the TE v2.12 release notes.
Since MaxText's cuda12-requirements.txt pins transformer-engine>=2.9.0, any fresh install picks up TE >= 2.12, which breaks any configuration using attention="cudnn_flash_te" + packing=True + real data (dataset_type != "synthetic") with:

TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required
keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered'

The ROCm fork of TransformerEngine picked up the same change via its IFU 2.12 merge (ccda1a5, Apr 3, 2026).

Fix

The packing branch in cudnn_flash_attention uses qkv_layout="THD_THD_THD" with standard (non-reordered) segment IDs, so the correct values are:

  • is_thd=True — THD packed layout
  • is_segment_ids_reordered=False — no load-balancing reordering applied

Why this wasn't caught by existing tests

The existing test_gpu_packed_attention uses dataset_type=synthetic, which takes the elif branch in cudnn_flash_attention and sets attn_mask = None — completely bypassing SequenceDescriptor.from_segment_ids_and_pos(). The new test_gpu_packed_attention_hf uses HF parquet data to exercise the actual codepath.

Tests

[X] Verify the fix with a training run using attention=cudnn_flash_te + packing=True + non-synthetic data on a GPU with TE >= 2.12
[X] test_gpu_packed_attention (existing, synthetic) still passes
[X] test_gpu_packed_attention_hf (new, HF data) passes on sm90+ hardware

Convergence Validation (H200)

Verified numerical correctness with a 3-way convergence comparison on 8× H200 (sm90), 5000 training steps, streaming allenai/c4 from HuggingFace:

Run Container TE Version Attention API
1 (reference) nvcr.io/nvidia/jax:26.03-maxtext-py3 2.13.0 dot_product N/A
2 (PR fix) nvcr.io/nvidia/jax:26.03-maxtext-py3 2.13.0 cudnn_flash_te is_thd=True, is_segment_ids_reordered=False
3 (old TE) nvcr.io/nvidia/jax:26.01-maxtext-py3 2.10.0 cudnn_flash_te Old API (no is_thd kwarg)

Model config: base_emb_dim=1024, base_mlp_dim=4096, base_num_decoder_layers=8, head_dim=128, per_device_batch_size=1, packing=True, max_target_length=512, tokenizer=google-t5/t5-large

Loss comparison at milestones (click to expand)
Step dot_product (ref) te+fix (TE 2.13) te+old (TE 2.10) fix vs ref old vs ref
0 10.9452 10.9453 10.9453 0.00% 0.00%
100 6.7611 6.7614 6.7615 0.01% 0.01%
200 6.4777 6.4832 6.4967 0.08% 0.29%
500 5.8637 5.8760 5.8821 0.21% 0.32%
1000 5.6764 5.6769 5.6863 0.01% 0.17%
1500 5.9349 5.9302 5.9131 0.08% 0.37%
2000 5.5790 5.5958 5.5812 0.30% 0.04%
2500 5.5399 5.5718 5.5403 0.58% 0.01%
3000 5.5533 5.5516 5.5358 0.03% 0.31%
3500 5.5888 5.6207 5.5903 0.57% 0.03%
4000 5.2008 5.2163 5.2319 0.30% 0.60%
4500 5.4246 5.4533 5.4582 0.53% 0.62%
4999 5.3046 5.3288 5.3191 0.46% 0.27%

Results:

  • Final loss: 5.305 (ref) / 5.329 (fix, TE 2.13) / 5.319 (old, TE 2.10)
  • Avg per-step divergence vs reference: 0.36% (fix) / 0.26% (old TE)
  • All three runs converge to the same loss range (~5.3). Per-step differences stay < 0.7%, consistent with expected numerical noise from different attention kernel implementations.

Conclusion: The PR fix (is_thd=True, is_segment_ids_reordered=False) on TE 2.13 is numerically equivalent to the old TE 2.10 behavior. Correctness is preserved.

Training logs — Runs 1 & 2 (TE 2.13, nvcr.io/nvidia/jax:26.03-maxtext-py3)
======================================================================
GPUs:               8 x NVIDIA H200
Compute cap:        9.0
JAX:                0.9.0.dev20260205+6e29effa6
TransformerEngine:  2.13.0+287770466
======================================================================

Starting: conv_dot_product  attention=dot_product  steps=5000
Done: conv_dot_product  218.5s

Starting: conv_cudnn_te_fix  attention=cudnn_flash_te  steps=5000
Done: conv_cudnn_te_fix  417.1s
Training log — Run 3 (TE 2.10, nvcr.io/nvidia/jax:26.01-maxtext-py3)
======================================================================
GPUs:               8 x NVIDIA H200
Compute cap:        9.0
JAX:                0.8.1.dev20251212+6ab1fef24
TransformerEngine:  2.10.0+769ed7783
======================================================================

Starting: conv_cudnn_te_old  attention=cudnn_flash_te  steps=5000
Done: conv_cudnn_te_old  227.1s

Convergence Validation (MI355X / ROCm)

Verified numerical correctness with a 3-way convergence comparison on 8× MI355X (gfx950), 5000 training steps, streaming allenai/c4 from HuggingFace:

Run Environment TE Version Attention API
1 (reference) ROCm 7.2 / JAX 0.8.2 / TE 2.12 2.12.0.dev0+1c949a56 dot_product N/A
2 (PR fix) ROCm 7.2 / JAX 0.8.2 / TE 2.12 2.12.0.dev0+1c949a56 cudnn_flash_te is_thd=True, is_segment_ids_reordered=False
3 (old TE) ROCm 7.1 / JAX 0.8.2 / TE 2.8 2.8.0.dev0+aec00a7f cudnn_flash_te Old API (no is_thd kwarg)

Pytest Integration Tests (TE 2.12.0)

Test Result
test_gpu_packed_attention ✅ PASSED
test_gpu_packed_attention_hf ✅ PASSED

Note: Tests required patching the compute_capability >= 9.0 skip guard to accept ROCm gfx950 devices.

Convergence Results

Step dot_product (ref) te+fix (TE 2.12) te+old (TE 2.8) fix vs ref old vs ref
0 10.945 10.945 10.945 0.00% 0.00%
500 5.874 5.864 5.883 0.17% 0.16%
1000 5.687 5.663 5.680 0.43% 0.13%
2000 5.589 5.608 5.622 0.34% 0.59%
3000 5.543 5.569 5.581 0.47% 0.69%
4000 5.204 5.215 5.249 0.21% 0.85%
4999 5.291 5.291 5.363 0.00% 1.35%

Results:

  • Final loss: 5.291 (ref) / 5.291 (fix) / 5.363 (old TE)
  • Avg per-step divergence vs reference: 0.25% (fix) / 0.62% (old TE)
  • All three runs converge to the same loss range. The fix (TE 2.12) tracks the reference more closely than old TE (2.8).

NOTES

The runners that the CI uses A100 for test, so these tests test_gpu_packed_attention and test_gpu_packed_attention _hf don't actually get run, right?

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

yeandy added 3 commits April 7, 2026 09:59
TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd`
and `is_segment_ids_reordered` required keyword arguments on
`SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect
segment position calculation for THD layouts.

Since the packing branch in `cudnn_flash_attention` uses
`qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs,
the correct values are `is_thd=True, is_segment_ids_reordered=False`.

Without this fix, any configuration using `attention="cudnn_flash_te"`
with `packing=True` and real data (`dataset_type != "synthetic"`) fails
with:
  TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2
  required keyword-only arguments: 'is_thd' and
  'is_segment_ids_reordered'
The existing test_gpu_packed_attention uses dataset_type=synthetic,
which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely
(takes the elif branch in cudnn_flash_attention). This meant the
TE v2.12 API breakage went undetected.

Add test_gpu_packed_attention_hf that uses HF parquet data with
packing=True + attention=cudnn_flash_te, exercising the actual
SequenceDescriptor codepath. This serves as a regression test for
NVIDIA/TransformerEngine#2523.
@yeandy

yeandy commented Apr 16, 2026

Copy link
Copy Markdown
Contributor Author

@NuojCheng can you please take a look? I see you have most recently worked on this section of the code.

I can mark "Ready for review" now, but it will ping a bunch of reviewers. Want to minimize noise 😄 But we still may need to mark as ready if we want to run the GH actions workflow to run tests?

@yeandy

yeandy commented Apr 27, 2026

Copy link
Copy Markdown
Contributor Author

Hi @NuojCheng, following up to see how best to proceed. Thanks!

@github-actions

Copy link
Copy Markdown

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label May 27, 2026
@yeandy yeandy closed this May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Automatically applied to stale PRs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant