Skip to content
Open
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
db44fc2
[PyTorch][CP] Fix THD AllGather CP: offset-based approach with proper…
sudhakarsingh27 Apr 7, 2026
1a5ca4c
[PyTorch][CP] Enable THD+all_gather tests in test_attention_with_cp
sudhakarsingh27 Apr 7, 2026
b4db9eb
[PyTorch][Fused Attn] Fix max_logit masking for non-zero-starting cu_…
sudhakarsingh27 Apr 7, 2026
7491ab6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
b957725
some cleanup of ag+thd impl and gate e e te test for flash+ag+thd
sudhakarsingh27 Apr 10, 2026
c89173c
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 10, 2026
18e41bd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 10, 2026
0b48746
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2026
608106d
improve the logic and remvoe for loop from the code
sudhakarsingh27 Apr 13, 2026
4b95130
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
15af3af
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 13, 2026
5bec5b3
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 13, 2026
89b1066
AG+THD SWA: extend KV visibility for right window and rename a2a-spec…
sudhakarsingh27 Apr 16, 2026
55fc2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2026
f499f59
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 20, 2026
2569a65
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 20, 2026
4e4212f
resolved merge conflicts with main
sudhakarsingh27 Apr 23, 2026
10e4cfc
[PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
sudhakarsingh27 Apr 24, 2026
2a49dee
[PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
sudhakarsingh27 Apr 24, 2026
34e3d62
[QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
sudhakarsingh27 Apr 24, 2026
4745f98
[PyTorch] Fix FA3 deterministic gate to match upstream backward const…
sudhakarsingh27 Apr 24, 2026
4be004f
[PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
sudhakarsingh27 Apr 24, 2026
c476f15
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 24, 2026
a2b0f1b
[QA] Fix cutlass-dsl utils shadow in FA versions test
sudhakarsingh27 Apr 25, 2026
0ee22c7
merge conflicts with main
sudhakarsingh27 Apr 26, 2026
dfc1472
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 26, 2026
ac38d4f
merge flash attn pad bw seqs
sudhakarsingh27 Apr 26, 2026
b94e175
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 28, 2026
7ebe3d9
fixes after merging with flash_attn_pad_bw_seqs branchj
sudhakarsingh27 Apr 28, 2026
ddaa196
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 28, 2026
fc9182f
skip tests which OOM in deterministic+backward+hopper+large_configs a…
sudhakarsingh27 Apr 29, 2026
636666f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 29, 2026
7928bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
1585ebb
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 29, 2026
7ecad01
[PyTorch][CP] Replace Python-loop THD reorder with kernel-backed perm…
sudhakarsingh27 Apr 29, 2026
d8bf5c5
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 Apr 29, 2026
cc104d3
[PyTorch][CP] Fix AllGather SBHD forward: set cu_seqlens_kv_per_step
sudhakarsingh27 Apr 29, 2026
2464f43
make cp det and nondet tests run in parallel whenever possible
sudhakarsingh27 Apr 30, 2026
26e9f6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2026
611d876
[PyTorch][CP] Fix THD AllGather forward stream race on k_ag/v_ag
sudhakarsingh27 Apr 30, 2026
0aae820
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 30, 2026
789ccf0
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 1, 2026
0a32185
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 4, 2026
c33cf2d
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 May 4, 2026
353361a
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 May 4, 2026
a1062d9
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 5, 2026
1d4e170
Add THD + FlashAttention v3 support to AllGather CP backend
sudhakarsingh27 May 5, 2026
29785a0
Refactor AG THD window logic into shared get_kv_seq_info_after_all_ga…
sudhakarsingh27 May 6, 2026
09b01c9
[PyTorch][CP] Address PR 2829 self-review: clarify THD mask/cu_seqlens
sudhakarsingh27 May 22, 2026
a329afb
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 28, 2026
2dc5c15
[PyTorch] Fused thd_reorder kernel + sync-free CP THD reorder
sudhakarsingh27 May 30, 2026
5f606ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
24a95ab
[PyTorch] Sync-free thd_valid_copy kernel for AllGather CP THD fwd/bwd
sudhakarsingh27 May 30, 2026
b1faebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
8c44fcb
[PyTorch] Fix FA3 all_gather THD allocator-reuse race in fused reorder
sudhakarsingh27 Jun 1, 2026
628f73c
[PyTorch] Serialize FA3 AG calls on GPU
sudhakarsingh27 Jun 2, 2026
b897900
[PyTorch] Avoid D2H sync in THD max-logit mask
sudhakarsingh27 Jun 2, 2026
669342a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
0e926c4
[PyTorch] Address THD AG review and lint issues
sudhakarsingh27 Jun 3, 2026
ed28a8b
Merge NVIDIA main into CP THD SWA branch
sudhakarsingh27 Jun 3, 2026
696ea9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
90ab1c7
[PyTorch] Add THD helper kernel tests
sudhakarsingh27 Jun 3, 2026
a72e70b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
a8831fb
Merge remote-tracking branch 'origin/main' into codex/pr2829-review-c…
sudhakarsingh27 Jun 5, 2026
dfc3a97
[PyTorch] Clean up THD AG review comments
sudhakarsingh27 Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,8 @@ def run_dpa_with_cp(
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
Expand Down
29 changes: 25 additions & 4 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand All @@ -328,8 +336,17 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd":
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if cp_comm_type == "all_gather" and not FlashAttentionUtils.v3_is_installed:
pytest.skip(
"THD + all_gather requires FA3 (seqused_k) to separate tensor offsets from"
" visibility limits in the gathered KV buffer."
)

if (
config.window_size != (-1, 0)
Expand Down Expand Up @@ -538,8 +555,12 @@ def test_cp_with_fused_attention(
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd":
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)

if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [
"p2p",
Expand Down
166 changes: 160 additions & 6 deletions transformer_engine/common/fused_attn/context_parallel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) {
return left - 1;
}

// Dual-chunk source index for THD CP partitioning. cu_seqlens_s must already be divided by
// world_size. Single source of truth shared by thd_partition_indices_kernel and
// thd_reorder_kernel so the two never diverge.
__forceinline__ __device__ int thd_partition_src_index(int token_id, int *cu_seqlens_s, int batch,
int world_size, int rank) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
return index + cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
}

/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
Expand All @@ -96,12 +108,82 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
int num_threads = blockDim.x * gridDim.x;

for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
output[token_id] = thd_partition_src_index(token_id, cu_seqlens_s, batch, world_size, rank);
}
}

/***************************************************************************************************
* Support THD format for Context Parallel: fused dual-chunk reorder (gather/scatter)
* out[gi] = inp[src(gi)] (gather, to_rank_sharded) or out[src(gi)] = inp[gi] (scatter,
* to_contiguous). src is computed inline (no materialized index tensor). Modeled on
* thd_read_half_tensor_kernel: warp-per-token, cu_seqlens_s in shared, float4 vectorized copy.
* hidden_size_in_bytes must be a multiple of 16 (same assumption as thd_read_half_tensor).
**************************************************************************************************/
__global__ void thd_reorder_kernel(void *out, void *inp, int *cu_seqlens, int batch,
int total_tokens, int world_size, int hidden_size_in_bytes,
bool scatter) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / world_size;
}
__syncthreads();

int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int tpr = total_tokens / world_size;
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);

for (int gi = warpid; gi < total_tokens; gi += num_warps) {
int rank = gi / tpr;
int token_id = gi % tpr;
int src = thd_partition_src_index(token_id, cu_seqlens_s, batch, world_size, rank);
int rd = scatter ? gi : src;
int wr = scatter ? src : gi;
float4 *src_tok = reinterpret_cast<float4 *>(reinterpret_cast<char *>(inp) +
static_cast<size_t>(rd) * hidden_size_in_bytes);
float4 *dst_tok = reinterpret_cast<float4 *>(reinterpret_cast<char *>(out) +
static_cast<size_t>(wr) * hidden_size_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) dst_tok[idx] = src_tok[idx];
}
}

/***************************************************************************************************
* Support THD format for Context Parallel: copy the VALID token rows of a per-step output/grad
* into the destination accumulator, leaving padded tails untouched. Sync-free replacement for the
* per-batch `.item()` slice-copy loops in the AllGather CP THD fwd/bwd. cu_seqlens_padded gives a
* token's segment + local offset in the padded layout; cu_seqlens gives each segment's valid
* length. Warp-per-token, float4 vectorized, modeled on thd_reorder_kernel.
**************************************************************************************************/
__global__ void thd_valid_copy_kernel(void *out, void *inp, int *cu_seqlens_padded, int *cu_seqlens,
int batch, int total_tokens, int hidden_size_in_bytes) {
extern __shared__ int padded_s[]; // [0..batch] padded boundaries
int *valid_s = padded_s + (batch + 1); // [0..batch] valid boundaries
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
padded_s[i] = cu_seqlens_padded[i];
valid_s[i] = cu_seqlens[i];
}
__syncthreads();

int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);

for (int token_id = warpid; token_id < total_tokens; token_id += num_warps) {
int seq_id = binary_search(token_id, padded_s, batch + 1);
int local = token_id - padded_s[seq_id];
int valid_len = valid_s[seq_id + 1] - valid_s[seq_id];
// local can be negative when a segment's padded start is shifted past earlier tokens (e.g.
// step-1 chunks: cu_seqlens_padded[:-1] += chunk_size). Those tokens are outside any valid
// run, so skip them -- otherwise the first chunk's already-written rows get clobbered.
if (local >= 0 && local < valid_len) {
float4 *src_tok = reinterpret_cast<float4 *>(
Comment thread
sudhakarsingh27 marked this conversation as resolved.
reinterpret_cast<char *>(inp) + static_cast<size_t>(token_id) * hidden_size_in_bytes);
float4 *dst_tok = reinterpret_cast<float4 *>(
reinterpret_cast<char *>(out) + static_cast<size_t>(token_id) * hidden_size_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) dst_tok[idx] = src_tok[idx];
}
}
}

Expand Down Expand Up @@ -678,6 +760,57 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
NVTE_CHECK_CUDA(cudaGetLastError());
}

void thd_reorder(const Tensor &inp, const Tensor &cu_seqlens, Tensor &out, int world_size,
bool scatter, int total_tokens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto cu_seqlens_shape = cu_seqlens.shape();
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);

auto inp_shape = inp.shape();
size_t row_elems = 1;
for (int i = 1; i < inp.dim(); i++) row_elems *= inp_shape[i];
int hidden_size_in_bytes = (row_elems * typeToNumBits(inp.dtype())) / 8;
NVTE_CHECK(hidden_size_in_bytes % 16 == 0); // 128-bit load/store

int batch = cu_seqlens_shape[0] - 1;
constexpr unsigned int block = 256;
unsigned int grid = (static_cast<unsigned int>(total_tokens) * 32 + block - 1) / block;
thd_reorder_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
out.data.dptr, inp.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
total_tokens, world_size, hidden_size_in_bytes, scatter);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void thd_valid_copy(const Tensor &inp, const Tensor &cu_seqlens_padded, const Tensor &cu_seqlens,
Tensor &out, int total_tokens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1 && cu_seqlens_padded.dim() == 1);
auto cu_seqlens_shape = cu_seqlens.shape();
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
NVTE_CHECK(cu_seqlens_padded.shape()[0] == cu_seqlens_shape[0]);
NVTE_CHECK(total_tokens > 0);

auto inp_shape = inp.shape();
size_t row_elems = 1;
for (int i = 1; i < inp.dim(); i++) row_elems *= inp_shape[i];
int hidden_size_in_bytes = (row_elems * typeToNumBits(inp.dtype())) / 8;
NVTE_CHECK(hidden_size_in_bytes % 16 == 0); // 128-bit load/store

int batch = cu_seqlens_shape[0] - 1;
constexpr unsigned int block = 256;
unsigned int grid = (static_cast<unsigned int>(total_tokens) * 32 + block - 1) / block;
thd_valid_copy_kernel<<<grid, block, sizeof(int) * 2 * (batch + 1), stream>>>(
out.data.dptr, inp.data.dptr, reinterpret_cast<int *>(cu_seqlens_padded.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, total_tokens, hidden_size_in_bytes);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace context_parallel
} // namespace transformer_engine

Expand Down Expand Up @@ -750,3 +883,24 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
*convertNVTETensorCheck(output), total_tokens,
world_size, rank, stream);
}

void nvte_cp_thd_reorder(const NVTETensor &inp, const NVTETensor &cu_seqlens, NVTETensor out,
int world_size, int scatter, int total_tokens, cudaStream_t stream) {
NVTE_API_CALL(nvte_cp_thd_reorder);
using namespace transformer_engine;

context_parallel::thd_reorder(*convertNVTETensorCheck(inp), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(out), world_size, scatter != 0,
total_tokens, stream);
}

void nvte_cp_thd_valid_copy(const NVTETensor &inp, const NVTETensor &cu_seqlens_padded,
const NVTETensor &cu_seqlens, NVTETensor out, int total_tokens,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cp_thd_valid_copy);
using namespace transformer_engine;

context_parallel::thd_valid_copy(
*convertNVTETensorCheck(inp), *convertNVTETensorCheck(cu_seqlens_padded),
*convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(out), total_tokens, stream);
}
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,41 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
int total_tokens, int world_size, int rank,
cudaStream_t stream);

/*! \brief Fused dual-chunk THD reorder for Context Parallel (gather or scatter).
*
* Computes the dual-chunk source index inline (no materialized index tensor) and copies each
* token row. scatter=0: out[gi]=inp[src(gi)] (contiguous->rank-sharded); scatter=1:
* out[src(gi)]=inp[gi] (rank-sharded->contiguous). Row size must be a multiple of 16 bytes.
*
* \param[in] inp Input THD tensor [total_tokens, ...].
* \param[in] cu_seqlens Padded cumulative sequence lengths, [batch_size + 1], int32.
* \param[out] out Output tensor, same shape/dtype as inp.
* \param[in] world_size Context-parallel size.
* \param[in] scatter 0 = gather (rank-sharded), 1 = scatter (contiguous).
* \param[in] total_tokens Total padded tokens (= inp.shape[0]).
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_reorder(const NVTETensor &inp, const NVTETensor &cu_seqlens, NVTETensor out,
int world_size, int scatter, int total_tokens, cudaStream_t stream);

/*! \brief Copy valid token rows of a per-step THD output/grad into an accumulator (CP AllGather).
*
* Sync-free replacement for the per-batch `.item()` slice-copy loops in the AllGather CP THD
* fwd/bwd. For each segment, copies rows [cu_seqlens_padded[b], cu_seqlens_padded[b]+valid_len_b)
* from inp to out at identical indices, leaving padded tails of out untouched. Row size must be a
* multiple of 16 bytes.
*
* \param[in] inp Per-step THD source tensor [total_tokens, ...].
* \param[in] cu_seqlens_padded Padded cumulative sequence lengths, [batch_size + 1], int32.
* \param[in] cu_seqlens Valid cumulative sequence lengths, [batch_size + 1], int32.
* \param[in,out] out Destination accumulator, same shape/dtype as inp.
* \param[in] total_tokens Total padded tokens (= inp.shape[0]).
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_valid_copy(const NVTETensor &inp, const NVTETensor &cu_seqlens_padded,
const NVTETensor &cu_seqlens, NVTETensor out, int total_tokens,
cudaStream_t stream);

/*! \brief Convert tensor from THD to BSHD format.
*
* \warning This API is **experimental** and subject to change.
Expand Down
Loading
Loading