Skip to content

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829

Open
sudhakarsingh27 wants to merge 63 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag
Open

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
sudhakarsingh27 wants to merge 63 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 commented Apr 3, 2026

Description

Add THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single [t, h, d] tensor tracked by cu_seqlens, which is needed for workloads with heterogeneous sequence lengths.

The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step cu_seqlens_q_padded values directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).

Fixes # (issue)

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

  • Offset-based Q chunking: Per-step cu_seqlens_q_padded selects which chunk the kernel reads from the full Q tensor, instead of slicing Q per step
  • Per-step KV cu_seqlens: Computes visible KV token counts per step for causal masking (chunks 0..chunk_id) and non-causal (all tokens)
  • THD reorder reuse: Reuses the existing reorder_seq_chunks_*_thd helpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence order
  • max_logit masking fix: Handles non-zero-starting cu_seqlens_q_padded in the valid-token mask (step 1's padded offsets don't start at 0)
  • Test gates: Enables THD+all_gather for FusedAttention tests; skips FlashAttention (no THD padding support)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Latest update: THD AllGather CP fixes (2026-06-02)

  • Added fused THD reorder / valid-copy kernels for the AllGather CP THD path.
  • Replaced the interim FA3 crash workaround with FA3-scoped GPU stream waits between consecutive per-step FA3 forward/backward calls (628f73cc). This prevents overlapping FA3 internal scheduler workspace lifetimes without a host-side synchronize; FusedAttention overlap remains unchanged.
  • Removed .item() D2H synchronizations from the THD max-logit valid-token mask by building the interval mask on GPU with scatter_add_ (b897900d).

Local validation:

  • CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=$(pwd) python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q -> 4 passed, 476 skipped, 9824 deselected.
  • CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 torchrun --nproc-per-node=2 --master-port=47151 /perfhome/llms/repos/te_repos/ag_thd_swa/cp_bench/tests/pytorch/attention/run_attention_with_cp.py dtype=bf16 model=bucket32k qkv_format=thd kernel_backend=FlashAttention cp_comm_type=all_gather benchmark=0 log_level=WARNING thd_seqlen_pattern=24576,28672,30720,32768 -> exit 0.
  • Focused serial sweep: all_gather cp2/4/8 and a2a cp2/4/8 causal, plus all_gather cp2 SWA windows, completed 72/72 OK with 0 crashes.

… cu_seqlens

- Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing
- Use padded cu_seqlens_kv for K/V reordering (ensures divisibility)
- Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature
- Compute per-step Q and KV cu_seqlens correctly from actual seqlens
- Support non-causal attention (all KV visible)
- Zero-initialize out/dq for THD to avoid garbage in padding regions
- Save per-step cu_seqlens in ctx for backward (avoid recomputation)

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded

The interleaved valid mask computation assumed cu_seqlens_q_padded starts
at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at
a non-zero offset, causing a size mismatch. Use absolute positions from
cu_seqlens_q_padded to build the valid mask instead.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
if qkv_format == "thd":
# [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d]
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = reorder_seq_chunks_after_a2a_before_attn_thd(
Copy link
Copy Markdown
Member Author

@sudhakarsingh27 sudhakarsingh27 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved on the current branch, with final cleanup in 0e926c42. The THD reorder entry points are now reorder_thd_sequences_to_rank_sharded and reorder_thd_sequences_to_contiguous, and the stale Python permutation helpers were removed. Both wrappers call the fused tex.thd_reorder path, so this logic is no longer A2A-named or A2A-specific.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@sudhakarsingh27 sudhakarsingh27 changed the title Cp thd swa with ag [PyTorch][CP] Add THD format support for AllGather-based Context Parallelism Apr 13, 2026
@sudhakarsingh27 sudhakarsingh27 marked this pull request as ready for review April 13, 2026 21:53
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 13, 2026

Greptile Summary

This PR adds THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather, complementing the existing bshd/sbhd paths. The core mechanism is an offset-based Q-chunking approach: the full Q tensor is passed to each kernel step, and per-step cu_seqlens_q_padded values steer the kernel to the correct in-tensor chunk — avoiding explicit tensor slicing for variable-length sequences.

  • New CUDA kernels (thd_reorder_kernel, thd_valid_copy_kernel) fuse the dual-chunk gather/scatter permutation and the valid-token copy operations; the old Python index_select-based helpers are now thin wrappers around tex.thd_reorder / tex.thd_valid_copy.
  • Per-step KV visibility (cu_seqlens_kv_per_step) is computed upfront for causal and sliding-window attention, with each step trimming the effective KV range as needed.
  • Stream synchronization issues found in earlier review rounds (CP stream race on K/V reorder, default-stream race on max_logit merge) have been fixed; FA3 consecutive-call workspace collisions are serialized with GPU stream waits.

Confidence Score: 5/5

The THD AllGather forward and backward paths are logically sound and the previously identified stream-race and dead-variable issues are resolved; no correctness regressions are visible in the changed paths.

All previously flagged stream synchronization issues are addressed. The new CUDA kernels are verified by unit tests against legacy Python references. The offset-based Q-chunking, per-step KV visibility, scatter_add_ mask fix, and dK/dV accumulation logic are correct. Only minor dead-code and documentation observations remain.

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py (complex per-step THD logic); transformer_engine/common/fused_attn/context_parallel.cu (new CUDA kernels)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core AllGather CP implementation extended for THD: offset-based Q-chunking, per-step cu_seqlens for KV visibility, stream synchronization, and reorder helpers renamed/refactored to use new CUDA kernels. Previously flagged stream-race and dead-variable issues addressed.
transformer_engine/common/fused_attn/context_parallel.cu Three new CUDA kernels: thd_reorder_kernel (fused dual-chunk gather/scatter), thd_valid_copy_kernel (sync-free valid-token copy), and shared thd_partition_src_index device function. Binary search and shared-memory bounds verified correct.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Replaced per-batch .item() D2H loop in max_logit masking with GPU-resident scatter_add_ interval mask; correctly handles non-zero starting offsets in cu_seqlens_q_padded for the offset-based AllGather THD path.
transformer_engine/pytorch/csrc/extensions/attention.cpp C++ bindings for thd_reorder and thd_valid_copy; input validation and stream plumbing are consistent with existing thd helpers.
tests/pytorch/attention/test_cp_utils.py New TestTHDKernels class validates thd_reorder and thd_valid_copy against legacy Python reference implementations across multiple cp_size and sequence configurations.

Sequence Diagram

sequenceDiagram
    participant C as Caller
    participant F as AttnFuncWithCPAndKVAllGather.forward
    participant AG as gather_along_first_dim
    participant R as reorder_thd_sequences_to_contiguous
    participant S0 as Step 0 current_stream
    participant S1 as Step 1 cp_stream
    participant V as tex.thd_valid_copy
    participant AR as AllReduce max_logit
    C->>F: q,k,v [t,h,d] cu_seqlens cu_seqlens_padded
    F->>AG: "k,v gather to k_ag,v_ag [cp*t,h,d]"
    F->>R: k_ag reorder to contiguous per-sequence order
    Note over F: cp_stream.wait_stream current_stream
    F->>F: compute per-step cu_seqlens_q_padded
    F->>F: compute per-step cu_seqlens_kv
    par step 0 on current_stream
        F->>S0: fused_attn_fwd q_full k_ag v_ag cu_seqlens_q_padded_step0
        S0-->>F: out_per_step[0]
    and step 1 on cp_stream
        F->>S1: fused_attn_fwd q_full k_ag v_ag cu_seqlens_q_padded_step1
        S1-->>F: out_per_step[1]
    end
    F->>V: thd_valid_copy out out_per_step[0] padded_step0 valid_step0
    F->>V: thd_valid_copy out out_per_step[1] padded_step1 valid_step1
    Note over F: current_stream.wait_stream cp_stream
    F->>AR: all_reduce max_logit MAX
    F-->>C: out [T,h,d] max_logit [h]
Loading

Reviews (12): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +3436 to +3440
# dK/dV: add full tensor (kernel zeros non-valid positions)
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.add_(dk_per_step[i - 1])
dv.add_(dv_per_step[i - 1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 THD backward dK/dV relies on unverified cuDNN zeroing behavior

The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.

Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 0e926c42. The code no longer claims or depends on cuDNN zeroing non-valid positions. The dK/dV accumulation is treated row-wise: padding rows may accumulate, but valid rows are independent and only valid rows are copied/scattered after the per-step reductions. Revalidated today with the focused FusedAttention THD all_gather pytest (4 passed, 476 skipped, 9824 deselected) and the FA3 THD all_gather bucket32k correctness runner.

Comment on lines +3007 to +3011
# [AG+THD] Is this needed?
visible_actual = [
torch.minimum(actual_seqlens_kv, visible_padded_split)
for visible_padded_split in visible_padded
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unresolved development comment left in production code

# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved on the current branch. The open-question development comment is gone; the clamp remains because padded chunk endpoints can exceed actual sequence lengths, and the current comments distinguish padded offsets from valid/unpadded KV visibility.

if ctx.qkv_format == "thd":
cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded
thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step
cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead variable in backward pass

cu_seqlens_q_padded_rank is computed here but never read in the backward. The padded offsets are loaded from ctx.thd_cu_seqlens_q_padded_per_step a few lines later. This line can be removed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in 0e926c42; the unused backward cu_seqlens_q_padded_rank assignment was removed.

sudhakarsingh27 and others added 2 commits April 16, 2026 11:28
…ific helpers

The AllGather THD path was not extending KV visibility beyond the causal
boundary when window_size had a right component > 0, meaning tokens right
of the diagonal were invisible to the kernel. Fix by adding window_size[1]
to visible_padded (clamped at actual seqlen) and max_seqlen_kv_.

Also rename reorder helpers to backend-neutral names since AllGather now
uses them too, and add a clarifying comment for non-causal KV cu_seqlens.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
pytest.skip(
"FlashAttention does not support THD padding; use FusedAttention for"
" THD+all_gather CP."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in the current test wording. The skip message now says the CP implementation with QKVO A2A+P2P / hierarchical A2A does not support THD format, rather than implying FlashAttention itself does not support THD.

s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "for" loop might be too costly. Could this logic be written in another way in Python, or simply in C++/CUDA? Do we have some thd kernels that already do this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by moving the THD reorder path to the fused kernel. The current helpers call tex.thd_reorder directly, and 0e926c42 removed the remaining dead Python permutation builders.

s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, regarding the "for" loop.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by the same fused reorder path as the forward case. The current THD contiguous/rank-sharded reorder helpers call tex.thd_reorder; 0e926c42 removed the stale Python permutation builders.

pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Status: not claiming this broad offline request as fully closed yet. Current validation covers the focused FusedAttention THD all_gather pytest, FA3 THD all_gather bucket32k correctness, and prior cp2/cp4/cp8 AG-vs-a2a sweeps. Plan is to run the full CP file with test_essential=False offline after the lint/review blockers are clear and report that result back here.

Comment on lines +3160 to +3170
elif qkv_format == "thd":
# Copy valid token ranges from this step's output.
# Each step writes at different positions (no overlap, no correction needed).
step_padded = thd_cu_seqlens_q_padded_per_step[i - 1]
step_valid = thd_cu_seqlens_q_per_step[i - 1]
batch_size = step_valid.shape[0] - 1
for b in range(batch_size):
s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 THD forward output copy runs on wrong stream for step 1

For step 1 (i == 2), the with torch.cuda.stream(flash_attn_streams[i - 1]) block streams the copy onto cp_stream. But out_per_step[1] is produced on cp_stream by the step-1 attention kernel, and the copy also runs on cp_stream, so there is no race between the kernel and the copy.

However, the if return_max_logit: block at line 3172–3173 runs outside the with block (i.e. on the default stream) and reads max_logit_per_step[i - 1] — which was produced on cp_stream for step 1. There is no current_stream.wait_stream(cp_stream) before this point (that sync only happens at line 3175, after the loop). As a result the torch.maximum kernel launched at line 3173 on the default stream can race against the step-1 attention kernel still running on cp_stream.

This is a pre-existing pattern, but it is newly exercised by THD+AllGather since this is the first path that has both return_max_logit=True and dual-stream THD step execution. Moving the max_logit merge inside the with torch.cuda.stream(flash_attn_streams[i - 1]): block would eliminate the race.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed on the current branch. The max_logit merge now waits before the default-stream torch.maximum reads a per-step result produced on cp_stream:

torch.cuda.current_stream().wait_stream(flash_attn_streams[i - 1])
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])

That wait is intentionally before the merge; the later post-loop wait_stream(cp_stream) was too late for this read.

Validation: FusedAttention all_gather THD pytest passed (4 passed, 476 skipped, 9824 deselected).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still resolved on the current branch. The earlier fix adds the needed stream waits before default-stream max-logit merging, and today

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete validation note for the previous reply: after 0e926c42, the focused FusedAttention THD all_gather pytest still passes (4 passed, 476 skipped, 9824 deselected) and the FA3 THD all_gather bucket32k correctness runner exits 0.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
The CP THD load-balancing reorder built its gather/scatter permutation
via tex.thd_get_partitioned_indices and then advanced-indexed in Python,
with a per-call cu_seqlens.tolist() D2H sync to size the permutation.
That sync serialized the host against the stream and left the GPU ~41%
idle in the a2a THD path (nvbug 6179415), so THD ran well behind BSHD
under CP even though both feed cu_seqlens to the same attention kernels.

Replace it with a single fused CUDA kernel (nvte_cp_thd_reorder): one
warp per token, float4-vectorized copy, dual-chunk source index computed
on-device from cu_seqlens. total_tokens comes from x.shape so the launch
needs no D2H sync. The kernel reuses the existing common THD helpers
(binary_search and the factored thd_partition_src_index) rather than
duplicating index math. Bit-identical to the Python reorder; 4-12x faster
on the reorder itself and removes the GPU bubble (THD/BSHD a2a gap at
cp2/nseg1 closes from ~1.5x to ~1.15x, the remainder being intrinsic
cuDNN varlen-vs-dense kernel cost).

_get_thd_reorder_perms is now sync-free (total_tokens passed in from
x.shape; the .item() fallback is dead/defensive only).

Also adds temporary, env-gated (NVTE_CP_PROFILE / NVTE_CP_NVTX) CUDA-event
and NVTX instrumentation in context_parallel.py used to locate this bubble;
inert unless enabled and slated for removal in PR cleanup.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment on lines +376 to +379
for b_idx in range(b):
start = cu_seqlens_q_padded[b_idx].item()
n_valid = actual_seqlens[b_idx].item()
valid[start : start + n_valid] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 D2H synchronization on the training hot path

The new Python loop introduces 2 × batch_size device-to-host synchronizations via .item() on every forward step in which return_max_logit=True (the AllGather CP path) and pad_between_seqs=True. The old code was fully vectorized and stayed entirely on the GPU. With batch sizes typical for variable-length training (32+), this adds 64+ blocking D2H round-trips per step, stalling the GPU pipeline and reducing throughput.

A vectorized replacement that still handles non-zero starting offsets without .item():

tq = max_tensor.shape[0]
starts = cu_seqlens_q_padded[:-1].to(device=max_tensor.device)
ends   = (starts + actual_seqlens).clamp(max=tq)
delta = torch.zeros(tq + 1, dtype=torch.int32, device=max_tensor.device)
delta.scatter_add_(0, starts.long(), torch.ones(b, dtype=torch.int32, device=max_tensor.device))
delta.scatter_add_(0, ends.long(),  -torch.ones(b, dtype=torch.int32, device=max_tensor.device))
valid = delta[:tq].cumsum(0) > 0

This keeps the fix for non-zero step-1 offsets while avoiding all D2H traffic.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b897900d. The THD max-logit path no longer builds the valid-token mask with per-batch .item() calls. It now builds the same interval mask on GPU using start/end deltas and scatter_add_, preserving non-zero cu_seqlens_q_padded offsets without D2H synchronization.

Validation: python3 -m black --check transformer_engine/pytorch/cpp_extensions/fused_attn.py passed, and FusedAttention all_gather THD pytest passed (4 passed, 476 skipped, 9824 deselected).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still resolved by b897900d and revalidated today after the review-fix commit. The THD max-logit mask is built on GPU with start/end deltas and scatter_add_, without the previous per-batch .item() loop. Focused FusedAttention THD all_gather pytest passed again after 0e926c42.

sudhakarsingh27 and others added 2 commits May 30, 2026 00:23
The AllGather CP THD path copied each step's valid output/dQ rows into the
accumulator with a per-batch Python loop that called .item() twice per
segment to read the padded start and valid length. Those D2H syncs stalled
the host so it could not run ahead to keep the NCCL all-gather overlapped
with compute: at cp2/nseg1/8k the dQ-copy loop alone occupied ~63% of the
backward wall (mostly idle), leaving THD ~1.55x BSHD even though THD does
only ~1.14x the GPU work (overlap 1.57x vs BSHD's 2.14x).

Replace both loops with a single fused kernel (nvte_cp_thd_valid_copy):
warp-per-token, float4-vectorized, segment found by binary_search on
cu_seqlens_padded, copies row t iff its local offset is in [0, valid_len).
The local>=0 guard is required because step-1 chunk starts are shifted past
earlier tokens (cu_seqlens_padded[:-1] += chunk_size), so a token before a
segment's padded start must be skipped rather than clobbering an
already-written first-chunk row. total_tokens comes from x.shape, so no D2H.
Reuses the existing binary_search helper; mirrors thd_reorder_kernel.

Bit-identical to the Python loop (incl. the shifted step-1 layout). Restores
comm/compute overlap (1.57x -> 1.78x): all_gather cp2/nseg1 8k THD/BSHD
1.55x -> 1.35x, and at realistic multi-sequence packing THD now beats BSHD
(nseg=4 0.97x, nseg=8 0.94x). All CP all_gather THD tests pass.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment thread transformer_engine/common/fused_attn/context_parallel.cu
Comment on lines +3253 to +3254
if qkv_format == "thd" and not use_fused_attention and not use_flash_attn_3:
assert False, (
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the condition should be inside the assert

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 0e926c42; the condition is now the assert predicate with the explanatory message, instead of if ...: assert False.

The fused tex.thd_reorder (2dc5c15) runs on the main stream and, unlike the
index_select it replaced, its buffers are not allocator/stream-tracked. Its
faster sync-free execution lets the host run ahead into the 2-stream per-step
loop (step 1 on cp_stream), and PyTorch's caching allocator can recycle a block
the reorder is still using -> cudaErrorIllegalAddress for FA3 + all_gather + THD
on larger-token packs (bucket32k/64k/128k, mixed32k). Reproduces serially;
masked by CUDA_LAUNCH_BLOCKING and PYTORCH_NO_CUDA_MEMORY_CACHING, confirming an
allocator-reuse race. FusedAttention AG and FA3 a2a are unaffected; pre-kernel
TE ran these FA3 AG configs fine, so this is a regression from the kernel work.

Drain the main stream after the AG reorder, before the per-step loop allocates.
cp_stream already waits on main, so the cp_stream-vs-main per-step overlap is
preserved; cost is ~0.5-4.3% (one sync/forward, before the compute loop) vs the
11-30% the kernels recovered. Env-gated AG_REORDER_SYNC (default on) for A/B.
A finer allocator/event fix (needs a memory-snapshot trace to pin the block,
likely FA3-internal scratch) can replace this conservative drain later.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FA3 hopper allocates per-call scheduler workspace internally, including semaphores that are recorded by the CUDA caching allocator only on the allocation stream. The THD all_gather CP path can launch consecutive FA3 per-step calls on separate streams, allowing a later step to reuse a just-freed workspace block while the prior step's scheduler kernel is still active, which shows up as illegal memory access under overlap.

Replace the interim host-side reorder synchronize with FA3-scoped GPU stream waits between consecutive per-step forward and backward calls. This preserves the FusedAttention overlap path and avoids a host drain while preventing overlapping FA3 internal workspace lifetimes.

Tests:

- CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=/perfhome/llms/repos/te_repos/ag_thd_swa/TransformerEngine python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q

- CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 torchrun --nproc-per-node=2 --master-port=47151 /perfhome/llms/repos/te_repos/ag_thd_swa/cp_bench/tests/pytorch/attention/run_attention_with_cp.py dtype=bf16 model=bucket32k qkv_format=thd kernel_backend=FlashAttention cp_comm_type=all_gather benchmark=0 log_level=WARNING thd_seqlen_pattern=24576,28672,30720,32768

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
sudhakarsingh27 and others added 7 commits June 2, 2026 14:19
The THD max-logit path built the valid-token mask with per-sequence .item() calls when cu_seqlens_q_padded was provided. In AllGather CP this is on the training hot path and blocks the host for every batch entry.

Build the same interval mask on GPU using scatter_add over start/end deltas. This preserves support for non-zero padded offsets while avoiding the device-to-host synchronizations called out in PR review.

Tests:

- python3 -m black --check transformer_engine/pytorch/cpp_extensions/fused_attn.py

- CUDA_VISIBLE_DEVICES=0,1 NVTE_BATCH_MHA_P2P_COMM=1 TE_PATH=/perfhome/llms/repos/te_repos/ag_thd_swa/TransformerEngine python3 -m pytest tests/pytorch/attention/test_attention_with_cp.py -k 'fused_attention and thd and all_gather and not fp8 and not bias' -q

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Resolve the pybind.cpp tail-binding conflict by preserving upstream grouped_mlp_experimental bindings inside the helper split that keeps PYBIND11_MODULE under cpplint's function-size limit.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment on lines +3102 to +3103
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we add these two additional checks?

assert (
"padding" not in attn_mask_type
qkv_format == "thd" or "padding" not in attn_mask_type
), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}."
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a repeated check, from just a few lines above

Comment on lines +3168 to +3171
seq_dim = qkv_format.index("s")
assert (
q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0
), "Sequence length per GPU needs to be divisible by 2!"
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we adding additional checks which didn't exist before

Comment on lines +3289 to +3292
if qkv_format == "thd":
out = torch.zeros(o_shape, dtype=fwd_nominal_dtype, device=q.device)
else:
out = torch.empty_like(q)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better to just have torch.zeros for both paths?

Comment on lines +3574 to +3578
# Copy valid token ranges from this step's output.
# Each step writes at different positions (no overlap, no correction needed).
# Sync-free fused copy of every segment's valid rows in one launch (replaces
# a per-batch .item() slice-copy loop that stalled comm/compute overlap).
tex.thd_valid_copy(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Information from local experiments leaking into the comments. The comments should only tell what the code does and only in most important/pertinent cases share extra information. In this case, if the user hasn't seen the python version, they wouldn't know how to make sense of per-batch .item() or comm/compute overlap

# calls overlapping across streams, the allocator can hand step i's FA3 the block
# step i-1's FA3 just freed while step i-1's scheduler kernel is still atomically
# incrementing it -> shared semaphore -> out-of-range tiles -> illegal memory access
# (FA3-only, only under overlap; root cause in mb_runs/FA3_AG_RECORDSTREAM_RACE.md
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, leaking experiments data in the comments


if qkv_format == "thd":
# [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d]
# Use padded cu_seqlens since reorder computes slice boundaries via integer
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also note here that the padded cu_seqlens is global

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants