[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829sudhakarsingh27 wants to merge 63 commits into
Conversation
… cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1164a15 to
b4db9eb
Compare
for more information, see https://pre-commit.ci
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) | ||
| k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( |
There was a problem hiding this comment.
This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd
There was a problem hiding this comment.
Resolved on the current branch, with final cleanup in 0e926c42. The THD reorder entry points are now reorder_thd_sequences_to_rank_sharded and reorder_thd_sequences_to_contiguous, and the stale Python permutation helpers were removed. Both wrappers call the fused tex.thd_reorder path, so this logic is no longer A2A-named or A2A-specific.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…formerEngine into cp_thd_swa_with_ag
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…formerEngine into cp_thd_swa_with_ag
Greptile SummaryThis PR adds THD (variable-length sequence) format support to
Confidence Score: 5/5The THD AllGather forward and backward paths are logically sound and the previously identified stream-race and dead-variable issues are resolved; no correctness regressions are visible in the changed paths. All previously flagged stream synchronization issues are addressed. The new CUDA kernels are verified by unit tests against legacy Python references. The offset-based Q-chunking, per-step KV visibility, scatter_add_ mask fix, and dK/dV accumulation logic are correct. Only minor dead-code and documentation observations remain. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py (complex per-step THD logic); transformer_engine/common/fused_attn/context_parallel.cu (new CUDA kernels) Important Files Changed
Sequence DiagramsequenceDiagram
participant C as Caller
participant F as AttnFuncWithCPAndKVAllGather.forward
participant AG as gather_along_first_dim
participant R as reorder_thd_sequences_to_contiguous
participant S0 as Step 0 current_stream
participant S1 as Step 1 cp_stream
participant V as tex.thd_valid_copy
participant AR as AllReduce max_logit
C->>F: q,k,v [t,h,d] cu_seqlens cu_seqlens_padded
F->>AG: "k,v gather to k_ag,v_ag [cp*t,h,d]"
F->>R: k_ag reorder to contiguous per-sequence order
Note over F: cp_stream.wait_stream current_stream
F->>F: compute per-step cu_seqlens_q_padded
F->>F: compute per-step cu_seqlens_kv
par step 0 on current_stream
F->>S0: fused_attn_fwd q_full k_ag v_ag cu_seqlens_q_padded_step0
S0-->>F: out_per_step[0]
and step 1 on cp_stream
F->>S1: fused_attn_fwd q_full k_ag v_ag cu_seqlens_q_padded_step1
S1-->>F: out_per_step[1]
end
F->>V: thd_valid_copy out out_per_step[0] padded_step0 valid_step0
F->>V: thd_valid_copy out out_per_step[1] padded_step1 valid_step1
Note over F: current_stream.wait_stream cp_stream
F->>AR: all_reduce max_logit MAX
F-->>C: out [T,h,d] max_logit [h]
Reviews (12): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| # dK/dV: add full tensor (kernel zeros non-valid positions) | ||
| if i > 1: | ||
| flash_attn_streams[i - 1].wait_event(dkv_update_done) | ||
| dk.add_(dk_per_step[i - 1]) | ||
| dv.add_(dv_per_step[i - 1]) |
There was a problem hiding this comment.
THD backward dK/dV relies on unverified cuDNN zeroing behavior
The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.
Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.
There was a problem hiding this comment.
Updated in 0e926c42. The code no longer claims or depends on cuDNN zeroing non-valid positions. The dK/dV accumulation is treated row-wise: padding rows may accumulate, but valid rows are independent and only valid rows are copied/scattered after the per-step reductions. Revalidated today with the focused FusedAttention THD all_gather pytest (4 passed, 476 skipped, 9824 deselected) and the FA3 THD all_gather bucket32k correctness runner.
| # [AG+THD] Is this needed? | ||
| visible_actual = [ | ||
| torch.minimum(actual_seqlens_kv, visible_padded_split) | ||
| for visible_padded_split in visible_padded | ||
| ] |
There was a problem hiding this comment.
Unresolved development comment left in production code
# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.
There was a problem hiding this comment.
Resolved on the current branch. The open-question development comment is gone; the clamp remains because padded chunk endpoints can exceed actual sequence lengths, and the current comments distinguish padded offsets from valid/unpadded KV visibility.
| if ctx.qkv_format == "thd": | ||
| cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded | ||
| thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step | ||
| cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 |
There was a problem hiding this comment.
Resolved in 0e926c42; the unused backward cu_seqlens_q_padded_rank assignment was removed.
…ific helpers The AllGather THD path was not extending KV visibility beyond the causal boundary when window_size had a right component > 0, meaning tokens right of the diagonal were invisible to the kernel. Fix by adding window_size[1] to visible_padded (clamped at actual seqlen) and max_seqlen_kv_. Also rename reorder helpers to backend-neutral names since AllGather now uses them too, and add a clarifying comment for non-causal KV cu_seqlens. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
| pytest.skip( | ||
| "FlashAttention does not support THD padding; use FusedAttention for" | ||
| " THD+all_gather CP." |
There was a problem hiding this comment.
Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)
There was a problem hiding this comment.
Resolved in the current test wording. The skip message now says the CP implementation with QKVO A2A+P2P / hierarchical A2A does not support THD format, rather than implying FlashAttention itself does not support THD.
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
This "for" loop might be too costly. Could this logic be written in another way in Python, or simply in C++/CUDA? Do we have some thd kernels that already do this?
There was a problem hiding this comment.
Resolved by moving the THD reorder path to the fused kernel. The current helpers call tex.thd_reorder directly, and 0e926c42 removed the remaining dead Python permutation builders.
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
Same here, regarding the "for" loop.
There was a problem hiding this comment.
Resolved by the same fused reorder path as the forward case. The current THD contiguous/rank-sharded reorder helpers call tex.thd_reorder; 0e926c42 removed the stale Python permutation builders.
| pytest.skip("THD format does not support post_scale_bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") |
There was a problem hiding this comment.
A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.
There was a problem hiding this comment.
Status: not claiming this broad offline request as fully closed yet. Current validation covers the focused FusedAttention THD all_gather pytest, FA3 THD all_gather bucket32k correctness, and prior cp2/cp4/cp8 AG-vs-a2a sweeps. Plan is to run the full CP file with test_essential=False offline after the lint/review blockers are clear and report that result back here.
…formerEngine into cp_thd_swa_with_ag
| elif qkv_format == "thd": | ||
| # Copy valid token ranges from this step's output. | ||
| # Each step writes at different positions (no overlap, no correction needed). | ||
| step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] | ||
| step_valid = thd_cu_seqlens_q_per_step[i - 1] | ||
| batch_size = step_valid.shape[0] - 1 | ||
| for b in range(batch_size): | ||
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
THD forward output copy runs on wrong stream for step 1
For step 1 (i == 2), the with torch.cuda.stream(flash_attn_streams[i - 1]) block streams the copy onto cp_stream. But out_per_step[1] is produced on cp_stream by the step-1 attention kernel, and the copy also runs on cp_stream, so there is no race between the kernel and the copy.
However, the if return_max_logit: block at line 3172–3173 runs outside the with block (i.e. on the default stream) and reads max_logit_per_step[i - 1] — which was produced on cp_stream for step 1. There is no current_stream.wait_stream(cp_stream) before this point (that sync only happens at line 3175, after the loop). As a result the torch.maximum kernel launched at line 3173 on the default stream can race against the step-1 attention kernel still running on cp_stream.
This is a pre-existing pattern, but it is newly exercised by THD+AllGather since this is the first path that has both return_max_logit=True and dual-stream THD step execution. Moving the max_logit merge inside the with torch.cuda.stream(flash_attn_streams[i - 1]): block would eliminate the race.
There was a problem hiding this comment.
Addressed on the current branch. The max_logit merge now waits before the default-stream torch.maximum reads a per-step result produced on cp_stream:
torch.cuda.current_stream().wait_stream(flash_attn_streams[i - 1])
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])That wait is intentionally before the merge; the later post-loop wait_stream(cp_stream) was too late for this read.
Validation: FusedAttention all_gather THD pytest passed (4 passed, 476 skipped, 9824 deselected).
There was a problem hiding this comment.
Still resolved on the current branch. The earlier fix adds the needed stream waits before default-stream max-logit merging, and today
There was a problem hiding this comment.
Complete validation note for the previous reply: after 0e926c42, the focused FusedAttention THD all_gather pytest still passes (4 passed, 476 skipped, 9824 deselected) and the FA3 THD all_gather bucket32k correctness runner exits 0.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The CP THD load-balancing reorder built its gather/scatter permutation via tex.thd_get_partitioned_indices and then advanced-indexed in Python, with a per-call cu_seqlens.tolist() D2H sync to size the permutation. That sync serialized the host against the stream and left the GPU ~41% idle in the a2a THD path (nvbug 6179415), so THD ran well behind BSHD under CP even though both feed cu_seqlens to the same attention kernels. Replace it with a single fused CUDA kernel (nvte_cp_thd_reorder): one warp per token, float4-vectorized copy, dual-chunk source index computed on-device from cu_seqlens. total_tokens comes from x.shape so the launch needs no D2H sync. The kernel reuses the existing common THD helpers (binary_search and the factored thd_partition_src_index) rather than duplicating index math. Bit-identical to the Python reorder; 4-12x faster on the reorder itself and removes the GPU bubble (THD/BSHD a2a gap at cp2/nseg1 closes from ~1.5x to ~1.15x, the remainder being intrinsic cuDNN varlen-vs-dense kernel cost). _get_thd_reorder_perms is now sync-free (total_tokens passed in from x.shape; the .item() fallback is dead/defensive only). Also adds temporary, env-gated (NVTE_CP_PROFILE / NVTE_CP_NVTX) CUDA-event and NVTX instrumentation in context_parallel.py used to locate this bubble; inert unless enabled and slated for removal in PR cleanup. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
| for b_idx in range(b): | ||
| start = cu_seqlens_q_padded[b_idx].item() | ||
| n_valid = actual_seqlens[b_idx].item() | ||
| valid[start : start + n_valid] = True |
There was a problem hiding this comment.
D2H synchronization on the training hot path
The new Python loop introduces 2 × batch_size device-to-host synchronizations via .item() on every forward step in which return_max_logit=True (the AllGather CP path) and pad_between_seqs=True. The old code was fully vectorized and stayed entirely on the GPU. With batch sizes typical for variable-length training (32+), this adds 64+ blocking D2H round-trips per step, stalling the GPU pipeline and reducing throughput.
A vectorized replacement that still handles non-zero starting offsets without .item():
tq = max_tensor.shape[0]
starts = cu_seqlens_q_padded[:-1].to(device=max_tensor.device)
ends = (starts + actual_seqlens).clamp(max=tq)
delta = torch.zeros(tq + 1, dtype=torch.int32, device=max_tensor.device)
delta.scatter_add_(0, starts.long(), torch.ones(b, dtype=torch.int32, device=max_tensor.device))
delta.scatter_add_(0, ends.long(), -torch.ones(b, dtype=torch.int32, device=max_tensor.device))
valid = delta[:tq].cumsum(0) > 0This keeps the fix for non-zero step-1 offsets while avoiding all D2H traffic.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Fixed in b897900d. The THD max-logit path no longer builds the valid-token mask with per-batch .item() calls. It now builds the same interval mask on GPU using start/end deltas and scatter_add_, preserving non-zero cu_seqlens_q_padded offsets without D2H synchronization.
Validation: python3 -m black --check transformer_engine/pytorch/cpp_extensions/fused_attn.py passed, and FusedAttention all_gather THD pytest passed (4 passed, 476 skipped, 9824 deselected).
There was a problem hiding this comment.
Still resolved by b897900d and revalidated today after the review-fix commit. The THD max-logit mask is built on GPU with start/end deltas and scatter_add_, without the previous per-batch .item() loop. Focused FusedAttention THD all_gather pytest passed again after 0e926c42.
The AllGather CP THD path copied each step's valid output/dQ rows into the accumulator with a per-batch Python loop that called .item() twice per segment to read the padded start and valid length. Those D2H syncs stalled the host so it could not run ahead to keep the NCCL all-gather overlapped with compute: at cp2/nseg1/8k the dQ-copy loop alone occupied ~63% of the backward wall (mostly idle), leaving THD ~1.55x BSHD even though THD does only ~1.14x the GPU work (overlap 1.57x vs BSHD's 2.14x). Replace both loops with a single fused kernel (nvte_cp_thd_valid_copy): warp-per-token, float4-vectorized, segment found by binary_search on cu_seqlens_padded, copies row t iff its local offset is in [0, valid_len). The local>=0 guard is required because step-1 chunk starts are shifted past earlier tokens (cu_seqlens_padded[:-1] += chunk_size), so a token before a segment's padded start must be skipped rather than clobbering an already-written first-chunk row. total_tokens comes from x.shape, so no D2H. Reuses the existing binary_search helper; mirrors thd_reorder_kernel. Bit-identical to the Python loop (incl. the shifted step-1 layout). Restores comm/compute overlap (1.57x -> 1.78x): all_gather cp2/nseg1 8k THD/BSHD 1.55x -> 1.35x, and at realistic multi-sequence packing THD now beats BSHD (nseg=4 0.97x, nseg=8 0.94x). All CP all_gather THD tests pass. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
| if qkv_format == "thd" and not use_fused_attention and not use_flash_attn_3: | ||
| assert False, ( |
There was a problem hiding this comment.
the condition should be inside the assert
There was a problem hiding this comment.
Fixed in 0e926c42; the condition is now the assert predicate with the explanatory message, instead of if ...: assert False.
The fused tex.thd_reorder (2dc5c15) runs on the main stream and, unlike the index_select it replaced, its buffers are not allocator/stream-tracked. Its faster sync-free execution lets the host run ahead into the 2-stream per-step loop (step 1 on cp_stream), and PyTorch's caching allocator can recycle a block the reorder is still using -> cudaErrorIllegalAddress for FA3 + all_gather + THD on larger-token packs (bucket32k/64k/128k, mixed32k). Reproduces serially; masked by CUDA_LAUNCH_BLOCKING and PYTORCH_NO_CUDA_MEMORY_CACHING, confirming an allocator-reuse race. FusedAttention AG and FA3 a2a are unaffected; pre-kernel TE ran these FA3 AG configs fine, so this is a regression from the kernel work. Drain the main stream after the AG reorder, before the per-step loop allocates. cp_stream already waits on main, so the cp_stream-vs-main per-step overlap is preserved; cost is ~0.5-4.3% (one sync/forward, before the compute loop) vs the 11-30% the kernels recovered. Env-gated AG_REORDER_SYNC (default on) for A/B. A finer allocator/event fix (needs a memory-snapshot trace to pin the block, likely FA3-internal scratch) can replace this conservative drain later. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FA3 hopper allocates per-call scheduler workspace internally, including semaphores that are recorded by the CUDA caching allocator only on the allocation stream. The THD all_gather CP path can launch consecutive FA3 per-step calls on separate streams, allowing a later step to reuse a just-freed workspace block while the prior step's scheduler kernel is still active, which shows up as illegal memory access under overlap. Replace the interim host-side reorder synchronize with FA3-scoped GPU stream waits between consecutive per-step forward and backward calls. This preserves the FusedAttention overlap path and avoids a host drain while preventing overlapping FA3 internal workspace lifetimes. Tests: - CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=/perfhome/llms/repos/te_repos/ag_thd_swa/TransformerEngine python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q - CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 torchrun --nproc-per-node=2 --master-port=47151 /perfhome/llms/repos/te_repos/ag_thd_swa/cp_bench/tests/pytorch/attention/run_attention_with_cp.py dtype=bf16 model=bucket32k qkv_format=thd kernel_backend=FlashAttention cp_comm_type=all_gather benchmark=0 log_level=WARNING thd_seqlen_pattern=24576,28672,30720,32768 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The THD max-logit path built the valid-token mask with per-sequence .item() calls when cu_seqlens_q_padded was provided. In AllGather CP this is on the training hot path and blocks the host for every batch entry. Build the same interval mask on GPU using scatter_add over start/end deltas. This preserves support for non-zero padded offsets while avoiding the device-to-host synchronizations called out in PR review. Tests: - python3 -m black --check transformer_engine/pytorch/cpp_extensions/fused_attn.py - CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=/perfhome/llms/repos/te_repos/ag_thd_swa/TransformerEngine python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Resolve the pybind.cpp tail-binding conflict by preserving upstream grouped_mlp_experimental bindings inside the helper split that keeps PYBIND11_MODULE under cpplint's function-size limit. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" | ||
| assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" |
There was a problem hiding this comment.
Why did we add these two additional checks?
| assert ( | ||
| "padding" not in attn_mask_type | ||
| qkv_format == "thd" or "padding" not in attn_mask_type | ||
| ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." |
There was a problem hiding this comment.
this is a repeated check, from just a few lines above
| seq_dim = qkv_format.index("s") | ||
| assert ( | ||
| q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 | ||
| ), "Sequence length per GPU needs to be divisible by 2!" |
There was a problem hiding this comment.
why are we adding additional checks which didn't exist before
| if qkv_format == "thd": | ||
| out = torch.zeros(o_shape, dtype=fwd_nominal_dtype, device=q.device) | ||
| else: | ||
| out = torch.empty_like(q) |
There was a problem hiding this comment.
is it better to just have torch.zeros for both paths?
| # Copy valid token ranges from this step's output. | ||
| # Each step writes at different positions (no overlap, no correction needed). | ||
| # Sync-free fused copy of every segment's valid rows in one launch (replaces | ||
| # a per-batch .item() slice-copy loop that stalled comm/compute overlap). | ||
| tex.thd_valid_copy( |
There was a problem hiding this comment.
Information from local experiments leaking into the comments. The comments should only tell what the code does and only in most important/pertinent cases share extra information. In this case, if the user hasn't seen the python version, they wouldn't know how to make sense of per-batch .item() or comm/compute overlap
| # calls overlapping across streams, the allocator can hand step i's FA3 the block | ||
| # step i-1's FA3 just freed while step i-1's scheduler kernel is still atomically | ||
| # incrementing it -> shared semaphore -> out-of-range tiles -> illegal memory access | ||
| # (FA3-only, only under overlap; root cause in mb_runs/FA3_AG_RECORDSTREAM_RACE.md |
There was a problem hiding this comment.
Again, leaking experiments data in the comments
|
|
||
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| # Use padded cu_seqlens since reorder computes slice boundaries via integer |
There was a problem hiding this comment.
also note here that the padded cu_seqlens is global
Description
Add THD (variable-length sequence) format support to
AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single[t, h, d]tensor tracked bycu_seqlens, which is needed for workloads with heterogeneous sequence lengths.The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step
cu_seqlens_q_paddedvalues directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
cu_seqlens_q_paddedselects which chunk the kernel reads from the full Q tensor, instead of slicing Q per stepreorder_seq_chunks_*_thdhelpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence ordercu_seqlens_q_paddedin the valid-token mask (step 1's padded offsets don't start at 0)Checklist:
Latest update: THD AllGather CP fixes (2026-06-02)
628f73cc). This prevents overlapping FA3 internal scheduler workspace lifetimes without a host-side synchronize; FusedAttention overlap remains unchanged..item()D2H synchronizations from the THD max-logit valid-token mask by building the interval mask on GPU withscatter_add_(b897900d).Local validation:
CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=$(pwd) python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q-> 4 passed, 476 skipped, 9824 deselected.CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 torchrun --nproc-per-node=2 --master-port=47151 /perfhome/llms/repos/te_repos/ag_thd_swa/cp_bench/tests/pytorch/attention/run_attention_with_cp.py dtype=bf16 model=bucket32k qkv_format=thd kernel_backend=FlashAttention cp_comm_type=all_gather benchmark=0 log_level=WARNING thd_seqlen_pattern=24576,28672,30720,32768-> exit 0.