diff --git a/examples/distributed/README.md b/examples/distributed/README.md index 48cf85488b..f7b4cfe3d6 100644 --- a/examples/distributed/README.md +++ b/examples/distributed/README.md @@ -1,13 +1,11 @@ # Distributed Examples This directory contains examples demonstrating distributed computing capabilities using TileLang. +These examples are sorted into two categories: +- Examples under `nvshmem` folder and inter-node examples depend on NVSHMEM library for distributed communication. +- Other examples have no external dependency and only rely on TileScale IPC -For example, -``` -./tilelang/distributed/launch.sh examples/distributed/example_allgather.py -``` - -## Prerequisites +## `nvshmem` examples Before running the examples, you need to build NVSHMEM library for device-side code generation. @@ -28,3 +26,17 @@ Then you can test python import: ```bash python -c "import pynvshmem" ``` + +Finally, run examples like this: +```bash +TILELANG_USE_DISTRIBUTED=1 bash ./tilelang/distributed/launch.sh examples/distributed/nvshmem/example_allgather.py +``` + +## IPC-based examples + +Simply run via python: +```bash +TILELANG_USE_DISTRIBUTED=1 python examples/distributed/intranode/example_allgather_gemm_overlapped.py +``` + +> Tips: To disable annoying NCCL IB logs, consider running with: `NCCL_IB_DISABLE=1` diff --git a/examples/distributed/example_allgather_gemm.py b/examples/distributed/example_allgather_gemm.py deleted file mode 100644 index 702f1264ad..0000000000 --- a/examples/distributed/example_allgather_gemm.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -import pynvshmem -import os -import tilelang -import tilelang.language as T -from tilelang.profiler import TensorSupplyType -from tilelang.distributed import init_distributed - - -def allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K, dtype="float16"): - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Buffer((M, K), dtype), - A_ag: T.Buffer((M * PE_num, K), dtype), - B: T.Buffer((K, N), dtype), - signal: T.Buffer((PE_num,), "uint64"), - C: T.Buffer((M * PE_num, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - mype = T.alloc_local([1], "int32") - npes = T.alloc_local([1], "int32") - peer = T.alloc_local([1], "int32") - - mype[0] = T.get_pe() - npes[0] = T.get_pe_num() - - T.copy(A[by * block_M, bx * block_K], A_shared) - T.copy(A_shared, A_ag[mype[0] * M, bx * block_K]) - for k in T.serial(PE_num - 1): - peer[0] = (mype[0] + 1 + k) % npes[0] - T.putmem_signal_nbi_block( - T.address_of(A_ag[mype[0] * M, 0]), - T.address_of(A[0, 0]), - block_M * block_K * 2, - T.address_of(signal[k]), - k + 1, - 9, - peer[0], - ) - for k in T.serial(PE_num - 1): - T.signal_wait_until(T.address_of(signal[k]), 0, k + 1) - - for bk in T.serial(PE_num): - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): - T.copy(A_ag[bk * M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - T.copy(C_local, C[bk * M, bx * block_N]) - - return main - - -tilelang.disable_cache() -M, N, K, block_M, block_N, block_K = 64, 64, 64, 64, 64, 64 -dtype = torch.float16 - -RANK = int(os.environ.get("RANK", 0)) -WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) -PE_num = WORLD_SIZE -func = allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K) -kernel = tilelang.compile(func, out_idx=-1, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) - -# Get CUDA Source -if RANK == 0: - print(kernel.get_kernel_source()) - -profiler = kernel.get_profiler(tensor_supply_type=TensorSupplyType.Randn) - -A_tensor = torch.arange(M * PE_num * K, dtype=dtype).cuda() * 0.001 -A_tensor = A_tensor.reshape(M * PE_num, K) -B_tensor = torch.arange(K * N, dtype=dtype).cuda() * 0.001 -B_tensor = B_tensor.reshape(K, N) - -print("A_tensor:", A_tensor) -print("B_tensor:", B_tensor) - - -def ref_program(A, B): - return A @ B - - -C_ref = ref_program(A_tensor, B_tensor) -print("C_ref:", C_ref) - -# profiler.init_distributed() -A_local = pynvshmem.nvshmem_create_tensor([M, K], dtype) -A_local[:].copy_(A_tensor[M * RANK : M * (RANK + 1), :]) - -A_ag_local = pynvshmem.nvshmem_create_tensor([M * PE_num, K], dtype) -A_ag_local.fill_(0) - -B_local = pynvshmem.nvshmem_create_tensor([K, N], dtype) -B_local[:].copy_(B_tensor) - -signal_local = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) -signal_local.fill_(0) - -out = kernel(A_local, A_ag_local, B_local, signal_local) -print("out:", out) - -ref_cpu = C_ref.cpu() -for i in range(PE_num): - if i == RANK: - out_cpu = out.cpu() - assert torch.allclose(out_cpu, ref_cpu, atol=1e-2, rtol=1e-2) - print(f"rank {i} check passed.") diff --git a/examples/distributed/example_nvshmem.py b/examples/distributed/example_nvshmem.py deleted file mode 100644 index 8f8de69ed5..0000000000 --- a/examples/distributed/example_nvshmem.py +++ /dev/null @@ -1,58 +0,0 @@ -import tilelang -import tilelang.language as T - -import tvm - - -@tvm.register_func("tilelang_callback_cuda_postproc", override=True) -def tilelang_callback_cuda_postproc(code, _): - code = """ -#include -#include -#include -#include -#include -#include -#include -#include - -extern "C" __global__ void main_kernel(short* __restrict__ A, short* __restrict__ B); -extern "C" __global__ void __launch_bounds__(128) main_kernel(short* __restrict__ A, short* __restrict__ B) { - int mype[1]; - extern __shared__ __align__(1024) short A_shared[]; - mype[0] = nvshmem_my_pe(); - if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - printf("mype: %d\\n", mype[0]); - } -}""" - return code - - -def dist_test(M, N, block_M, block_N, dtype="int16"): - @T.prim_func - def main( - A: T.Buffer((M, N), dtype), - B: T.Buffer((M, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_N), dtype) - mype = T.alloc_local([1], "int32") - - mype[0] = T.get_pe() - T.copy(A[by * block_M, bx * block_N], A_shared) - T.copy(A_shared, B[by * block_M, bx * block_N]) - - return main - - -func = dist_test(128, 128, 128, 128) - -kernel = tilelang.compile(func, out_idx=-1) - -# Get CUDA Source -print(kernel.get_kernel_source()) - -profiler = kernel.get_profiler() -out = profiler.run_once() - -print(out) diff --git a/examples/distributed/gemm_rs_utils.py b/examples/distributed/gemm_rs_utils.py deleted file mode 100644 index 0a6634c393..0000000000 --- a/examples/distributed/gemm_rs_utils.py +++ /dev/null @@ -1,240 +0,0 @@ -import dataclasses -from typing import List - -import torch - -import pynvshmem - -SIGNAL_DTYPE = torch.uint64 - - -class BarrierAllContext: - """ - You may use this to barrier all ranks in global, or just in intra-node team. - - NOTE: nvshmem_barrier_all is slower for intra-node only. - """ - - def __init__(self, is_intra_node): - self.is_intra_node = is_intra_node - # TODO: implement these for intra-node - # if self.is_intra_node: - # self.rank = pynvshmem.nvshmem_my_pe() - # self.local_rank = pynvshmem.nvshmem_team_my_pe(pynvshmem.Team.NODE) - # self.num_local_ranks = pynvshmem.nvshmem_team_n_pes(pynvshmem.Team.NODE) - # self.symm_barrier = pynvshmem.nvshmem_create_tensor((1, ), torch.int32) - # self.symm_barrier.fill_(0) - # pynvshmem.nvshmem_barrier_all() - - -@dataclasses.dataclass -class ReduceScatter2DContext: - max_M: int - N: int - rank: int - world_size: int - local_world_size: int - dtype: torch.dtype - overlap_with_gemm: bool - - # comm buffer - scatter_bufs: List[torch.Tensor] - rs_per_node_bufs: List[torch.Tensor] - p2p_bufs: List[torch.Tensor] - - # barrier bufs - signal_bufs: List[torch.Tensor] # need reset: signal_buf = scatter_signal | rs_per_node_signal - - # intra-node barrier - barrier: BarrierAllContext - - # stream - reduction_stream: torch.cuda.Stream - p2p_stream: torch.cuda.Stream - - # sms - num_sync_sms: int - num_p2p_sms: int - num_reduction_sms: int - - # preprocess to reduce cpu overhead - # comm barriers - scatter_signal_bufs: List[torch.Tensor] = dataclasses.field(init=False) - rs_per_node_signal_bufs: List[torch.Tensor] = dataclasses.field(init=False) - - local_rank: int = dataclasses.field(init=False) - node_id: int = dataclasses.field(init=False) - nnodes: int = dataclasses.field(init=False) - - scatter_signal_buf_list_for_each_node: List[torch.Tensor] = dataclasses.field(init=False) - - def __post_init__(self): - self.local_rank = self.rank % self.local_world_size - self.node_id = self.rank // self.local_world_size - assert self.world_size % self.local_world_size == 0 - assert self.max_M % self.world_size == 0 - assert len(self.signal_bufs) == self.local_world_size - self.nnodes = self.world_size // self.local_world_size - self.scatter_signal_buf_list_for_each_node = [] - for buf in self.signal_bufs: - assert buf.shape[0] >= 2 * self.world_size - - self.scatter_signal_bufs = [buf[: self.world_size] for buf in self.signal_bufs] - self.rs_per_node_signal_bufs = [buf[self.world_size : self.world_size * 2] for buf in self.signal_bufs] - - for node_id in range(self.nnodes): - self.scatter_signal_buf_list_for_each_node.append( - self.scatter_signal_bufs[self.local_rank][node_id * self.local_world_size : (node_id + 1) * self.local_world_size] - ) - - def reset_barriers(self) -> int: - # self.scatter_signal_bufs[self.local_rank].fill_(0) - # self.rs_per_node_signal_bufs[self.local_rank].fill_(0) - self.signal_bufs[self.local_rank].fill_(0) - - def get_scatter_bufs_and_signal_for_each_node(self, input, node_id): - M = input.shape[0] - M_per_rank = M // self.world_size - M_per_node = M_per_rank * self.local_world_size - M_start = node_id * M_per_node - M_end = M_start + M_per_node - scatter_bufs_intra_node = [self.scatter_bufs[i][M_start:M_end] for i in range(self.local_world_size)] - return scatter_bufs_intra_node, self.scatter_signal_buf_list_for_each_node[node_id] - - @property - def rs_per_node_buf(self) -> torch.Tensor: - return self.rs_per_node_bufs[self.local_rank] - - @property - def rs_per_node_signal_buf(self) -> torch.Tensor: - return self.rs_per_node_signal_bufs[self.local_rank] - - @property - def p2p_buf(self) -> torch.Tensor: - return self.p2p_bufs[self.local_rank] - - @property - def num_rs_sms(self) -> int: - if self.nnodes > 1: - return self.num_sync_sms + self.num_p2p_sms + self.num_reduction_sms - else: - # for intra node rs, no need sm. - return 0 - - @property - def scatter_signal_buf(self) -> torch.Tensor: - return self.scatter_signal_bufs[self.local_rank] - - -def create_reduce_scater_2d_ctx( - max_M, N, rank, world_size, local_world_size, dtype, overlap_with_gemm=True, num_reduction_sms=15 -) -> ReduceScatter2DContext: - """ - for num_reduction_sms: tunable param, 16 are enough for H800 - For H800, we overlap local reduce and inter-node p2p with intra-node scatter. - The reduction kernel bandwidth is not a bottleneck if it exceeds 450GB, so only a few SMs are needed. - For machines with higher intra_node bandwidth(e.g. H100), we may need to increase the number of SMs or redesign overlapping. - """ - assert world_size % local_world_size == 0 - assert max_M % world_size == 0 - - scatter_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M, N], dtype) - - rs_per_node_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M // local_world_size, N], dtype) - - p2p_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M // local_world_size, N], dtype) - - # signal_buf: scatter_signal | rs_per_node_signal - num_signal_bufs = 2 - signal_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node( - [ - world_size * num_signal_bufs, - ], - SIGNAL_DTYPE, - ) - - # TODO: implement barrier_all_on_stream - # barrier_all_on_stream(None, torch.cuda.current_stream()) - - p2p_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) - reduction_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) - - num_sync_sms = 0 - num_p2p_sms = 1 - ctx = ReduceScatter2DContext( - max_M=max_M, - N=N, - rank=rank, - world_size=world_size, - local_world_size=local_world_size, - dtype=dtype, - overlap_with_gemm=overlap_with_gemm, - scatter_bufs=scatter_bufs, - rs_per_node_bufs=rs_per_node_bufs, - p2p_bufs=p2p_bufs, - signal_bufs=signal_bufs, - barrier=BarrierAllContext(True), - reduction_stream=reduction_stream, - p2p_stream=p2p_stream, - num_sync_sms=num_sync_sms, - num_p2p_sms=num_p2p_sms, - num_reduction_sms=num_reduction_sms, - ) - return ctx - - -################### context ################### -@dataclasses.dataclass -class GEMMReduceScatterTensorParallelContext: - rs_ctx: ReduceScatter2DContext - output_dtype: torch.dtype - - # gemm bufs (symm address) - gemm_out_bufs: List[torch.Tensor] - - # stream - rs_stream: torch.cuda.Stream - - # gemm kernel config - num_gemm_sms: int - BLOCK_M: int = 128 - BLOCK_N: int = 256 - BLOCK_K: int = 64 - GROUP_M: int = 8 - stages: int = 3 - - def update(self, rs_stream, output_dtype=None, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, GROUP_M=8, stages=3): - self.rs_stream = rs_stream - self.output_dtype = output_dtype - self.BLOCK_M = BLOCK_M - self.BLOCK_N = BLOCK_N - self.BLOCK_K = BLOCK_K - self.GROUP_M = GROUP_M - self.stages = stages - - def get_gemm_out_buf(self, input): - M, _ = input.shape - local_rank = self.rs_ctx.local_rank - return self.gemm_out_bufs[local_rank][:M] - - -def create_gemm_rs_context( - max_M, N, rank, world_size, local_world_size, output_dtype, rs_stream, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, GROUP_M=8, stages=3 -) -> GEMMReduceScatterTensorParallelContext: - rs_ctx = create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, output_dtype, overlap_with_gemm=True) - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - num_gemm_sms = NUM_SMS - rs_ctx.num_rs_sms - gemm_out_bufs = pynvshmem.nvshmem_create_tensor_list_intra_node([max_M, N], output_dtype) - ctx = GEMMReduceScatterTensorParallelContext( - rs_ctx=rs_ctx, - output_dtype=output_dtype, - gemm_out_bufs=gemm_out_bufs, - rs_stream=rs_stream, - num_gemm_sms=num_gemm_sms, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, - GROUP_M=GROUP_M, - stages=stages, - ) - return ctx diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/intranode/example_allgather_gemm_overlapped.py similarity index 100% rename from examples/distributed/example_allgather_gemm_overlapped.py rename to examples/distributed/intranode/example_allgather_gemm_overlapped.py diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/intranode/example_gemm_rs_overlapped.py similarity index 99% rename from examples/distributed/example_gemm_rs_overlapped.py rename to examples/distributed/intranode/example_gemm_rs_overlapped.py index 3ce6d35357..db54f7877e 100644 --- a/examples/distributed/example_gemm_rs_overlapped.py +++ b/examples/distributed/intranode/example_gemm_rs_overlapped.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import tilelang import tilelang.language as T import argparse diff --git a/examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py b/examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py new file mode 100644 index 0000000000..2a92ac0546 --- /dev/null +++ b/examples/distributed/intranode/example_post_attn_all2all_transpose_intranode.py @@ -0,0 +1,143 @@ +""" +Intranode post-attention all-to-all (transpose) using tilescale IPC API. + +Input: [B, H_PE, S, D] — partial heads, full sequence per rank +Output: [B, S_PE, NH, D] — partial sequence, full heads per rank + +Rank r sends src[:, :, p*S_PE:(p+1)*S_PE, :] (shape [B, H_PE, S_PE, D]) +to rank p's dst[:, :, r*H_PE:(r+1)*H_PE, :] (shape [B, S_PE, H_PE, D]) +after transposing dims 1 and 2. +""" + +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist, perf_fn + +dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + + +def torch_reference(src, group, H_PE, S_PE): + """dist.all_to_all reference implementation.""" + PE_num = dist.get_world_size(group) + B, _, S, D = src.shape + NH = H_PE * PE_num + + # send [B, H_PE, S_PE, D] to each rank + input_list = [src[:, :, p * S_PE : (p + 1) * S_PE, :].contiguous() for p in range(PE_num)] + output_list = [torch.empty(B, H_PE, S_PE, D, dtype=src.dtype, device=src.device) for _ in range(PE_num)] + dist.all_to_all(output_list, input_list, group=group) + + result = torch.empty(B, S_PE, NH, D, dtype=src.dtype, device=src.device) + for r in range(PE_num): + # output_list[r] is [B, H_PE, S_PE, D] from rank r; transpose to [B, S_PE, H_PE, D] + result[:, :, r * H_PE : (r + 1) * H_PE, :] = output_list[r].transpose(1, 2) + return result + + +def kernel_post_attn_all2all_transpose(PE_num, B, NH, S_PE, D, dtype="float16"): + H_PE = NH // PE_num + S = S_PE * PE_num + NUM_BLOCKS_X = B * S_PE + + @T.prim_func + def main( + data_src: T.Tensor((B, H_PE, S, D), dtype), + data_dst: T.Tensor((B, S_PE, NH, D), dtype), + ): + with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + batch_idx = bx // S_PE + seq_idx = bx % S_PE + src_seq_idx = target_pe * S_PE + seq_idx + + for head_idx in T.serial(H_PE): + T.put_block( + src=T.address_of(data_src[batch_idx, head_idx, src_seq_idx, 0]), + dst=T.address_of(data_dst[batch_idx, seq_idx, rank[0] * H_PE + head_idx, 0]), + size=D, + dst_pe=target_pe, + ) + + return main + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = dtype_map[args.dtype] + device = "cuda" + B, NH, S, D = args.batch_size, args.num_heads, args.seq_len, args.head_dim + PE_num = num_local_ranks + assert S % PE_num == 0 and NH % PE_num == 0 + S_PE = S // PE_num + H_PE = NH // PE_num + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + allocator = tilelang.get_allocator( + size=2**30, + device=device, + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + ) + + func = kernel_post_attn_all2all_transpose(PE_num, B, NH, S_PE, D, args.dtype) + kernel = tilelang.compile(func) + kernel.initialize(allocator=allocator) + + if local_rank == 0 and args.print_source: + print(kernel.get_kernel_source()) + + src_bufs = tilelang.tensor((B, H_PE, S, D), dtype, allocator=allocator, return_peers=True) + dst_bufs = tilelang.tensor((B, S_PE, NH, D), dtype, allocator=allocator, return_peers=True) + + src_bufs[local_rank].normal_(mean=0.0, std=0.5) + dst_bufs[local_rank].zero_() + dist.barrier(group) + + torch_out = torch_reference(src_bufs[local_rank], group, H_PE, S_PE) + dist.barrier(group) + + def ipc_all2all(): + kernel(src_bufs[local_rank], dst_bufs[local_rank]) + torch.cuda.synchronize() + dist.barrier(group) + + ipc_all2all() + + result = dst_bufs[local_rank].clone() + if torch.allclose(result, torch_out, atol=1e-3, rtol=1e-3): + print(f"rank {local_rank} check passed. \u2705") + else: + diff = (result - torch_out).abs() + print(f"rank {local_rank} check FAILED. max_diff={diff.max():.5f}") + + t = perf_fn(ipc_all2all, warmup=args.warmup, rep=args.repeat) + print(f"rank {local_rank} avg time: {t:.3f} ms") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--dtype", type=str, default="float16", choices=list(dtype_map)) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--repeat", type=int, default=5) + parser.add_argument("--print_source", action="store_true") + args = parser.parse_args() + + torch.multiprocessing.spawn(main, args=(args.num_processes, args), nprocs=args.num_processes) diff --git a/examples/distributed/intranode/example_pre_attn_all2all_intranode.py b/examples/distributed/intranode/example_pre_attn_all2all_intranode.py new file mode 100644 index 0000000000..03eb86ed27 --- /dev/null +++ b/examples/distributed/intranode/example_pre_attn_all2all_intranode.py @@ -0,0 +1,139 @@ +""" +Intranode pre-attention all-to-all using tilescale IPC API. + +Input: [B, NH, S_PE, D] — full heads, partial sequence per rank +Output: [B, H_PE, S, D] — partial heads, full sequence per rank + +Rank r sends src[:, p*H_PE:(p+1)*H_PE, :, :] to rank p's +dst[:, :, r*S_PE:(r+1)*S_PE, :]. +""" + +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist, perf_fn + +dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + + +def torch_reference(src, group, H_PE, S_PE): + """dist.all_to_all reference implementation.""" + PE_num = dist.get_world_size(group) + B, NH, _, D = src.shape + + input_list = [src[:, p * H_PE : (p + 1) * H_PE, :, :].contiguous() for p in range(PE_num)] + output_list = [torch.empty(B, H_PE, S_PE, D, dtype=src.dtype, device=src.device) for _ in range(PE_num)] + dist.all_to_all(output_list, input_list, group=group) + + S = S_PE * PE_num + result = torch.empty(B, H_PE, S, D, dtype=src.dtype, device=src.device) + for r in range(PE_num): + result[:, :, r * S_PE : (r + 1) * S_PE, :] = output_list[r] + return result + + +def kernel_pre_attn_all2all(PE_num, B, NH, S_PE, D, dtype="float16"): + H_PE = NH // PE_num + S = S_PE * PE_num + NUM_BLOCKS_X = B * H_PE + + @T.prim_func + def main( + data_src: T.Tensor((B, NH, S_PE, D), dtype), + data_dst: T.Tensor((B, H_PE, S, D), dtype), + ): + with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + batch_idx = bx // H_PE + head_idx = bx % H_PE + src_head_idx = target_pe * H_PE + head_idx + + T.put_block( + src=T.address_of(data_src[batch_idx, src_head_idx, 0, 0]), + dst=T.address_of(data_dst[batch_idx, head_idx, rank[0] * S_PE, 0]), + size=S_PE * D, + dst_pe=target_pe, + ) + + return main + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = dtype_map[args.dtype] + device = "cuda" + B, NH, S, D = args.batch_size, args.num_heads, args.seq_len, args.head_dim + PE_num = num_local_ranks + assert S % PE_num == 0 and NH % PE_num == 0 + S_PE = S // PE_num + H_PE = NH // PE_num + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + allocator = tilelang.get_allocator( + size=2**30, + device=device, + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + ) + + func = kernel_pre_attn_all2all(PE_num, B, NH, S_PE, D, args.dtype) + kernel = tilelang.compile(func) + kernel.initialize(allocator=allocator) + + if local_rank == 0 and args.print_source: + print(kernel.get_kernel_source()) + + src_bufs = tilelang.tensor((B, NH, S_PE, D), dtype, allocator=allocator, return_peers=True) + dst_bufs = tilelang.tensor((B, H_PE, S, D), dtype, allocator=allocator, return_peers=True) + + src_bufs[local_rank].normal_(mean=0.0, std=0.5) + dst_bufs[local_rank].zero_() + dist.barrier(group) + + torch_out = torch_reference(src_bufs[local_rank], group, H_PE, S_PE) + dist.barrier(group) + + def ipc_all2all(): + kernel(src_bufs[local_rank], dst_bufs[local_rank]) + torch.cuda.synchronize() + dist.barrier(group) + + ipc_all2all() + + result = dst_bufs[local_rank].clone() + if torch.allclose(result, torch_out, atol=1e-3, rtol=1e-3): + print(f"rank {local_rank} check passed.") + else: + diff = (result - torch_out).abs() + print(f"rank {local_rank} check FAILED. max_diff={diff.max():.5f}") + + t = perf_fn(ipc_all2all, warmup=args.warmup, rep=args.repeat) + print(f"rank {local_rank} avg time: {t:.3f} ms") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--dtype", type=str, default="float16", choices=list(dtype_map)) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--repeat", type=int, default=5) + parser.add_argument("--print_source", action="store_true") + args = parser.parse_args() + + torch.multiprocessing.spawn(main, args=(args.num_processes, args), nprocs=args.num_processes) diff --git a/examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py b/examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py new file mode 100644 index 0000000000..8c57407d73 --- /dev/null +++ b/examples/distributed/intranode/example_pre_attn_all2all_transpose_intranode.py @@ -0,0 +1,143 @@ +""" +Intranode pre-attention all-to-all (transpose) using tilescale IPC API. + +Input: [B, S_PE, NH, D] — partial sequence, full heads per rank +Output: [B, H_PE, S, D] — partial heads, full sequence per rank + +Rank r sends src[:, :, p*H_PE:(p+1)*H_PE, :] (shape [B, S_PE, H_PE, D]) +to rank p's dst[:, :, r*S_PE:(r+1)*S_PE, :] (shape [B, H_PE, S_PE, D]) +after transposing dims 1 and 2. +""" + +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist, perf_fn + +dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, +} + + +def torch_reference(src, group, H_PE, S_PE): + """dist.all_to_all reference implementation.""" + PE_num = dist.get_world_size(group) + B, _, NH, D = src.shape + + # send [B, S_PE, H_PE, D] to each rank + input_list = [src[:, :, p * H_PE : (p + 1) * H_PE, :].contiguous() for p in range(PE_num)] + output_list = [torch.empty(B, S_PE, H_PE, D, dtype=src.dtype, device=src.device) for _ in range(PE_num)] + dist.all_to_all(output_list, input_list, group=group) + + S = S_PE * PE_num + result = torch.empty(B, H_PE, S, D, dtype=src.dtype, device=src.device) + for r in range(PE_num): + # output_list[r] is [B, S_PE, H_PE, D] from rank r; transpose to [B, H_PE, S_PE, D] + result[:, :, r * S_PE : (r + 1) * S_PE, :] = output_list[r].transpose(1, 2) + return result + + +def kernel_pre_attn_all2all_transpose(PE_num, B, NH, S_PE, D, dtype="float16"): + H_PE = NH // PE_num + S = S_PE * PE_num + NUM_BLOCKS_X = B * H_PE + + @T.prim_func + def main( + data_src: T.Tensor((B, S_PE, NH, D), dtype), + data_dst: T.Tensor((B, H_PE, S, D), dtype), + ): + with T.Kernel(NUM_BLOCKS_X, PE_num, threads=128) as (bx, target_pe): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + batch_idx = bx // H_PE + head_idx = bx % H_PE + src_head_idx = target_pe * H_PE + head_idx + + for seq_idx in T.serial(S_PE): + T.put_block( + src=T.address_of(data_src[batch_idx, seq_idx, src_head_idx, 0]), + dst=T.address_of(data_dst[batch_idx, head_idx, rank[0] * S_PE + seq_idx, 0]), + size=D, + dst_pe=target_pe, + ) + + return main + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = dtype_map[args.dtype] + device = "cuda" + B, NH, S, D = args.batch_size, args.num_heads, args.seq_len, args.head_dim + PE_num = num_local_ranks + assert S % PE_num == 0 and NH % PE_num == 0 + S_PE = S // PE_num + H_PE = NH // PE_num + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + allocator = tilelang.get_allocator( + size=2**30, + device=device, + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + ) + + func = kernel_pre_attn_all2all_transpose(PE_num, B, NH, S_PE, D, args.dtype) + kernel = tilelang.compile(func) + kernel.initialize(allocator=allocator) + + if local_rank == 0 and args.print_source: + print(kernel.get_kernel_source()) + + src_bufs = tilelang.tensor((B, S_PE, NH, D), dtype, allocator=allocator, return_peers=True) + dst_bufs = tilelang.tensor((B, H_PE, S, D), dtype, allocator=allocator, return_peers=True) + + src_bufs[local_rank].normal_(mean=0.0, std=0.5) + dst_bufs[local_rank].zero_() + dist.barrier(group) + + torch_out = torch_reference(src_bufs[local_rank], group, H_PE, S_PE) + dist.barrier(group) + + def ipc_all2all(): + kernel(src_bufs[local_rank], dst_bufs[local_rank]) + torch.cuda.synchronize() + dist.barrier(group) + + ipc_all2all() + + result = dst_bufs[local_rank].clone() + if torch.allclose(result, torch_out, atol=1e-3, rtol=1e-3): + print(f"rank {local_rank} check passed.") + else: + diff = (result - torch_out).abs() + print(f"rank {local_rank} check FAILED. max_diff={diff.max():.5f}") + + t = perf_fn(ipc_all2all, warmup=args.warmup, rep=args.repeat) + print(f"rank {local_rank} avg time: {t:.3f} ms") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--dtype", type=str, default="float16", choices=list(dtype_map)) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--repeat", type=int, default=5) + parser.add_argument("--print_source", action="store_true") + args = parser.parse_args() + + torch.multiprocessing.spawn(main, args=(args.num_processes, args), nprocs=args.num_processes) diff --git a/examples/distributed/intranode/example_reduce_scatter.py b/examples/distributed/intranode/example_reduce_scatter.py new file mode 100644 index 0000000000..892cf530f7 --- /dev/null +++ b/examples/distributed/intranode/example_reduce_scatter.py @@ -0,0 +1,68 @@ +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +import tilelang +from tilelang.distributed import init_dist, perf_fn +from reduce_scatter import reduce_scatter_2d_op, create_reduce_scater_2d_ctx + + +def torch_reduce_scatter( + pg: torch.distributed.ProcessGroup, + input: torch.Tensor, + num_local_ranks: int, +) -> torch.Tensor: + M, N = input.shape + output = torch.empty((M // num_local_ranks, N), dtype=input.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(output, input, group=pg) + return output + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = torch.float16 + M = args.M + N = args.N + M_per_rank = M // num_local_ranks + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now" + + allocator = tilelang.get_allocator( + size=2**30, device="cuda", is_distributed=True, local_rank=local_rank, num_local_ranks=num_local_ranks, group=group + ) + + input_tensor = tilelang.tensor((M, N), dtype, allocator=allocator).normal_() / 10 + output_tensor = tilelang.tensor((M_per_rank, N), dtype, allocator=allocator) + + ctx = create_reduce_scater_2d_ctx(M, N, local_rank, num_local_ranks, num_local_ranks, dtype, allocator, overlap_with_gemm=False) + + dist.barrier() + + tilelang_out = reduce_scatter_2d_op(input_tensor, ctx, output_tensor) + torch_out = torch_reduce_scatter(group, input_tensor, num_local_ranks) + + atol = 1e-2 + rtol = 1e-2 + if torch.allclose(torch_out, tilelang_out, atol=atol, rtol=rtol): + print(f"rank {local_rank} check passed. ✅") + else: + print(f"rank {local_rank} check failed. ❌") + print(f"max diff: {(torch_out - tilelang_out).abs().max()}") + + tl_t = perf_fn(lambda: reduce_scatter_2d_op(input_tensor, ctx, output_tensor), warmup=5, rep=10) + input_bytes = M * N * torch.finfo(dtype).bits // 8 + algbw = input_bytes / tl_t / 1e6 # GB/s + print(f"rank {local_rank} tilelang reduce_scatter time: {tl_t:.2f} ms, algbw: {algbw:.2f} GB/s") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=2, help="Number of processes to spawn (default: 2)") + parser.add_argument("--M", type=int, default=8192, help="M dimension") + parser.add_argument("--N", type=int, default=8192, help="N dimension") + args = parser.parse_args() + num_processes = args.num_processes + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) diff --git a/examples/distributed/example_sp_ag_attention_intra_node.py b/examples/distributed/intranode/example_sp_ag_attention_intra_node.py similarity index 100% rename from examples/distributed/example_sp_ag_attention_intra_node.py rename to examples/distributed/intranode/example_sp_ag_attention_intra_node.py diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/intranode/reduce_scatter.py similarity index 99% rename from examples/distributed/reduce_scatter.py rename to examples/distributed/intranode/reduce_scatter.py index 8f39a7a95b..3da29e1434 100644 --- a/examples/distributed/reduce_scatter.py +++ b/examples/distributed/intranode/reduce_scatter.py @@ -16,8 +16,6 @@ import torch.distributed as dist import tilelang.language as T -tilelang.disable_cache() - @dataclasses.dataclass class ReduceScatter2DContext: diff --git a/examples/distributed/sp_ag_attention_intra_node.py b/examples/distributed/intranode/sp_ag_attention_intra_node.py similarity index 100% rename from examples/distributed/sp_ag_attention_intra_node.py rename to examples/distributed/intranode/sp_ag_attention_intra_node.py diff --git a/examples/distributed/intranode/test_intranode.py b/examples/distributed/intranode/test_intranode.py new file mode 100644 index 0000000000..c5565eb652 --- /dev/null +++ b/examples/distributed/intranode/test_intranode.py @@ -0,0 +1,60 @@ +import torch +import tilelang +import tilelang.testing + +import example_allgather_gemm_overlapped +import example_reduce_scatter +import example_gemm_rs_overlapped +import example_sp_ag_attention_intra_node +import example_pre_attn_all2all_intranode +import example_pre_attn_all2all_transpose_intranode +import example_post_attn_all2all_transpose_intranode + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_allgather_gemm_overlapped(): + torch.multiprocessing.spawn(example_allgather_gemm_overlapped.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_reduce_scatter(): + torch.multiprocessing.spawn(example_reduce_scatter.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_gemm_rs_overlapped(): + torch.multiprocessing.spawn(example_gemm_rs_overlapped.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_sp_ag_attention_intra_node(): + torch.multiprocessing.spawn(example_sp_ag_attention_intra_node.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_pre_attn_all2all_intranode(): + torch.multiprocessing.spawn(example_pre_attn_all2all_intranode.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_pre_attn_all2all_transpose_intranode(): + torch.multiprocessing.spawn(example_pre_attn_all2all_transpose_intranode.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_post_attn_all2all_transpose_intranode(): + torch.multiprocessing.spawn(example_post_attn_all2all_transpose_intranode.main, args=(2, None), nprocs=2) diff --git a/examples/distributed/example_all_to_all.py b/examples/distributed/nvshmem/example_all_to_all.py similarity index 99% rename from examples/distributed/example_all_to_all.py rename to examples/distributed/nvshmem/example_all_to_all.py index dd0157c893..128f9f670a 100644 --- a/examples/distributed/example_all_to_all.py +++ b/examples/distributed/nvshmem/example_all_to_all.py @@ -7,8 +7,6 @@ import argparse import random -tilelang.disable_cache() - def all_to_all(PE_num, TOKEN_NUM, TOPK, HIDDEN, EXPERT_NUM, dtype="float16"): EXPERTS_PER_RANK = EXPERT_NUM // PE_num diff --git a/examples/distributed/example_allgather.py b/examples/distributed/nvshmem/example_allgather.py similarity index 100% rename from examples/distributed/example_allgather.py rename to examples/distributed/nvshmem/example_allgather.py diff --git a/examples/distributed/example_cannon.py b/examples/distributed/nvshmem/example_cannon.py similarity index 99% rename from examples/distributed/example_cannon.py rename to examples/distributed/nvshmem/example_cannon.py index ad25a41e7a..ef0da0d5af 100644 --- a/examples/distributed/example_cannon.py +++ b/examples/distributed/nvshmem/example_cannon.py @@ -7,8 +7,6 @@ import math import argparse -tilelang.disable_cache() - def cannon(MESH, M, N, K, block_M, block_N, block_K, dtype="float16", specialize=False): M_local = T.ceildiv(M, MESH) diff --git a/examples/distributed/nvshmem/example_nvshmem.py b/examples/distributed/nvshmem/example_nvshmem.py new file mode 100644 index 0000000000..41cc374025 --- /dev/null +++ b/examples/distributed/nvshmem/example_nvshmem.py @@ -0,0 +1,32 @@ +import tilelang +import tilelang.language as T + + +def dist_test(M, N, block_M, block_N, dtype="int16"): + @T.prim_func + def main( + A: T.Buffer((M, N), dtype), + B: T.Buffer((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), dtype) + mype = T.alloc_local([1], "int32") + + mype[0] = T.get_pe() + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(A_shared, B[by * block_M, bx * block_N]) + + return main + + +func = dist_test(128, 128, 128, 128) + +kernel = tilelang.compile(func, out_idx=-1) + +# Get CUDA Source +print(kernel.get_kernel_source()) + +profiler = kernel.get_profiler() +out = profiler.run_once() + +print(out) diff --git a/examples/distributed/example_overlapping_allgather.py b/examples/distributed/nvshmem/example_overlapping_allgather.py similarity index 100% rename from examples/distributed/example_overlapping_allgather.py rename to examples/distributed/nvshmem/example_overlapping_allgather.py diff --git a/examples/distributed/example_post_attn_all2all_transpose.py b/examples/distributed/nvshmem/example_post_attn_all2all_transpose.py similarity index 100% rename from examples/distributed/example_post_attn_all2all_transpose.py rename to examples/distributed/nvshmem/example_post_attn_all2all_transpose.py diff --git a/examples/distributed/example_pre_attn_all2all.py b/examples/distributed/nvshmem/example_pre_attn_all2all.py similarity index 100% rename from examples/distributed/example_pre_attn_all2all.py rename to examples/distributed/nvshmem/example_pre_attn_all2all.py diff --git a/examples/distributed/example_pre_attn_all2all_transpose.py b/examples/distributed/nvshmem/example_pre_attn_all2all_transpose.py similarity index 100% rename from examples/distributed/example_pre_attn_all2all_transpose.py rename to examples/distributed/nvshmem/example_pre_attn_all2all_transpose.py diff --git a/examples/distributed/example_simple_shift.py b/examples/distributed/nvshmem/example_simple_shift.py similarity index 100% rename from examples/distributed/example_simple_shift.py rename to examples/distributed/nvshmem/example_simple_shift.py diff --git a/examples/distributed/example_summa.py b/examples/distributed/nvshmem/example_summa.py similarity index 99% rename from examples/distributed/example_summa.py rename to examples/distributed/nvshmem/example_summa.py index 640a31de6b..216d145d24 100644 --- a/examples/distributed/example_summa.py +++ b/examples/distributed/nvshmem/example_summa.py @@ -3,12 +3,11 @@ import pynvshmem import tilelang import tilelang.language as T +from tilelang.carver.arch import driver from tilelang.distributed import init_distributed, dtype_map import math import argparse -tilelang.disable_cache() - def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"): M_local = T.ceildiv(M, MESH) @@ -16,7 +15,7 @@ def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"): K_local = T.ceildiv(K, MESH) accum_dtype = "float32" - sm_num = 132 # 132 SMs for H100 + sm_num = driver.get_num_sms() total_tiles = T.ceildiv(M_local, block_M) * T.ceildiv(N_local, block_N) @T.prim_func diff --git a/examples/distributed/triton_sp.py b/examples/distributed/triton_sp.py deleted file mode 100644 index 1b99a5fac3..0000000000 --- a/examples/distributed/triton_sp.py +++ /dev/null @@ -1,730 +0,0 @@ -################################################################################ -# -# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files -# (the "Software"), to deal in the Software without restriction, -# including without limitation the rights to use, copy, modify, merge, -# publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, -# subject to the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ - -import torch -import triton -import triton.language as tl -from triton.language import core as tlc -import triton_dist.language as dl - -from typing import Optional -import itertools - -from triton_dist.utils import nvshmem_create_tensor, nvshmem_free_tensor_sync, cuda_stream_max_priority, supports_p2p_native_atomic -from triton_dist.kernels.nvidia.common_ops import barrier_all_intra_node_atomic_cas_block -from triton.language.extra.cuda.language_extra import tid, __syncthreads, st - - -@tlc.extern -def load_v4_b32(ptr, _semantic=None): - return tl.inline_asm_elementwise( - asm="ld.global.v4.b32 {$0,$1,$2,$3}, [$4];", - constraints=("=r,=r,=r,=r,l"), - args=[ptr], - dtype=(tl.int32, tl.int32, tl.int32, tl.int32), - is_pure=False, - pack=1, - _semantic=_semantic, - ) - - -@tlc.extern -def store_v4_b32(ptr, val0, val1, val2, val3, _semantic=None): - return tl.inline_asm_elementwise( - asm=""" - st.global.v4.b32 [$1], {$2,$3,$4,$5}; - mov.u32 $0, 0; - """, - constraints=("=r,l,r,r,r,r"), # no use output - args=[ptr, val0, val1, val2, val3], - dtype=tl.int32, - is_pure=False, - pack=1, - _semantic=_semantic, - ) - - -@tlc.extern -def load_v4_b32_cond(ptr, mask, _semantic=None): - return tl.inline_asm_elementwise( - asm=""" - { - .reg .pred %p0; - setp.eq.s32 %p0, $5, 1; - @%p0 ld.global.v4.b32 {$0,$1,$2,$3}, [$4]; - } - """, - constraints=("=r,=r,=r,=r,l,r"), - args=[ptr, mask.to(tl.int32, _semantic=_semantic)], - dtype=(tl.int32, tl.int32, tl.int32, tl.int32), - is_pure=False, - pack=1, - _semantic=_semantic, - ) - - -@tlc.extern -def store_v4_b32_cond(ptr, val0, val1, val2, val3, mask, _semantic=None): - return tl.inline_asm_elementwise( - asm=""" - { - .reg .pred %p0; - setp.eq.s32 %p0, $6, 1; - @%p0 st.global.v4.b32 [$1], {$2,$3,$4,$5}; - mov.u32 $0, 0; - } - """, - constraints=("=r,l,r,r,r,r,r"), # no use output - args=[ptr, val0, val1, val2, val3, mask.to(tl.int32, _semantic=_semantic)], - dtype=tl.int32, - is_pure=False, - pack=1, - _semantic=_semantic, - ) - - -@triton.jit -def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_GEMM_SMS): - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - return pid_m, pid_n - - -def _matmul_launch_metadata(grid, kernel, args): - ret = {} - M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False) - ws_str = "_ws" if WS else "" - ret["name"] = f"{kernel.name}{ws_str} [M={M}, N={N}, K={K}]" - if "c_ptr" in args: - bytes_per_elem = args["c_ptr"].element_size() - else: - bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 - ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K - ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) - return ret - - -def _kernel_consumer_gemm_persistent_repr(proxy): - constexprs = proxy.constants - cap_major, cap_minor = torch.cuda.get_device_capability() - a_dtype = proxy.signature["a_ptr"].lstrip("*") - b_dtype = proxy.signature["b_ptr"].lstrip("*") - c_dtype = proxy.signature["c_ptr"].lstrip("*") - BM, BN, BK = constexprs["BLOCK_SIZE_M"], constexprs["BLOCK_SIZE_N"], constexprs["BLOCK_SIZE_K"] - - return ( - f"cutlass_triton3x_sm{cap_major}{cap_minor}_a2a_consumer_gemm_persistent_tensorop_{a_dtype}_{b_dtype}_{c_dtype}_{BM}x{BN}x{BK}_ntn" - ) - - -@triton.jit(do_not_specialize=["sp_rank"], launch_metadata=_matmul_launch_metadata, repr=_kernel_consumer_gemm_persistent_repr) -def matmul_kernel_descriptor_persistent( - a_ptr, - b_ptr, - bias_ptr, - c_ptr, # - gemm_barrier_ptr, - sp_rank, - sp_size: tl.constexpr, - M, - N: tl.constexpr, - K: tl.constexpr, # - A2A_TILE_M: tl.constexpr, - A2A_TILE_N: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, # - BLOCK_SIZE_N: tl.constexpr, # - BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # - EPILOGUE_SUBTILE: tl.constexpr, # - NUM_GEMM_SMS: tl.constexpr, # - WARP_SPECIALIZE: tl.constexpr, # - HAS_BIAS: tl.constexpr, -): - # Matmul using TMA and device-side descriptor creation - dtype = c_ptr.dtype.element_ty - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_tiles = num_pid_m * num_pid_n - - tl.static_assert(K % sp_size == 0, f"K {K} must be divisible by sp_size {sp_size}") - K_per_sp_rank: tl.constexpr = K // sp_size - tl.static_assert(K_per_sp_rank % BLOCK_SIZE_K == 0, f"K_per_sp_rank {K_per_sp_rank} must be divisible by BLOCK_SIZE_K {BLOCK_SIZE_K}") - k_tiles: tl.constexpr = K // BLOCK_SIZE_K - - tl.static_assert(A2A_TILE_N % BLOCK_SIZE_K == 0, f"A2A_TILE_N {A2A_TILE_N} must be divisible by BLOCK_SIZE_N {BLOCK_SIZE_K}") - NUM_K_PER_TILE: tl.constexpr = A2A_TILE_N // BLOCK_SIZE_K - # This is used for k-swizzle - # k_tiles_per_rank: tl.constexpr = K_per_sp_rank // BLOCK_SIZE_K - # k_vec_tiles_per_rank: tl.constexpr = k_tiles_per_rank // NUM_K_PER_TILE - - a_desc = tl.make_tensor_descriptor( - a_ptr, - shape=[M, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) - b_desc = tl.make_tensor_descriptor( - b_ptr, - shape=[N, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) - c_desc = tl.make_tensor_descriptor( - c_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], - ) - - # tile_id_c is used in the epilogue to break the dependency between - # the prologue and the epilogue - tile_id_c = start_pid - NUM_GEMM_SMS - num_pid_in_group = GROUP_SIZE_M * num_pid_n - - for tile_id in tl.range(start_pid, num_tiles, NUM_GEMM_SMS, flatten=False, warp_specialize=WARP_SPECIALIZE): - pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_GEMM_SMS) - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N - - chunk_beg = pid_m * BLOCK_SIZE_M // A2A_TILE_M - chunk_end = (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) // A2A_TILE_M - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for ki in range(k_tiles): - # k-swizzle: as the all-to-all comes in non-serial order, a swizzle may help in performance - # vec = NUM_K_PER_TILE - # vec = 4 - # ki_vec = ki // vec - # ki_elem = ki % vec - # swizzle_ki_vec = (ki_vec % sp_size + sp_rank) % sp_size - # ki = (swizzle_ki_vec * k_vec_tiles_per_rank + ki_vec // sp_size) * vec + ki_elem - - if ki % NUM_K_PER_TILE == 0: - for chunk_id in range(chunk_beg, chunk_end + 1): - token = dl.wait( - gemm_barrier_ptr + chunk_id * (k_tiles // NUM_K_PER_TILE) + ki // NUM_K_PER_TILE, - 1, - scope="gpu", - semantic="acquire", - waitValue=1, - ) - a_desc = dl.consume_token(a_desc, token) - offs_k = ki * BLOCK_SIZE_K - a = a_desc.load([offs_am, offs_k]) - b = b_desc.load([offs_bn, offs_k]) - accumulator = tl.dot(a, b.T, accumulator) - - tile_id_c += NUM_GEMM_SMS - pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_GEMM_SMS) - offs_cm = pid_m * BLOCK_SIZE_M - offs_cn = pid_n * BLOCK_SIZE_N - - if HAS_BIAS: - offs_bias_n = tl.arange(0, BLOCK_SIZE_N) - bias_data = tl.load(bias_ptr + offs_cn + offs_bias_n, mask=(offs_cn + offs_bias_n < N)).to(tl.float32) - accumulator = accumulator + bias_data[None, :] - - if EPILOGUE_SUBTILE: - acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - c0 = acc0.to(dtype) - c_desc.store([offs_cm, offs_cn], c0) - c1 = acc1.to(dtype) - c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) - else: - c = accumulator.to(dtype) - c_desc.store([offs_cm, offs_cn], c) - - -def matmul_descriptor_persistent(sp_rank, sp_size, a, b, bias, c, gemm_barrier, gemm_config: triton.Config, warp_specialize: bool = False): - # Check constraints. - assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed - assert a.dtype == b.dtype, "Incompatible dtypes" - - M, K = a.shape - N, K = b.shape - - # TMA descriptors require a global memory allocation - def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) - - triton.set_allocator(alloc_fn) - - def grid(META): - return (min(META["NUM_GEMM_SMS"], triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),) - - matmul_kernel_descriptor_persistent[grid]( - a, - b, - bias, - c, - gemm_barrier, # - sp_rank, - sp_size, - M, - N, - K, # - EPILOGUE_SUBTILE=False, # - WARP_SPECIALIZE=warp_specialize, # - **gemm_config.all_kwargs(), # - HAS_BIAS=1 if bias is not None else 0, - ) - return c - - -@triton.jit(do_not_specialize=["rank", "sp_rank"]) -def kernel_all2all_push_intra_node_nvl( - attn_out_ptr, - a2a_out_ptr, - cum_seqlen_cpu_tuple, - cum_seqlen_gpu_ptr, - barrier_ptr, - intra_node_sync_buf_ptr, - local_head: tl.constexpr, - global_head, - head_dim: tl.constexpr, - sp_size: tl.constexpr, - rank, - sp_rank, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - NUM_COMM_SM: tl.constexpr, - FUSE_SYNC: tl.constexpr, - SUPPORT_ATOMIC: tl.constexpr, - VEC: tl.constexpr, - SKIP_BARRIER: tl.constexpr = False, -): - pid = tl.program_id(0) - if SKIP_BARRIER: - num_pids = tl.num_programs(0) - empty_pids = num_pids - NUM_COMM_SM - if pid < empty_pids: - return - pid = pid - empty_pids - - if FUSE_SYNC: - tl.static_assert(SUPPORT_ATOMIC, "FUSE_SYNC requires SUPPORT_ATOMIC to be True") - barrier_all_intra_node_atomic_cas_block(sp_rank, rank, sp_size, intra_node_sync_buf_ptr + pid * sp_size) - - for i in tl.static_range(sp_size + 1): - tl.store(cum_seqlen_gpu_ptr + i, cum_seqlen_cpu_tuple[i]) - __syncthreads() - - rank_offset = rank - sp_rank - - offs_m = tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N // VEC) - - if sp_size <= NUM_COMM_SM: - tl.static_assert(NUM_COMM_SM % sp_size == 0, f"NUM_COMM_SM {NUM_COMM_SM} must be divisible by sp_size {sp_size}") - NUM_SM_PER_SP: tl.constexpr = NUM_COMM_SM // sp_size - NUM_SP_PER_SM: tl.constexpr = 1 - else: - tl.static_assert(sp_size % NUM_COMM_SM == 0, f"sp_size {sp_size} must be divisible by NUM_COMM_SM {NUM_COMM_SM}") - NUM_SM_PER_SP: tl.constexpr = 1 - NUM_SP_PER_SM: tl.constexpr = sp_size // NUM_COMM_SM - - for tile in range(NUM_SP_PER_SM): - remote_sp_rank = pid * NUM_SP_PER_SM // NUM_SM_PER_SP + tile - remote_rank = remote_sp_rank + rank_offset - remote_a2a_out_ptr = dl.symm_at(a2a_out_ptr, remote_rank) - remote_barrier_ptr = dl.symm_at(barrier_ptr, remote_rank) - pid_in_sp = pid % NUM_SM_PER_SP - seq_beg = tl.load(cum_seqlen_gpu_ptr + remote_sp_rank) - seq_end = tl.load(cum_seqlen_gpu_ptr + remote_sp_rank + 1) - remote_seq_len = seq_end - seq_beg - num_tile_m = tl.cdiv(remote_seq_len, BLOCK_M) - tl.static_assert( - local_head * head_dim % BLOCK_N == 0, f"local_head * head_dim {local_head * head_dim} must be divisible by BLOCK_N {BLOCK_N}" - ) - num_tile_n = local_head * head_dim // BLOCK_N - - for tile_id_m_outer_n_tail in range(0, tl.cdiv(num_tile_m, GROUP_SIZE_M) * num_tile_n): - tile_id_m_outer_tail = tile_id_m_outer_n_tail // num_tile_n - tile_id_n_tail = tile_id_m_outer_n_tail % num_tile_n - for tile_id_m_inner_tail in range(pid_in_sp, GROUP_SIZE_M, NUM_SM_PER_SP): - tile_id_m_tail = tile_id_m_outer_tail * GROUP_SIZE_M + tile_id_m_inner_tail - if tile_id_m_tail < num_tile_m: - attn_offs_m = seq_beg + tile_id_m_tail * BLOCK_M + offs_m - attn_mask_m = attn_offs_m < seq_end - attn_offs_n = tile_id_n_tail * BLOCK_N + offs_n * VEC - data0, data1, data2, data3 = load_v4_b32_cond( - attn_out_ptr + attn_offs_m[:, None] * local_head * head_dim + attn_offs_n[None, :], mask=attn_mask_m[:, None] - ) - - out_offs_m = tile_id_m_tail * BLOCK_M + offs_m - out_mask_m = out_offs_m < remote_seq_len - out_offs_n = sp_rank * local_head * head_dim + tile_id_n_tail * BLOCK_N + offs_n * VEC - store_v4_b32_cond( - remote_a2a_out_ptr + out_offs_m[:, None] * global_head * head_dim + out_offs_n[None, :], - data0, - data1, - data2, - data3, - mask=out_mask_m[:, None], - ) - - if not SKIP_BARRIER: - __syncthreads() - notify_barrier_ptr = ( - remote_barrier_ptr + tile_id_m_tail * num_tile_n * sp_size + sp_rank * num_tile_n + tile_id_n_tail - ) - thread_idx = tid(0) - if thread_idx == 0: - st(notify_barrier_ptr, 1, scope="sys", semantic="release") - - -class SpUlysessOAll2AllGemmKernel: - def __init__( - self, - world_group: torch.distributed.ProcessGroup, - nnodes: int, - sp_size: int, - max_batch: int, - num_head: int, - max_seqlen: int, - head_dim: int, - max_num_comm_buf: int, - input_dtype=torch.bfloat16, - output_dtype=torch.bfloat16, - a2a_only: bool = True, - fuse_sync: bool = True, - ): - self.world_group = world_group - self.world_size = world_group.size() - self.rank = world_group.rank() - self.nnodes = nnodes - assert self.world_size % nnodes == 0, f"world_size {self.world_size} must be divisible by nnodes {nnodes}" - self.local_world_size = self.world_size // nnodes - self.local_rank = self.rank % self.local_world_size - self.sp_size = sp_size - assert self.local_world_size % self.sp_size == 0, f"local_world_size {self.local_world_size} must be divisible by sp_size {sp_size}" - self.sp_rank = self.local_rank % self.sp_size - self.max_batch = max_batch - self.num_head = num_head - self.max_seqlen = max_seqlen - self.head_dim = head_dim - self.max_num_comm_buf = max_num_comm_buf - self.input_dtype = input_dtype - self.output_dtype = output_dtype - self.a2a_only = a2a_only - assert self.a2a_only, "Only support a2a_only mode" - self.fuse_sync = fuse_sync - - self.compute_stream = torch.cuda.Stream(priority=cuda_stream_max_priority()) - self.cp_event = torch.cuda.Event(enable_timing=False) - self.ready_event = torch.cuda.Event(enable_timing=False) - self.compute_event = torch.cuda.Event(enable_timing=False) - - self.p2p_atomic_supported = supports_p2p_native_atomic() - self.max_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - - # GEMM config - self.BLOCK_SIZE_M = 128 - self.BLOCK_SIZE_N = 256 - self.BLOCK_SIZE_K = 64 - self.GROUP_SIZE_M = 4 - self.A2A_TILE_M = 128 - self.A2A_TILE_N = 256 - self.max_gemm_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - self.num_warps = 8 - self.num_stages = 3 - self.warp_specialize = False - - self.init_symm_buffer() - self.init_local_buffer() - - def __del__(self): - self.finalize() - - def finalize(self): - self.deinit_symm_buffer() - - def init_symm_buffer(self): - max_local_seq = self.max_seqlen // self.sp_size - self._comm_output_buffer = nvshmem_create_tensor( - [self.max_num_comm_buf, self.max_batch, max_local_seq, self.num_head * self.head_dim], self.input_dtype - ) - self._barrier_buffer = nvshmem_create_tensor( - [triton.cdiv(self.max_batch * self.max_seqlen, self.BLOCK_SIZE_M) * self.num_head], torch.int32 - ) - self._barrier_buffer.zero_() - self._intra_node_sync_buffer = nvshmem_create_tensor([self.sp_size * self.max_sms], torch.int32) - self._intra_node_sync_buffer.zero_() - self._sp_group_sync_buffer = nvshmem_create_tensor([self.world_size], torch.int32) - self._sp_group_sync_buffer.zero_() - - def deinit_symm_buffer(self): - if hasattr(self, "_comm_output_buffer"): - nvshmem_free_tensor_sync(self._comm_output_buffer) - del self._comm_output_buffer - if hasattr(self, "_barrier_buffer"): - nvshmem_free_tensor_sync(self._barrier_buffer) - del self._barrier_buffer - if hasattr(self, "_intra_node_sync_buffer"): - nvshmem_free_tensor_sync(self._intra_node_sync_buffer) - del self._intra_node_sync_buffer - if hasattr(self, "_sp_group_sync_buffer"): - nvshmem_free_tensor_sync(self._sp_group_sync_buffer) - del self._sp_group_sync_buffer - - def init_local_buffer(self): - self._cum_seqlen_gpu = torch.empty([self.sp_size + 1], dtype=torch.int32, device="cuda") - - def sp_group_barrier_all_intra_node(self, stream=None): - stream = torch.cuda.current_stream() if stream is None else stream - sp_local_rank = self.local_rank % self.sp_size - with torch.cuda.stream(stream): - barrier_all_intra_node_atomic_cas_block[(1,)](sp_local_rank, self.rank, self.sp_size, self._sp_group_sync_buffer) - - def reset_cusum_seq_lens(self, local_seqlen, seq_lens_cpu=None): - if seq_lens_cpu is None: - seq_lens_cpu = [local_seqlen] * self.sp_size - else: - seq_lens_cpu = seq_lens_cpu.tolist() - assert local_seqlen == seq_lens_cpu[self.local_rank % self.sp_size], ( - f"local_seqlen {local_seqlen} != seq_lens_cpu[{self.local_rank % self.sp_size}]={seq_lens_cpu[self.local_rank % self.sp_size]}" - ) - cum_seqlen_cpu = [0] + list(itertools.accumulate(seq_lens_cpu)) - self._cum_seq_len_cpu_tuple = tuple(cum_seqlen_cpu) - - def forward( - self, - inputs: torch.Tensor, - weight: torch.Tensor, - seq_lens_cpu: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - output: Optional[torch.Tensor] = None, - a2a_output: Optional[torch.Tensor] = None, - transpose_weight: bool = False, - num_comm_sms: int = -1, - sm_margin: int = 0, - ): - if num_comm_sms == -1: - num_comm_sms = self.world_size - assert num_comm_sms >= 0, "num_comm_sms must be non-negative" - assert len(weight.shape) == 2, f"weight must be 2D tensor, got {len(weight)}D" - assert len(inputs.shape) == 4, f"inputs must be 4D tensor, got {len(inputs)}D" - bs, total_seq_len, local_head, head_dim = inputs.shape - assert head_dim == self.head_dim, f"head_dim {head_dim} must be equal to self.head_dim {self.head_dim}" - assert weight.is_contiguous(), f"weight must be contiguous, got {weight.shape}" - assert inputs.is_contiguous(), f"inputs must be contiguous, got {inputs.shape}" - assert not transpose_weight, "transpose_weight is not supported in this kernel" - - if not transpose_weight: - N = weight.shape[0] - K = weight.shape[1] - else: - N = weight.shape[1] - K = weight.shape[0] - - if seq_lens_cpu is not None: - assert seq_lens_cpu.is_cpu, "seq_lens_cpu must be a CPU tensor" - assert seq_lens_cpu.dtype == torch.int32, "seq_lens_cpu must be int32" - assert seq_lens_cpu.is_contiguous(), "seq_lens_cpu must be contiguous" - - seq_lens_cpu_tuple = tuple(seq_lens_cpu.tolist()) - local_seq_len = seq_lens_cpu_tuple[self.sp_rank] - M = local_seq_len * bs - else: - assert total_seq_len % self.sp_size == 0, f"total_seq_len {total_seq_len} must be divisible by sp_size {self.sp_size}" - local_seq_len = total_seq_len // self.sp_size - M = local_seq_len * bs - - self.reset_cusum_seq_lens(local_seqlen=local_seq_len, seq_lens_cpu=seq_lens_cpu) - - gemm_input_a = self._comm_output_buffer.view(-1)[: M * K].view([M, K]) - - cur_stream = torch.cuda.current_stream() - - self._barrier_buffer.zero_() - if not self.fuse_sync: - self.sp_group_barrier_all_intra_node(cur_stream) - - self.ready_event.record(cur_stream) - self.compute_stream.wait_event(self.ready_event) - - grid = (num_comm_sms,) - kernel_all2all_push_intra_node_nvl[grid]( - inputs, - gemm_input_a, - self._cum_seq_len_cpu_tuple, - self._cum_seqlen_gpu, - self._barrier_buffer, - self._intra_node_sync_buffer, # no need to initialize - local_head, - local_head * self.sp_size, - self.head_dim, - self.sp_size, - self.rank, - self.sp_rank, - self.A2A_TILE_M, - self.A2A_TILE_N, - self.GROUP_SIZE_M, - num_comm_sms, - self.fuse_sync, - self.p2p_atomic_supported, - VEC=(16 // inputs.dtype.itemsize), - num_warps=32, - ) - - if output is None: - output = torch.empty([bs, local_seq_len, N], device=inputs.device, dtype=self.output_dtype) - - assert len(output.shape) == 3, f"output must be 4D tensor, got {len(output)}D" - assert output.shape[0] == bs, f"output batch size {output.shape[0]} must be equal to input batch size {bs}" - assert output.shape[1] == local_seq_len, f"output seq_len {output.shape[1]} must be equal to local_seq_len {local_seq_len}" - assert output.shape[2] == N, f"output head {output.shape[2]} must be equal to output size {N}" - assert output.is_contiguous(), f"output must be contiguous, got {output.shape}" - - assert self.max_gemm_sms - num_comm_sms - sm_margin > 0, ( - f"max_gemm_sms {self.max_gemm_sms} - num_comm_sms {num_comm_sms} - sm_margin {sm_margin} must be greater than 0" - ) - gemm_config = triton.Config( - { - "BLOCK_SIZE_M": self.BLOCK_SIZE_M, - "BLOCK_SIZE_N": self.BLOCK_SIZE_N, - "BLOCK_SIZE_K": self.BLOCK_SIZE_K, - "GROUP_SIZE_M": self.GROUP_SIZE_M, - "A2A_TILE_M": self.A2A_TILE_M, - "A2A_TILE_N": self.A2A_TILE_N, - "NUM_GEMM_SMS": self.max_gemm_sms - num_comm_sms - sm_margin, - }, - num_stages=self.num_stages, - num_warps=self.num_warps, - ) - - with torch.cuda.stream(self.compute_stream): - matmul_descriptor_persistent( - self.sp_rank, self.sp_size, gemm_input_a, weight, bias, output, self._barrier_buffer, gemm_config, self.warp_specialize - ) - - if a2a_output is not None: - assert a2a_output.shape == (bs, local_seq_len, local_head * self.sp_size, head_dim), ( - f"a2a_output shape {a2a_output.shape} must be equal to (bs, local_seq_len, local_head * self.sp_size, head_dim) ({bs}, {local_seq_len}, {local_head * self.sp_size}, {head_dim})" - ) - assert a2a_output.is_contiguous(), f"a2a_output must be contiguous, got {a2a_output.shape}" - a2a_output.copy_(gemm_input_a.view(bs, local_seq_len, local_head * self.sp_size * head_dim)) - ret = (output, a2a_output) - else: - ret = (output,) - - self.compute_event.record(self.compute_stream) - cur_stream.wait_event(self.compute_event) - - return ret - - def post_attn_a2a( - self, - inputs: torch.Tensor, - seq_lens_cpu: Optional[torch.Tensor] = None, - return_comm_buf: bool = False, - comm_buf_idx: int = 0, - num_comm_sms: int = -1, - ): - if num_comm_sms == -1: - num_comm_sms = self.world_size - assert num_comm_sms >= 0, "num_comm_sms must be non-negative" - assert len(inputs.shape) == 4, f"inputs must be 4D tensor, got {len(inputs)}D" - bs, total_seq_len, local_head, head_dim = inputs.shape - assert head_dim == self.head_dim, f"head_dim {head_dim} must be equal to self.head_dim {self.head_dim}" - assert inputs.is_contiguous(), f"inputs must be contiguous, got {inputs.shape}" - - if seq_lens_cpu is not None: - assert seq_lens_cpu.is_cpu, "seq_lens_cpu must be a CPU tensor" - assert seq_lens_cpu.dtype == torch.int32, "seq_lens_cpu must be int32" - assert seq_lens_cpu.is_contiguous(), "seq_lens_cpu must be contiguous" - - seq_lens_cpu_tuple = tuple(seq_lens_cpu.tolist()) - local_seq_len = seq_lens_cpu_tuple[self.sp_rank] - M = local_seq_len * bs - else: - assert total_seq_len % self.sp_size == 0, f"total_seq_len {total_seq_len} must be divisible by sp_size {self.sp_size}" - local_seq_len = total_seq_len // self.sp_size - M = local_seq_len * bs - - K = local_head * self.sp_size * head_dim - - self.reset_cusum_seq_lens(local_seqlen=local_seq_len, seq_lens_cpu=seq_lens_cpu) - - assert comm_buf_idx < self.max_num_comm_buf, f"comm_buf_idx {comm_buf_idx} must be less than num_comm_buf {self.max_num_comm_buf}" - gemm_input_a = self._comm_output_buffer[comm_buf_idx].view(-1)[: M * K].view([M, K]) - - cur_stream = torch.cuda.current_stream() - - if not self.fuse_sync: - self.sp_group_barrier_all_intra_node(cur_stream) - - grid = (self.max_gemm_sms,) - kernel_all2all_push_intra_node_nvl[grid]( - inputs, - gemm_input_a, - self._cum_seq_len_cpu_tuple, - self._cum_seqlen_gpu, - self._barrier_buffer, - self._intra_node_sync_buffer, # no need to initialize - local_head, - local_head * self.sp_size, - self.head_dim, - self.sp_size, - self.rank, - self.sp_rank, - 256, - 256, - 16, - num_comm_sms, - self.fuse_sync, - self.p2p_atomic_supported, - VEC=(16 // inputs.dtype.itemsize), - SKIP_BARRIER=True, - num_warps=32, - ) - - if return_comm_buf: - return gemm_input_a - else: - self.sp_group_barrier_all_intra_node(cur_stream) - return gemm_input_a.clone() - - def post_attn_a2a_no_cpy( - self, - inputs: torch.Tensor, - seq_lens_cpu: Optional[torch.Tensor] = None, - comm_buf_idx: int = 0, - num_comm_sms: int = -1, - ): - return self.post_attn_a2a( - inputs, - seq_lens_cpu, - return_comm_buf=True, - comm_buf_idx=comm_buf_idx, - num_comm_sms=num_comm_sms, - )