Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
db44fc2
[PyTorch][CP] Fix THD AllGather CP: offset-based approach with proper…
sudhakarsingh27 Apr 7, 2026
1a5ca4c
[PyTorch][CP] Enable THD+all_gather tests in test_attention_with_cp
sudhakarsingh27 Apr 7, 2026
b4db9eb
[PyTorch][Fused Attn] Fix max_logit masking for non-zero-starting cu_…
sudhakarsingh27 Apr 7, 2026
7491ab6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
b957725
some cleanup of ag+thd impl and gate e e te test for flash+ag+thd
sudhakarsingh27 Apr 10, 2026
c89173c
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 10, 2026
18e41bd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 10, 2026
0b48746
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2026
608106d
improve the logic and remvoe for loop from the code
sudhakarsingh27 Apr 13, 2026
4b95130
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
15af3af
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 13, 2026
5bec5b3
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 13, 2026
89b1066
AG+THD SWA: extend KV visibility for right window and rename a2a-spec…
sudhakarsingh27 Apr 16, 2026
55fc2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2026
f499f59
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 20, 2026
2569a65
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 20, 2026
4e4212f
resolved merge conflicts with main
sudhakarsingh27 Apr 23, 2026
10e4cfc
[PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
sudhakarsingh27 Apr 24, 2026
2a49dee
[PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
sudhakarsingh27 Apr 24, 2026
34e3d62
[QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
sudhakarsingh27 Apr 24, 2026
4745f98
[PyTorch] Fix FA3 deterministic gate to match upstream backward const…
sudhakarsingh27 Apr 24, 2026
4be004f
[PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
sudhakarsingh27 Apr 24, 2026
c476f15
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 24, 2026
a2b0f1b
[QA] Fix cutlass-dsl utils shadow in FA versions test
sudhakarsingh27 Apr 25, 2026
0ee22c7
merge conflicts with main
sudhakarsingh27 Apr 26, 2026
dfc1472
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 26, 2026
ac38d4f
merge flash attn pad bw seqs
sudhakarsingh27 Apr 26, 2026
b94e175
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 28, 2026
7ebe3d9
fixes after merging with flash_attn_pad_bw_seqs branchj
sudhakarsingh27 Apr 28, 2026
ddaa196
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 28, 2026
fc9182f
skip tests which OOM in deterministic+backward+hopper+large_configs a…
sudhakarsingh27 Apr 29, 2026
636666f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 29, 2026
7928bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
1585ebb
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 29, 2026
7ecad01
[PyTorch][CP] Replace Python-loop THD reorder with kernel-backed perm…
sudhakarsingh27 Apr 29, 2026
d8bf5c5
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 Apr 29, 2026
cc104d3
[PyTorch][CP] Fix AllGather SBHD forward: set cu_seqlens_kv_per_step
sudhakarsingh27 Apr 29, 2026
2464f43
make cp det and nondet tests run in parallel whenever possible
sudhakarsingh27 Apr 30, 2026
26e9f6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2026
611d876
[PyTorch][CP] Fix THD AllGather forward stream race on k_ag/v_ag
sudhakarsingh27 Apr 30, 2026
0aae820
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 30, 2026
789ccf0
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 1, 2026
0a32185
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 4, 2026
c33cf2d
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 May 4, 2026
353361a
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 May 4, 2026
a1062d9
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 5, 2026
1d4e170
Add THD + FlashAttention v3 support to AllGather CP backend
sudhakarsingh27 May 5, 2026
29785a0
Refactor AG THD window logic into shared get_kv_seq_info_after_all_ga…
sudhakarsingh27 May 6, 2026
09b01c9
[PyTorch][CP] Address PR 2829 self-review: clarify THD mask/cu_seqlens
sudhakarsingh27 May 22, 2026
a329afb
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 28, 2026
2dc5c15
[PyTorch] Fused thd_reorder kernel + sync-free CP THD reorder
sudhakarsingh27 May 30, 2026
5f606ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
24a95ab
[PyTorch] Sync-free thd_valid_copy kernel for AllGather CP THD fwd/bwd
sudhakarsingh27 May 30, 2026
b1faebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
8c44fcb
[PyTorch] Fix FA3 all_gather THD allocator-reuse race in fused reorder
sudhakarsingh27 Jun 1, 2026
628f73c
[PyTorch] Serialize FA3 AG calls on GPU
sudhakarsingh27 Jun 2, 2026
b897900
[PyTorch] Avoid D2H sync in THD max-logit mask
sudhakarsingh27 Jun 2, 2026
669342a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
0e926c4
[PyTorch] Address THD AG review and lint issues
sudhakarsingh27 Jun 3, 2026
ed28a8b
Merge NVIDIA main into CP THD SWA branch
sudhakarsingh27 Jun 3, 2026
696ea9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
90ab1c7
[PyTorch] Add THD helper kernel tests
sudhakarsingh27 Jun 3, 2026
a72e70b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,8 @@ def run_dpa_with_cp(
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
Expand Down
29 changes: 25 additions & 4 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand All @@ -328,8 +336,17 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd":
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if cp_comm_type == "all_gather" and not FlashAttentionUtils.v3_is_installed:
pytest.skip(
"THD + all_gather requires FA3 (seqused_k) to separate tensor offsets from"
" visibility limits in the gathered KV buffer."
)

if (
config.window_size != (-1, 0)
Expand Down Expand Up @@ -538,8 +555,12 @@ def test_cp_with_fused_attention(
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd":
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)

if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [
"p2p",
Expand Down
185 changes: 183 additions & 2 deletions tests/pytorch/attention/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# See LICENSE for license information.

"""Unit tests for context parallel utils."""

import itertools
import torch
import unittest
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
Expand All @@ -11,9 +13,16 @@
generate_positional_ids_for_cp,
)

try:
import transformer_engine_torch as tex
except ImportError:
tex = None


class TestSequencePadding(unittest.TestCase):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(
self,
):
"""Test with custom padding values for all tensors."""
# Setup

Expand Down Expand Up @@ -467,7 +476,36 @@ def test_sequences_longer_than_divisibility_factor(self):
)

expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7]
[
0,
1,
2,
3,
4,
5,
6,
7,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
0,
1,
2,
3,
4,
5,
6,
7,
]
)

expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28])
Expand Down Expand Up @@ -710,5 +748,148 @@ def test_integration_with_padding_and_cp_slicing(self):
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))


def _legacy_reorder_thd_to_rank_sharded(x, cu_seqlens, cp_size, seq_dim=0):
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence

indices = [
(
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=cu_seqlens.device,
),
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=cu_seqlens.device,
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]

indices = list(itertools.chain(*indices))
indices = torch.cat(indices)
return x.index_select(seq_dim, indices)


def _legacy_reorder_thd_to_contiguous(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0):
max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size
cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size

indices = [
torch.arange(
(
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
(
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
device=cu_seqlens.device,
)
for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:])
for loc, chunk_id in enumerate(seq_chunk_ids)
]

indices = torch.cat(indices)
return x.index_select(seq_dim, indices)


def _legacy_valid_copy(out, inp, cu_seqlens_padded, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
for b in range(batch_size):
s = cu_seqlens_padded[b].item()
sz = (cu_seqlens[b + 1] - cu_seqlens[b]).item()
if sz > 0:
out[s : s + sz].copy_(inp[s : s + sz])


@unittest.skipIf(
not torch.cuda.is_available() or tex is None,
"THD kernel tests require CUDA and transformer_engine_torch",
)
class TestTHDKernels(unittest.TestCase):
def test_thd_reorder_matches_legacy_python_reorder(self):
cp_size = 4
cu_seqlens = torch.tensor([0, 8, 24, 40], dtype=torch.int32, device="cuda")
x = torch.arange(40 * 2 * 4, dtype=torch.float16, device="cuda").view(40, 2, 4)

rank_sharded = tex.thd_reorder(x, cu_seqlens, cp_size, False, x.shape[0])
ref_rank_sharded = _legacy_reorder_thd_to_rank_sharded(x, cu_seqlens, cp_size)
self.assertTrue(torch.equal(rank_sharded, ref_rank_sharded))

seq_chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device="cuda")
for rank in range(cp_size):
seq_chunk_ids[rank] = 2 * rank
seq_chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
contiguous = tex.thd_reorder(rank_sharded, cu_seqlens, cp_size, True, rank_sharded.shape[0])
ref_contiguous = _legacy_reorder_thd_to_contiguous(
rank_sharded, cu_seqlens, seq_chunk_ids, cp_size
)
self.assertTrue(torch.equal(contiguous, ref_contiguous))
self.assertTrue(torch.equal(contiguous, x))

def test_thd_get_partitioned_indices_matches_dual_chunk_expected_indices(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")

rank0 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 0)
rank1 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 1)

expected_rank0 = torch.tensor([0, 1, 6, 7, 8, 9, 14, 15], dtype=torch.int32, device="cuda")
expected_rank1 = torch.tensor(
[2, 3, 4, 5, 10, 11, 12, 13], dtype=torch.int32, device="cuda"
)
self.assertTrue(torch.equal(rank0, expected_rank0))
self.assertTrue(torch.equal(rank1, expected_rank1))

def test_thd_valid_copy_matches_legacy_slice_copy_loop(self):
cu_seqlens_padded = torch.tensor([2, 6, 12], dtype=torch.int32, device="cuda")
cu_seqlens = torch.tensor([0, 3, 7], dtype=torch.int32, device="cuda")
inp = torch.arange(12 * 2 * 4, dtype=torch.float16, device="cuda").view(12, 2, 4)
out = torch.full_like(inp, -1)
expected = torch.full_like(inp, -1)

_legacy_valid_copy(expected, inp, cu_seqlens_padded, cu_seqlens)
tex.thd_valid_copy(out, inp, cu_seqlens_padded, cu_seqlens)
self.assertTrue(torch.equal(out, expected))

def test_thd_read_half_tensor_reads_each_sequence_half(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")
q = torch.arange(16 * 2 * 4, dtype=torch.float16, device="cuda").view(16, 2, 4)
kv = torch.arange(2 * 16 * 2 * 4, dtype=torch.float16, device="cuda").view(2, 16, 2, 4)

q_first = tex.thd_read_half_tensor(q, cu_seqlens, 0)
q_second = tex.thd_read_half_tensor(q, cu_seqlens, 1)
kv_first = tex.thd_read_half_tensor(kv, cu_seqlens, 0)
kv_second = tex.thd_read_half_tensor(kv, cu_seqlens, 1)

expected_first = torch.cat([q[0:4], q[8:12]], dim=0)
expected_second = torch.cat([q[4:8], q[12:16]], dim=0)
self.assertTrue(torch.equal(q_first, expected_first))
self.assertTrue(torch.equal(q_second, expected_second))
self.assertTrue(torch.equal(kv_first, torch.stack([expected_first, expected_first + 128])))
self.assertTrue(
torch.equal(kv_second, torch.stack([expected_second, expected_second + 128]))
)

def test_thd_read_second_half_lse_handles_packed_and_batch_major_lse(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")
lse = torch.arange(2 * 2 * 8, dtype=torch.float32, device="cuda").view(2, 2, 8)
packed_lse = torch.arange(2 * 16, dtype=torch.float32, device="cuda").view(2, 16)

second_half_lse = tex.thd_read_second_half_lse(lse, cu_seqlens, False, 4)
packed_second_half_lse = tex.thd_read_second_half_lse(packed_lse, cu_seqlens, True, 8)

expected = lse[:, :, 4:8]
expected_packed = torch.cat([packed_lse[:, 4:8], packed_lse[:, 12:16]], dim=1)
self.assertTrue(torch.equal(second_half_lse, expected))
self.assertTrue(torch.equal(packed_second_half_lse, expected_packed))


if __name__ == "__main__":
unittest.main()
Loading
Loading