diff --git a/testing/python/language/test_tilelang_language_atom_mma.py b/testing/python/language/test_tilelang_language_atom_mma.py new file mode 100644 index 000000000..1af57bebd --- /dev/null +++ b/testing/python/language/test_tilelang_language_atom_mma.py @@ -0,0 +1,342 @@ +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.intrinsics import ( + TensorCoreIntrinEmitter, + WGMMATensorCoreIntrinEmitter, + TCGEN05TensorCoreIntrinEmitter, +) +from tilelang.cuda.intrinsics.layout.mma_layout import get_swizzle_layout +from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + if shape[-1] * T.dtype(dtype).bits == 512: + + def transform_func(i, j): + return get_swizzle_layout(i, j, shape[-1], dtype) + + return T.Layout(shape, transform_func) + return T.Layout(shape, lambda *args: args) + + +def infer_wgmma_shared_layout(continuity, dtype): + vectorized_size = 128 // T.dtype(dtype).bits + if continuity % (vectorized_size * 8) == 0: + return make_full_bank_swizzled_layout + if continuity % (vectorized_size * 4) == 0: + return make_half_bank_swizzled_layout + if continuity % (vectorized_size * 2) == 0: + return make_quarter_bank_swizzled_layout + return make_linear_layout + + +# --------------------------------------------------------------------------- +# SM80+ MMA (atom-level) -- correctness test +# --------------------------------------------------------------------------- + + +def make_mma_atom_kernel(M, N, K, in_dtype, out_dtype, accum_dtype): + micro_size_x = micro_size_y = micro_size_k = 16 + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 32 + warp_col_tiles = 32 + chunk = 32 if in_dtype == T.float16 else 64 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + threads = 32 * block_row_warps * block_col_warps + + emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + warp_rows = emitter.warp_rows + warp_cols = emitter.warp_cols + local_size_a = emitter.local_size_a + local_size_b = emitter.local_size_b + local_size_c = emitter.local_size_out + num_inst_m = emitter.mma_num_inst_m + num_inst_n = emitter.mma_num_inst_n + + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + B_shared = T.alloc_shared((block_N, block_K), in_dtype) + C_shared = T.alloc_shared((block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y), out_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + } + ) + + T.clear(C_local) + + for ko in T.serial(K // block_K): + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(block_K // micro_size_k): + emitter.ldmatrix_a(A_local, A_shared, ki) + emitter.ldmatrix_b(B_local, B_shared, ki) + for i, j in T.grid(num_inst_m, num_inst_n): + emitter.mma_atom(A_local, B_local, C_local, i, j, ki) + + emitter.stmatrix(C_local, C_shared) + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, i % micro_size_x, j % micro_size_y] + + return main + + +def _run_mma_atom(M, N, K, in_dtype, out_dtype, accum_dtype): + kernel = tilelang.compile(make_mma_atom_kernel(M, N, K, in_dtype, out_dtype, accum_dtype), target="cuda", out_idx=[2]) + a = torch.randn(M, K, device="cuda", dtype=in_dtype.as_torch()) + b = torch.randn(N, K, device="cuda", dtype=in_dtype.as_torch()) + c = kernel(a, b) + ref = (a.float() @ b.T.float()).to(out_dtype.as_torch()) + torch.testing.assert_close(c, ref, rtol=1e-2, atol=0.1) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(8, 0) +def test_mma_atom_gemm(): + _run_mma_atom(128, 128, 128, T.float16, T.float16, T.float16) + _run_mma_atom(256, 256, 256, T.bfloat16, T.float32, T.float32) + + +# --------------------------------------------------------------------------- +# SM90 WGMMA (atom-level, SS variant) -- codegen and correctness test +# --------------------------------------------------------------------------- +def make_wgmma_atom_kernel(M, N, K, in_dtype, out_dtype, accum_dtype): + chunk = 32 if in_dtype == T.float16 else 64 + block_row_warps = 4 + block_col_warps = 1 + warp_row_tiles = M // block_row_warps + warp_col_tiles = N // block_col_warps + block_K = chunk + + emi = WGMMATensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=False, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + a_layout = infer_wgmma_shared_layout(K, in_dtype) + b_layout = infer_wgmma_shared_layout(emi.wgmma_inst_n, in_dtype) + num_inst_m = emi.wgmma_num_inst_m + num_inst_n = emi.wgmma_num_inst_n + num_k_atoms = emi.wgmma_num_k_atoms + + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(1, threads=128): + A_shared = T.alloc_shared((M, block_K), in_dtype) + B_shared = T.alloc_shared((block_K, N), in_dtype) + C_local = T.alloc_fragment((M, N), accum_dtype) + + emi._assign_a_shared_layout(a_layout(A_shared)) + emi._assign_b_shared_layout(b_layout(B_shared)) + T.annotate_layout( + { + A_shared: a_layout(A_shared), + B_shared: b_layout(B_shared), + C_local: emi.make_mma_store_layout(C_local), + } + ) + + T.copy(A[0:M, 0:block_K], A_shared) + T.copy(B[0:block_K, 0:N], B_shared) + + a_params = emi.compute_wgmma_a_desc_params(A_shared) + b_params = emi.compute_wgmma_b_desc_params(B_shared) + + desc_a = T.alloc_wgmma_desc() + desc_b = T.alloc_wgmma_desc() + emi.init_wgmma_a_desc(desc_a, A_shared, a_params) + emi.init_wgmma_b_desc(desc_b, B_shared, b_params) + emi.wgmma_fence_c(C_local) + emi.wgmma_arrive() + + for n in T.unroll(num_inst_n): + for m in T.unroll(num_inst_m): + for ki in T.unroll(num_k_atoms): + emi.wgmma_ss_atom(desc_a, desc_b, C_local, m, n, ki, a_params, b_params, T.bool(True)) + + emi.wgmma_commit() + emi.wgmma_wait(0) + emi.wgmma_fence_c(C_local) + + T.copy(C_local, C[0:M, 0:N]) + + return main + + +def _run_wgmma_atom(M, N, K, in_dtype, out_dtype, accum_dtype): + kernel = tilelang.compile( + make_wgmma_atom_kernel(M, N, K, in_dtype, out_dtype, accum_dtype), + target="cuda", + out_idx=[2], + ) + src = kernel.get_kernel_source() + assert "wgmma_ss" in src + assert "initialize_wgmma_descriptor" in src + assert "warpgroup_arrive" in src + assert "warpgroup_commit_batch" in src + + a = torch.randn(M, K, device="cuda", dtype=in_dtype.as_torch()) + b = torch.randn(K, N, device="cuda", dtype=in_dtype.as_torch()) + c = kernel(a, b) + ref = (a.float() @ b.float()).to(out_dtype.as_torch()) + torch.testing.assert_close(c, ref, rtol=1e-2, atol=0.1) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_wgmma_atom_gemm(): + _run_wgmma_atom(64, 64, 32, T.float16, T.float16, T.float32) + + +# --------------------------------------------------------------------------- +# SM100 TCGEN05MMA (atom-level, SS variant) -- codegen and correctness test +# --------------------------------------------------------------------------- + + +def make_tcgen05_atom_kernel(M, N, K, in_dtype, out_dtype, accum_dtype): + chunk = K + emi = TCGEN05TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=M, + warp_col_tiles=N, + chunk=chunk, + ) + emi.get_tcgen5_mma_meta(M, N, K, True) + a_layout = infer_wgmma_shared_layout(K, in_dtype) + b_layout = infer_wgmma_shared_layout(K, in_dtype) + num_inst_m = emi.tcgen05_num_inst_m + num_inst_n = emi.tcgen05_num_inst_n + num_k_atoms = emi.tcgen05_num_k_atoms + instr_desc = emi.compute_tcgen05_instr_desc() + + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + D: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(1, threads=128): + A_shared = T.alloc_shared((M, K), in_dtype) + B_shared = T.alloc_shared((N, K), in_dtype) + C_tmem = T.alloc_tmem((M, N), accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((M, N), accum_dtype) + C_shared = T.alloc_shared((M, N), out_dtype) + + emi._assign_a_shared_layout(a_layout(A_shared)) + emi._assign_b_shared_layout(b_layout(B_shared)) + T.annotate_layout( + { + A_shared: a_layout(A_shared), + B_shared: b_layout(B_shared), + C_tmem: emi.make_mma_store_layout(C_tmem), + } + ) + + for i, k in T.Parallel(M, K): + A_shared[i, k] = A[i, k] + for j, k in T.Parallel(N, K): + B_shared[j, k] = B[j, k] + + a_params = emi.compute_tcgen05_a_desc_params(A_shared) + b_params = emi.compute_tcgen05_b_desc_params(B_shared) + + if T.get_thread_binding() // 32 == 0: + desc_a = T.alloc_tcgen05_smem_desc() + desc_b = T.alloc_tcgen05_smem_desc() + emi.init_tcgen05_a_desc(desc_a, A_shared, a_params) + emi.init_tcgen05_b_desc(desc_b, B_shared, b_params) + + for n in T.unroll(num_inst_n): + for m in T.unroll(num_inst_m): + for ki in T.unroll(0, num_k_atoms): + emi.tcgen05_ss_atom(desc_a, desc_b, C_tmem, m, n, ki, a_params, b_params, instr_desc, T.bool(True)) + emi.tcgen05_atom_arrive(mbar) + T.mbarrier_wait_parity(mbar, 0) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, D[0:M, 0:N]) + + return main + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(10, 0) +def test_tcgen05_atom_gemm(): + M, N, K = 128, 128, 128 + in_dtype = T.bfloat16 + out_dtype = T.bfloat16 + kernel = tilelang.compile( + make_tcgen05_atom_kernel(M, N, K, in_dtype, out_dtype, T.float32), + target="cuda", + out_idx=[2], + ) + src = kernel.get_kernel_source() + assert "tcgen05mma_ss" in src + assert "threadIdx.x) >> 5) == 0" in src # elect 1 thread to issue UMMA + + a = torch.randn(M, K, device="cuda", dtype=in_dtype.as_torch()) + b = torch.randn(N, K, device="cuda", dtype=in_dtype.as_torch()) + d = kernel(a, b) + ref = (a.float() @ b.T.float()).to(out_dtype.as_torch()) + torch.testing.assert_close(d, ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 6461dbd88..87da62999 100644 --- a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -482,6 +482,58 @@ def _warp_ldmatrix_b( def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + self.mma_atom(A_local_buf, B_local_buf, C_local_buf, i, j, k_inner) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + # ---- Atom-level interface ---- + + @property + def mma_num_inst_m(self) -> int: + """Number of MMA instruction atoms along the M dimension.""" + return self.warp_rows + + @property + def mma_num_inst_n(self) -> int: + """Number of MMA instruction atoms along the N dimension.""" + return self.warp_cols + + def mma_atom( + self, + A_local_buf: Buffer, + B_local_buf: Buffer, + C_local_buf: Buffer, + inst_m_idx: PrimExpr | int, + inst_n_idx: PrimExpr | int, + k_inner: PrimExpr | int = 0, + ): + """Emit a single MMA atom for tile (inst_m_idx, inst_n_idx). + + This is the atomic building block of ``mma()``. Calling this method + for every ``(i, j)`` in ``T.grid(mma_num_inst_m, mma_num_inst_n)`` + produces identical TIR to a single ``mma()`` call. + + Parameters + ---------- + A_local_buf : Buffer + Fragment buffer for operand A. + B_local_buf : Buffer + Fragment buffer for operand B. + C_local_buf : Buffer + Accumulator fragment buffer. + inst_m_idx : int or PrimExpr + M-dimension atom index (0 .. mma_num_inst_m - 1). + inst_n_idx : int or PrimExpr + N-dimension atom index (0 .. mma_num_inst_n - 1). + k_inner : int or PrimExpr + K-inner step index used to offset A/B fragments. + """ + warp_rows = self.warp_rows + warp_cols = self.warp_cols local_size_a = self.local_size_a local_size_b = self.local_size_b local_size_out = self.local_size_out @@ -497,9 +549,29 @@ def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_i a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 + A_offset = a_local_stride + inst_m_idx * local_size_a + B_offset = b_local_stride + inst_n_idx * local_size_b + C_offset = inst_m_idx * warp_cols * local_size_out + inst_n_idx * local_size_out + @T.macro - def _warp_mma(A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(warp_rows, warp_cols): + def _atom_mma(A_local_buf, B_local_buf, C_local_buf): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + A_offset, + B_local_buf.data, + B_offset, + C_local_buf.data, + C_offset, + T.bool(False), + ) + if replicate_b: T.ptx_mma( accum_dtype, mma_prefix, @@ -509,32 +581,15 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): b_dtype_abbrv, accum_dtype_abbrv, A_local_buf.data, - a_local_stride + i * local_size_a, + A_offset, B_local_buf.data, - b_local_stride + j * local_size_b, + B_offset + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out, - T.bool(False), # saturate + C_offset + lift(local_size_out) // 2, + T.bool(False), ) - if replicate_b: - T.ptx_mma( - accum_dtype, - mma_prefix, - "row", - "col", - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_local_buf.data, - a_local_stride + i * local_size_a, - B_local_buf.data, - b_local_stride + j * local_size_b + lift(local_size_b) // 2, - C_local_buf.data, - i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, - T.bool(False), # saturate - ) - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + return _atom_mma(A_local_buf, B_local_buf, C_local_buf) def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): block_row_warps = self.block_row_warps diff --git a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py index 7799b7915..f1f52c77c 100644 --- a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py @@ -1,4 +1,5 @@ from __future__ import annotations +from dataclasses import dataclass from enum import IntEnum import tilelang.language as T from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter @@ -19,6 +20,30 @@ lift = convert +@dataclass(frozen=True) +class TCGEN05DescriptorParams: + """Pre-computed parameters for TCGEN05 descriptor initialization and atom offset computation. + + Returned by ``compute_tcgen05_*_desc_params()`` and consumed by + ``init_tcgen05_*_desc()`` and ``tcgen05_*_atom()`` methods. + """ + + swizzle_mode: int + """SwizzleMode enum value (passed directly to ``T.initialize_tcgen05_descriptor``).""" + leading_byte_offset: int + """LBO >> 4, ready to pass to ``T.initialize_tcgen05_descriptor``.""" + stride_byte_offset: int + """SBO >> 4, ready to pass to ``T.initialize_tcgen05_descriptor``.""" + swizzle_atom_elems: int + """Number of elements per swizzle atom along the non-K dimension.""" + k_atom_size: int + """``max(swizzle_atom_elems // micro_size_k, 1)``.""" + elems_in_bytes: int + """Byte width of a single element: ``(DataType(dtype).bits + 7) // 8``.""" + is_k_major: bool + """Whether the matrix is stored in K-major order (affects offset formula branching).""" + + class SwizzleMode(IntEnum): # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 NONE = 0 @@ -185,207 +210,29 @@ def tcgen05mma_ss(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, clear_accum : PrimExpr Whether to zero the accumulator before the first MMA. """ - accum_dtype = self.accum_dtype - m_dim = self.block_row_warps * self.warp_row_tiles micro_size_k = self.micro_size_k - k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles - meta = self.meta - if len(meta) != 5: - raise ValueError( - f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " - f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" - ) - atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) - atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m - n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim - scale_in_a = 1 - scale_in_b = 1 - + k_dim = self.chunk assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" - a_is_k_major = not self.a_transposed - b_is_k_major = self.b_transposed - a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) - - elems_in_bits = DataType(self.a_dtype).bits - elems_in_bytes = (elems_in_bits + 7) // 8 - a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - accum_dtype_in_bits = DataType(accum_dtype).bits - - # by default, we utilize non-swizzle layout offset - a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) - a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) - - if not a_swizzle_mode.is_none(): - # swizzle mode doesn't require LBO/SBO to be 1 - # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset - if a_is_k_major: - a_leading_byte_offset = 16 - a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() - else: - # MN Major - # LBO represents the distance between two atoms along the M dimension - # SBO represents the distance between two atoms along the K dimension - a_m_axis_atoms = m_dim // a_swizzle_atom_elems - if a_m_axis_atoms <= 1: - a_leading_byte_offset = 0 - else: - a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size() - - if a_m_axis_atoms <= 1: - a_stride_byte_offset = 8 * elems_in_bytes * m_dim - else: - a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) - if not b_swizzle_mode.is_none(): - # swizzle mode doesn't require LBO/SBO to be 1 - # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() - else: - # MN Major, K * N - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim_per_cta - else: - b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems - - # for example, if [n, k] where k is 128, we should split it into 2 atoms - # where max specially handles the case when n_dim_per_cta is 8. - ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) - bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) - - instr_desc = self.get_tcgen5_instr_desc( - atom_m, - atom_n, - atom_k, - a_is_k_major, - b_is_k_major, - scale_in_a, - scale_in_b, - ) - # Allocate an instruction descriptor wrapper and initialize it - a_dtype_abbrv = self.a_dtype_abbrv - mask_zero = T.cast(0, T.int32) - mask0 = mask1 = mask2 = mask3 = mask_zero - - # TCGEN05 only has one warp group - num_inst_m = self.block_row_warps * self.warp_row_tiles // atom_m_per_cta - num_inst_n = self.block_col_warps * self.warp_col_tiles // atom_n - - # Helper to allow BufferRegion/BufferLoad as inputs - def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): - if isinstance(buffer_or_load_or_region, Buffer): - return buffer_or_load_or_region.access_ptr(access_type) - elif isinstance(buffer_or_load_or_region, BufferLoad): - buffer_load = buffer_or_load_or_region - offset, stride = 0, 1 - buffer = buffer_load.buffer - for i, shape in enumerate(reversed(buffer.shape)): - indice = buffer_load.indices[len(buffer_load.indices) - i - 1] - if isinstance(indice, tvm.tir.Ramp): - offset += indice.base * stride - elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): - offset += indice * stride - else: - raise ValueError(f"Unsupported index type: {type(indice)}") - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - elif isinstance(buffer_or_load_or_region, BufferRegion): - buffer_region = buffer_or_load_or_region - buffer = buffer_region.buffer - offset, stride = 0, 1 - for i, shape in enumerate(reversed(buffer.shape)): - offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - else: - raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + num_inst_m = self.tcgen05_num_inst_m + num_inst_n = self.tcgen05_num_inst_n + num_k_atoms = self.tcgen05_num_k_atoms + a_params = self.compute_tcgen05_a_desc_params(A_buf) + b_params = self.compute_tcgen05_b_desc_params(B_buf) + instr_desc = self.compute_tcgen05_instr_desc() @T.macro def _warp_mma_ss(A_buf, B_buf, C_local_buf, mbar): - # Allocate SMEM descriptors for A and B desc_a = T.alloc_tcgen05_smem_desc() desc_b = T.alloc_tcgen05_smem_desc() - A_ptr = access_ptr_from(A_buf, "r") - B_ptr = access_ptr_from(B_buf, "r") - - T.initialize_tcgen05_descriptor( - desc_a, - A_ptr, - int(a_leading_byte_offset >> 4), - int(a_stride_byte_offset >> 4), - 0, - False, - int(a_swizzle_mode), - ) - T.initialize_tcgen05_descriptor( - desc_b, - B_ptr, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4), - 0, - False, - int(b_swizzle_mode), - ) + self.init_tcgen05_a_desc(desc_a, A_buf, a_params) + self.init_tcgen05_b_desc(desc_b, B_buf, b_params) - tmem_col_step = atom_n // (128 // atom_m_per_cta) for j in T.unroll(num_inst_n): for i in T.unroll(num_inst_m): - for ki in T.unroll(0, (k_dim // micro_size_k)): - scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) - A_elem_offset = ( - (ki % ak_atom_size) * micro_size_k - + i * atom_m_per_cta * a_swizzle_atom_elems - + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems - if a_is_k_major - else i * atom_m_per_cta * k_dim + ki * a_swizzle_atom_elems * micro_size_k - ) - - B_elem_offset = ( - (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems - + (ki % bk_atom_size) * micro_size_k - + j * atom_n * b_swizzle_atom_elems - if b_is_k_major - else ( - ki * b_swizzle_atom_elems * micro_size_k - + j * atom_n * (k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1) - ) - ) - - A_byte_offset = A_elem_offset * elems_in_bytes - B_byte_offset = B_elem_offset * elems_in_bytes - C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 # 32 bits per tmem bank - - T.ptx_tcgen05_mma_ss( - a_dtype_abbrv, - desc_a.data, - A_byte_offset, - desc_b.data, - B_byte_offset, - C_local_buf.data, - C_offset, - instr_desc, - scale_out, - mask0, - mask1, - mask2, - mask3, - enable_ws, - enable_2cta, - ) - T.tcgen05_mma_arrive(mbar, arrive_2cta=enable_2cta) + for ki in T.unroll(0, num_k_atoms): + self.tcgen05_ss_atom(desc_a, desc_b, C_local_buf, i, j, ki, a_params, b_params, instr_desc, clear_accum) + self.tcgen05_atom_arrive(mbar) return _warp_mma_ss(A_buf, B_buf, C_local_buf, mbar) @@ -410,103 +257,15 @@ def tcgen05mma_ts(self, A_buf, B_buf, C_local_buf, mbar, clear_accum: PrimExpr = clear_accum : PrimExpr Whether to zero the accumulator before the first MMA. """ - accum_dtype = self.accum_dtype - m_dim = self.block_row_warps * self.warp_row_tiles micro_size_k = self.micro_size_k - k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles - meta = self.meta - if len(meta) != 5: - raise ValueError( - f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " - f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" - ) - atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) - atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m - n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim - scale_in_a = 1 - scale_in_b = 1 - + k_dim = self.chunk assert k_dim >= micro_size_k, f"k_dim must be >= {micro_size_k}, got {k_dim}" - b_is_k_major = self.b_transposed - b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) - - a_dtype_in_bits = DataType(self.a_dtype).bits - elems_in_bits = a_dtype_in_bits - elems_in_bytes = (elems_in_bits + 7) // 8 - b_swizzle_atom_elems = n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - accum_dtype_in_bits = DataType(accum_dtype).bits - - # B descriptor parameters (same as SS) - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) - if not b_swizzle_mode.is_none(): - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() - else: - b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim_per_cta - else: - b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems - - bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) - - a_is_k_major = not self.a_transposed - instr_desc = self.get_tcgen5_instr_desc( - atom_m, - atom_n, - atom_k, - a_is_k_major, - b_is_k_major, - scale_in_a, - scale_in_b, - ) - a_dtype_abbrv = self.a_dtype_abbrv - mask_zero = T.cast(0, T.int32) - mask0 = mask1 = mask2 = mask3 = mask_zero - - num_inst_m = m_dim // atom_m_per_cta - num_inst_n = n_dim // atom_n - - # TMEM column geometry for A operand - # Each TMEM column is 32 bits; row interleaving factor = 128 / atom_m - interleave = max(128 // atom_m, 1) - a_tmem_cols_per_k_atom = atom_k * a_dtype_in_bits // 32 // interleave - a_tmem_k_stride = k_dim * a_dtype_in_bits // 32 // interleave - - def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): - if isinstance(buffer_or_load_or_region, Buffer): - return buffer_or_load_or_region.access_ptr(access_type) - elif isinstance(buffer_or_load_or_region, BufferLoad): - buffer_load = buffer_or_load_or_region - offset, stride = 0, 1 - buffer = buffer_load.buffer - for i, shape in enumerate(reversed(buffer.shape)): - indice = buffer_load.indices[len(buffer_load.indices) - i - 1] - if isinstance(indice, tvm.tir.Ramp): - offset += indice.base * stride - elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): - offset += indice * stride - else: - raise ValueError(f"Unsupported index type: {type(indice)}") - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - elif isinstance(buffer_or_load_or_region, BufferRegion): - buffer_region = buffer_or_load_or_region - buffer = buffer_region.buffer - offset, stride = 0, 1 - for i, shape in enumerate(reversed(buffer.shape)): - offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - else: - raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + num_inst_m = self.tcgen05_num_inst_m + num_inst_n = self.tcgen05_num_inst_n + num_k_atoms = self.tcgen05_num_k_atoms + b_params = self.compute_tcgen05_b_desc_params(B_buf) + instr_desc = self.compute_tcgen05_instr_desc() # Resolve the TMEM data pointer for A if isinstance(A_buf, BufferRegion): @@ -519,59 +278,13 @@ def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): @T.macro def _warp_mma_ts(a_data, B_buf, C_local_buf, mbar): desc_b = T.alloc_tcgen05_smem_desc() - B_ptr = access_ptr_from(B_buf, "r") - - T.initialize_tcgen05_descriptor( - desc_b, - B_ptr, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4), - 0, - False, - int(b_swizzle_mode), - ) + self.init_tcgen05_b_desc(desc_b, B_buf, b_params) - tmem_col_step = atom_n // (128 // atom_m_per_cta) for j in T.unroll(num_inst_n): for i in T.unroll(num_inst_m): - for ki in T.unroll(0, (k_dim // micro_size_k)): - scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) - - # A: TMEM column offset - A_tmem_offset = i * a_tmem_k_stride + ki * a_tmem_cols_per_k_atom - - # B: SMEM byte offset (same as SS) - B_elem_offset = ( - (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems - + (ki % bk_atom_size) * micro_size_k - + j * atom_n * b_swizzle_atom_elems - if b_is_k_major - else ( - ki * b_swizzle_atom_elems * micro_size_k - + j * atom_n * (k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1) - ) - ) - B_byte_offset = B_elem_offset * elems_in_bytes - - # C: TMEM column offset (same as SS) - C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 - - T.ptx_tcgen05_mma_ts( - a_dtype_abbrv, - a_data, - A_tmem_offset, - desc_b.data, - B_byte_offset, - C_local_buf.data, - C_offset, - instr_desc, - scale_out, - mask0, - mask1, - mask2, - mask3, - ) - T.tcgen05_mma_arrive(mbar, arrive_2cta=enable_2cta) + for ki in T.unroll(0, num_k_atoms): + self.tcgen05_ts_atom(a_data, desc_b, C_local_buf, i, j, ki, b_params, instr_desc, clear_accum) + self.tcgen05_atom_arrive(mbar) return _warp_mma_ts(a_tmem_data, B_buf, C_local_buf, mbar) @@ -592,23 +305,17 @@ def tcgen05mma_blockscaled( Uses ``tcgen05.mma.cta_group::1|2.kind::mxf8f6f4.block_scale`` PTX instruction. Scale factors must already reside in tensor memory. """ - accum_dtype = self.accum_dtype m_dim = self.block_row_warps * self.warp_row_tiles micro_size_k = self.micro_size_k k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles - scale_in_a = 1 - scale_in_b = 1 assert k_dim >= micro_size_k a_is_k_major = not self.a_transposed b_is_k_major = self.b_transposed a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) - b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) - elems_in_bits = DataType(self.a_dtype).bits - elems_in_bytes = (elems_in_bits + 7) // 8 - accum_dtype_in_bits = DataType(accum_dtype).bits + elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 if len(self.meta) != 5: self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False) @@ -617,14 +324,12 @@ def tcgen05mma_blockscaled( f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, " f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" ) - atom_m, atom_n, atom_k, _enable_ws, enable_2cta = (int(x) for x in self.meta) - enable_ws = 0 + atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m - n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + # Block-scaled A LBO/SBO differ from regular SS (uses atom_m_per_cta instead of m_dim) a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * atom_m_per_cta * elems_in_bytes) a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) if not a_swizzle_mode.is_none(): @@ -638,64 +343,31 @@ def tcgen05mma_blockscaled( 8 * elems_in_bytes * a_swizzle_atom_elems if a_m_axis_atoms > 1 else 8 * elems_in_bytes * atom_m_per_cta ) - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) - if not b_swizzle_mode.is_none(): - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() - else: - b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems - b_leading_byte_offset = b_swizzle_mode.swizzle_byte_size() * k_dim if b_n_axis_atoms > 1 else 0 - b_stride_byte_offset = ( - 8 * elems_in_bytes * b_swizzle_atom_elems if b_n_axis_atoms > 1 else 8 * elems_in_bytes * n_dim_per_cta - ) - - ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) - bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + a_params = TCGEN05DescriptorParams( + swizzle_mode=int(a_swizzle_mode), + leading_byte_offset=int(a_leading_byte_offset >> 4), + stride_byte_offset=int(a_stride_byte_offset >> 4), + swizzle_atom_elems=a_swizzle_atom_elems, + k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=elems_in_bytes, + is_k_major=a_is_k_major, + ) + b_params = self.compute_tcgen05_b_desc_params(B_buf) base_instr_desc = self.get_tcgen5_blockscaled_instr_desc( atom_m, atom_n, a_is_k_major, b_is_k_major, - scale_in_a, - scale_in_b, + 1, + 1, 0, 0, ) - a_dtype_abbrv = self.a_dtype_abbrv num_inst_m = m_dim // atom_m_per_cta num_inst_n = n_dim // atom_n - - def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): - if isinstance(buffer_or_load_or_region, Buffer): - return buffer_or_load_or_region.access_ptr(access_type) - elif isinstance(buffer_or_load_or_region, BufferLoad): - buffer_load = buffer_or_load_or_region - offset, stride = 0, 1 - buffer = buffer_load.buffer - for i, shape in enumerate(reversed(buffer.shape)): - indice = buffer_load.indices[len(buffer_load.indices) - i - 1] - if isinstance(indice, tvm.tir.Ramp): - offset += indice.base * stride - elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): - offset += indice * stride - else: - raise ValueError(f"Unsupported index type: {type(indice)}") - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - elif isinstance(buffer_or_load_or_region, BufferRegion): - buffer_region = buffer_or_load_or_region - buffer = buffer_region.buffer - offset, stride = 0, 1 - for i, shape in enumerate(reversed(buffer.shape)): - offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride - stride *= shape - return buffer.access_ptr(access_type, offset=offset) - else: - raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + num_k_atoms = self.tcgen05_num_k_atoms if isinstance(SFA_tmem, BufferRegion): sfa_data = SFA_tmem.buffer.data @@ -715,78 +387,30 @@ def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): def _warp_mma_blockscaled(A_buf, B_buf, C_local_buf, sfa_data, sfb_data, mbar): desc_a = T.alloc_tcgen05_smem_desc() desc_b = T.alloc_tcgen05_smem_desc() - A_ptr = access_ptr_from(A_buf, "r") - B_ptr = access_ptr_from(B_buf, "r") - - T.initialize_tcgen05_descriptor( - desc_a, - A_ptr, - int(a_leading_byte_offset >> 4), - int(a_stride_byte_offset >> 4), - 0, - False, - int(a_swizzle_mode), - ) - T.initialize_tcgen05_descriptor( - desc_b, - B_ptr, - int(b_leading_byte_offset >> 4), - int(b_stride_byte_offset >> 4), - 0, - False, - int(b_swizzle_mode), - ) + self.init_tcgen05_a_desc(desc_a, A_buf, a_params) + self.init_tcgen05_b_desc(desc_b, B_buf, b_params) - tmem_col_step = atom_n // (128 // atom_m_per_cta) _sf_a = tvm.tir.const(sf_a_id, "int32") if isinstance(sf_a_id, int) else sf_a_id _sf_b = tvm.tir.const(sf_b_id, "int32") if isinstance(sf_b_id, int) else sf_b_id runtime_instr_desc = base_instr_desc | (_sf_a << 29) | (_sf_b << 4) for j in T.unroll(num_inst_n): for i in T.unroll(num_inst_m): - for ki in T.unroll(0, (k_dim // micro_size_k)): - scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) - A_elem_offset = ( - (ki % ak_atom_size) * micro_size_k - + i * atom_m_per_cta * a_swizzle_atom_elems - + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems - if a_is_k_major - else i * atom_m_per_cta * k_dim + ki * a_swizzle_atom_elems * micro_size_k - ) - B_elem_offset = ( - (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems - + (ki % bk_atom_size) * micro_size_k - + j * atom_n * b_swizzle_atom_elems - if b_is_k_major - else ( - ki * b_swizzle_atom_elems * micro_size_k - + j * atom_n * (k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1) - ) - ) - - A_byte_offset = A_elem_offset * elems_in_bytes - B_byte_offset = B_elem_offset * elems_in_bytes - C_offset = (i * n_dim + j * tmem_col_step) * accum_dtype_in_bits // 32 - - T.ptx_tcgen05_mma_blockscaled_ss( - a_dtype_abbrv, - desc_a.data, - A_byte_offset, - desc_b.data, - B_byte_offset, - C_local_buf.data, - C_offset, - runtime_instr_desc, - scale_out, + for ki in T.unroll(0, num_k_atoms): + self.tcgen05_blockscaled_atom( + desc_a, + desc_b, + C_local_buf, sfa_data, - 0, sfb_data, - 0, - 0, - 0, - enable_ws, - enable_2cta, + i, + j, + ki, + a_params, + b_params, + runtime_instr_desc, + clear_accum, ) - T.tcgen05_mma_arrive(mbar, arrive_2cta=enable_2cta) + self.tcgen05_atom_arrive(mbar) return _warp_mma_blockscaled(A_buf, B_buf, C_local_buf, sfa_data, sfb_data, mbar) @@ -925,3 +549,534 @@ def get_tcgen5_instr_desc( scale_in_b, ) return lift(desc) + + # ---- Atom-level interface ---- + + @property + def tcgen05_meta_unpacked(self) -> tuple: + """Return ``(atom_m, atom_n, atom_k, enable_ws, enable_2cta)`` as ints. + + Requires ``self.meta`` to have been set via ``get_tcgen5_mma_meta()``. + """ + assert len(self.meta) == 5, "TCGEN05 meta not initialized; call get_tcgen5_mma_meta() first" + return tuple(int(x) for x in self.meta) + + @property + def tcgen05_num_inst_m(self) -> int: + """Number of TCGEN05MMA instruction atoms along M (SS variant).""" + atom_m, _, _, _, enable_2cta = self.tcgen05_meta_unpacked + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + return self.block_row_warps * self.warp_row_tiles // atom_m_per_cta + + @property + def tcgen05_num_inst_n(self) -> int: + """Number of TCGEN05MMA instruction atoms along N.""" + _, atom_n, _, _, _ = self.tcgen05_meta_unpacked + return self.block_col_warps * self.warp_col_tiles // atom_n + + @property + def tcgen05_num_k_atoms(self) -> int: + """Number of K-dimension micro-steps (``chunk // micro_size_k``).""" + return self.chunk // self.micro_size_k + + @staticmethod + def _access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): + """Resolve an access pointer from a Buffer, BufferLoad, or BufferRegion.""" + if isinstance(buffer_or_load_or_region, Buffer): + return buffer_or_load_or_region.access_ptr(access_type) + elif isinstance(buffer_or_load_or_region, BufferLoad): + buffer_load = buffer_or_load_or_region + offset, stride = 0, 1 + buffer = buffer_load.buffer + for i, shape in enumerate(reversed(buffer.shape)): + indice = buffer_load.indices[len(buffer_load.indices) - i - 1] + if isinstance(indice, tvm.tir.Ramp): + offset += indice.base * stride + elif isinstance(indice, (tvm.tir.IntImm, tvm.tir.PrimExpr)): + offset += indice * stride + else: + raise ValueError(f"Unsupported index type: {type(indice)}") + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + elif isinstance(buffer_or_load_or_region, BufferRegion): + buffer_region = buffer_or_load_or_region + buffer = buffer_region.buffer + offset, stride = 0, 1 + for i, shape in enumerate(reversed(buffer.shape)): + offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride + stride *= shape + return buffer.access_ptr(access_type, offset=offset) + else: + raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}") + + # -- Descriptor parameter computation (pure Python, no TIR) -- + + def compute_tcgen05_b_desc_params(self, B_buf) -> TCGEN05DescriptorParams: + """Compute B descriptor parameters from the B shared buffer. + + This is a pure-Python helper -- no TIR code is emitted. + The returned ``TCGEN05DescriptorParams`` is passed to + ``init_tcgen05_b_desc()`` and ``tcgen05_*_atom()`` methods. + + Parameters + ---------- + B_buf : Buffer or BufferRegion + The B operand in shared memory. + """ + atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked + n_dim = self.block_col_warps * self.warp_col_tiles + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + k_dim = self.chunk + micro_size_k = self.micro_size_k + elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 + b_is_k_major = self.b_transposed + + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + b_swizzle_atom_elems = n_dim_per_cta if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim_per_cta * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim_per_cta == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + b_n_axis_atoms = n_dim_per_cta // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim_per_cta + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + return TCGEN05DescriptorParams( + swizzle_mode=int(b_swizzle_mode), + leading_byte_offset=int(b_leading_byte_offset >> 4), + stride_byte_offset=int(b_stride_byte_offset >> 4), + swizzle_atom_elems=b_swizzle_atom_elems, + k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=elems_in_bytes, + is_k_major=b_is_k_major, + ) + + def compute_tcgen05_a_desc_params(self, A_buf) -> TCGEN05DescriptorParams: + """Compute A descriptor parameters from the A shared buffer (SS variant). + + This is a pure-Python helper -- no TIR code is emitted. + + Parameters + ---------- + A_buf : Buffer or BufferRegion + The A operand in shared memory. + """ + m_dim = self.block_row_warps * self.warp_row_tiles + k_dim = self.chunk + micro_size_k = self.micro_size_k + elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 + a_is_k_major = not self.a_transposed + + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) + if not a_swizzle_mode.is_none(): + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + a_m_axis_atoms = m_dim // a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size() + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + + return TCGEN05DescriptorParams( + swizzle_mode=int(a_swizzle_mode), + leading_byte_offset=int(a_leading_byte_offset >> 4), + stride_byte_offset=int(a_stride_byte_offset >> 4), + swizzle_atom_elems=a_swizzle_atom_elems, + k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=elems_in_bytes, + is_k_major=a_is_k_major, + ) + + # -- Descriptor initialization (emit TIR) -- + + def init_tcgen05_b_desc(self, desc_b, B_buf, b_params: TCGEN05DescriptorParams): + """Emit TIR to initialize a pre-allocated TCGEN05 B descriptor. + + Parameters + ---------- + desc_b : Buffer + A descriptor buffer allocated via ``T.alloc_tcgen05_smem_desc()``. + B_buf : Buffer or BufferRegion + The B operand in shared memory. + b_params : TCGEN05DescriptorParams + Pre-computed parameters from ``compute_tcgen05_b_desc_params()``. + """ + access_ptr_from = self._access_ptr_from + lbo = b_params.leading_byte_offset + sbo = b_params.stride_byte_offset + swizzle_mode = b_params.swizzle_mode + B_ptr = access_ptr_from(B_buf, "r") + + @T.macro + def _init_b(desc_b, B_ptr): + T.initialize_tcgen05_descriptor(desc_b, B_ptr, lbo, sbo, 0, False, swizzle_mode) + + return _init_b(desc_b, B_ptr) + + def init_tcgen05_a_desc(self, desc_a, A_buf, a_params: TCGEN05DescriptorParams): + """Emit TIR to initialize a pre-allocated TCGEN05 A descriptor (SS variant). + + Parameters + ---------- + desc_a : Buffer + A descriptor buffer allocated via ``T.alloc_tcgen05_smem_desc()``. + A_buf : Buffer or BufferRegion + The A operand in shared memory. + a_params : TCGEN05DescriptorParams + Pre-computed parameters from ``compute_tcgen05_a_desc_params()``. + """ + access_ptr_from = self._access_ptr_from + lbo = a_params.leading_byte_offset + sbo = a_params.stride_byte_offset + swizzle_mode = a_params.swizzle_mode + A_ptr = access_ptr_from(A_buf, "r") + + @T.macro + def _init_a(desc_a, A_ptr): + T.initialize_tcgen05_descriptor(desc_a, A_ptr, lbo, sbo, 0, False, swizzle_mode) + + return _init_a(desc_a, A_ptr) + + # -- Instruction descriptor computation -- + + def compute_tcgen05_instr_desc(self) -> PrimExpr: + """Compute the 64-bit instruction descriptor using current meta. + + Requires ``self.meta`` to have been set via ``get_tcgen5_mma_meta()``. + """ + atom_m, atom_n, atom_k, _, _ = self.tcgen05_meta_unpacked + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + return self.get_tcgen5_instr_desc(atom_m, atom_n, atom_k, a_is_k_major, b_is_k_major, 1, 1) + + # -- Arrive -- + + def tcgen05_atom_arrive(self, mbar): + """Emit ``tcgen05_mma_arrive(mbar)``.""" + _, _, _, _, enable_2cta = self.tcgen05_meta_unpacked + + @T.macro + def _arrive(mbar): + T.tcgen05_mma_arrive(mbar, arrive_2cta=bool(enable_2cta)) + + return _arrive(mbar) + + # -- Atom emission -- + + def tcgen05_ss_atom( + self, + desc_a, + desc_b, + C_local_buf: Buffer, + inst_m_idx: int, + inst_n_idx: int, + ki: int, + a_params: TCGEN05DescriptorParams, + b_params: TCGEN05DescriptorParams, + instr_desc: PrimExpr, + clear_accum: PrimExpr = False, + ): + """Emit a single TCGEN05MMA SS instruction for atom ``(inst_m_idx, inst_n_idx, ki)``. + + Must be called after descriptor initialization and before ``tcgen05_atom_arrive()``. + + Parameters + ---------- + desc_a, desc_b : Buffer + Initialized A and B descriptors. + C_local_buf : Buffer + Accumulator buffer in tensor memory. + inst_m_idx : int + M-dimension atom index (0 .. tcgen05_num_inst_m - 1). + inst_n_idx : int + N-dimension atom index (0 .. tcgen05_num_inst_n - 1). + ki : int + K-dimension atom index (0 .. tcgen05_num_k_atoms - 1). + a_params : TCGEN05DescriptorParams + Pre-computed A descriptor parameters. + b_params : TCGEN05DescriptorParams + Pre-computed B descriptor parameters. + instr_desc : PrimExpr + Instruction descriptor from ``compute_tcgen05_instr_desc()``. + clear_accum : PrimExpr + Whether to zero the accumulator on the first K atom. + """ + atom_m, atom_n, _, enable_ws, enable_2cta = self.tcgen05_meta_unpacked + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + n_dim = self.block_col_warps * self.warp_col_tiles + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + m_dim = self.block_row_warps * self.warp_row_tiles + micro_size_k = self.micro_size_k + k_dim = self.chunk + accum_dtype_in_bits = DataType(self.accum_dtype).bits + a_dtype_abbrv = self.a_dtype_abbrv + a_elems_in_bytes = a_params.elems_in_bytes + b_elems_in_bytes = b_params.elems_in_bytes + ak_atom_size = a_params.k_atom_size + bk_atom_size = b_params.k_atom_size + a_swizzle_atom_elems = a_params.swizzle_atom_elems + b_swizzle_atom_elems = b_params.swizzle_atom_elems + mask_zero = T.cast(0, T.int32) + + # Pre-compute offsets + if a_params.is_k_major: + A_elem_offset = ( + (ki % ak_atom_size) * micro_size_k + + inst_m_idx * atom_m_per_cta * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + ) + else: + A_elem_offset = inst_m_idx * atom_m_per_cta * k_dim + ki * a_swizzle_atom_elems * micro_size_k + + if b_params.is_k_major: + B_elem_offset = ( + (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + inst_n_idx * atom_n * b_swizzle_atom_elems + ) + else: + B_elem_offset = ki * b_swizzle_atom_elems * micro_size_k + inst_n_idx * atom_n * ( + k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1 + ) + + A_byte_offset = A_elem_offset * a_elems_in_bytes + B_byte_offset = B_elem_offset * b_elems_in_bytes + tmem_col_step = atom_n // (128 // atom_m_per_cta) + C_offset = (inst_m_idx * n_dim + inst_n_idx * tmem_col_step) * accum_dtype_in_bits // 32 + + @T.macro + def _ss_atom(desc_a, desc_b, C_local_buf): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + T.ptx_tcgen05_mma_ss( + a_dtype_abbrv, + desc_a.data, + A_byte_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + C_offset, + instr_desc, + scale_out, + mask_zero, + mask_zero, + mask_zero, + mask_zero, + enable_ws, + enable_2cta, + ) + + return _ss_atom(desc_a, desc_b, C_local_buf) + + def tcgen05_ts_atom( + self, + a_tmem_data, + desc_b, + C_local_buf: Buffer, + inst_m_idx: int, + inst_n_idx: int, + ki: int, + b_params: TCGEN05DescriptorParams, + instr_desc: PrimExpr, + clear_accum: PrimExpr = False, + ): + """Emit a single TCGEN05MMA TS instruction for atom ``(inst_m_idx, inst_n_idx, ki)``. + + A resides in tensor memory; B in shared memory. + + Parameters + ---------- + a_tmem_data : Var + Data pointer for the A operand in tensor memory (e.g., ``A_buf.data``). + desc_b : Buffer + Initialized B descriptor. + C_local_buf : Buffer + Accumulator buffer in tensor memory. + inst_m_idx : int + M-dimension atom index. + inst_n_idx : int + N-dimension atom index. + ki : int + K-dimension atom index. + b_params : TCGEN05DescriptorParams + Pre-computed B descriptor parameters. + instr_desc : PrimExpr + Instruction descriptor from ``compute_tcgen05_instr_desc()``. + clear_accum : PrimExpr + Whether to zero the accumulator on the first K atom. + """ + atom_m, atom_n, atom_k, _, enable_2cta = self.tcgen05_meta_unpacked + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + n_dim = self.block_col_warps * self.warp_col_tiles + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + micro_size_k = self.micro_size_k + k_dim = self.chunk + a_dtype_in_bits = DataType(self.a_dtype).bits + accum_dtype_in_bits = DataType(self.accum_dtype).bits + a_dtype_abbrv = self.a_dtype_abbrv + b_elems_in_bytes = b_params.elems_in_bytes + bk_atom_size = b_params.k_atom_size + b_swizzle_atom_elems = b_params.swizzle_atom_elems + mask_zero = T.cast(0, T.int32) + + # TMEM column geometry for A + interleave = max(128 // atom_m, 1) + a_tmem_cols_per_k_atom = atom_k * a_dtype_in_bits // 32 // interleave + a_tmem_k_stride = k_dim * a_dtype_in_bits // 32 // interleave + + A_tmem_offset = inst_m_idx * a_tmem_k_stride + ki * a_tmem_cols_per_k_atom + + if b_params.is_k_major: + B_elem_offset = ( + (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + inst_n_idx * atom_n * b_swizzle_atom_elems + ) + else: + B_elem_offset = ki * b_swizzle_atom_elems * micro_size_k + inst_n_idx * atom_n * ( + k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1 + ) + B_byte_offset = B_elem_offset * b_elems_in_bytes + + tmem_col_step = atom_n // (128 // atom_m_per_cta) + C_offset = (inst_m_idx * n_dim + inst_n_idx * tmem_col_step) * accum_dtype_in_bits // 32 + + @T.macro + def _ts_atom(a_data, desc_b, C_local_buf): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + T.ptx_tcgen05_mma_ts( + a_dtype_abbrv, + a_data, + A_tmem_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + C_offset, + instr_desc, + scale_out, + mask_zero, + mask_zero, + mask_zero, + mask_zero, + ) + + return _ts_atom(a_tmem_data, desc_b, C_local_buf) + + def tcgen05_blockscaled_atom( + self, + desc_a, + desc_b, + C_local_buf: Buffer, + sfa_data, + sfb_data, + inst_m_idx: int, + inst_n_idx: int, + ki: int, + a_params: TCGEN05DescriptorParams, + b_params: TCGEN05DescriptorParams, + instr_desc: PrimExpr, + clear_accum: PrimExpr = False, + ): + """Emit a single TCGEN05MMA block-scaled SS instruction. + + Parameters + ---------- + desc_a, desc_b : Buffer + Initialized A and B descriptors. + C_local_buf : Buffer + Accumulator buffer in tensor memory. + sfa_data, sfb_data : Var + Scale factor data pointers in tensor memory. + inst_m_idx, inst_n_idx, ki : int + Atom indices. + a_params, b_params : TCGEN05DescriptorParams + Pre-computed descriptor parameters. + instr_desc : PrimExpr + Block-scaled instruction descriptor (with SF IDs already encoded). + clear_accum : PrimExpr + Whether to zero the accumulator on the first K atom. + """ + atom_m, atom_n, _, enable_ws, enable_2cta = self.tcgen05_meta_unpacked + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + n_dim = self.block_col_warps * self.warp_col_tiles + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + m_dim = self.block_row_warps * self.warp_row_tiles + micro_size_k = self.micro_size_k + k_dim = self.chunk + accum_dtype_in_bits = DataType(self.accum_dtype).bits + a_dtype_abbrv = self.a_dtype_abbrv + a_elems_in_bytes = a_params.elems_in_bytes + b_elems_in_bytes = b_params.elems_in_bytes + ak_atom_size = a_params.k_atom_size + bk_atom_size = b_params.k_atom_size + a_swizzle_atom_elems = a_params.swizzle_atom_elems + b_swizzle_atom_elems = b_params.swizzle_atom_elems + + if a_params.is_k_major: + A_elem_offset = ( + (ki % ak_atom_size) * micro_size_k + + inst_m_idx * atom_m_per_cta * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + ) + else: + A_elem_offset = inst_m_idx * atom_m_per_cta * k_dim + ki * a_swizzle_atom_elems * micro_size_k + + if b_params.is_k_major: + B_elem_offset = ( + (ki // bk_atom_size) * n_dim_per_cta * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + inst_n_idx * atom_n * b_swizzle_atom_elems + ) + else: + B_elem_offset = ki * b_swizzle_atom_elems * micro_size_k + inst_n_idx * atom_n * ( + k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1 + ) + + A_byte_offset = A_elem_offset * a_elems_in_bytes + B_byte_offset = B_elem_offset * b_elems_in_bytes + tmem_col_step = atom_n // (128 // atom_m_per_cta) + C_offset = (inst_m_idx * n_dim + inst_n_idx * tmem_col_step) * accum_dtype_in_bits // 32 + + @T.macro + def _bs_atom(desc_a, desc_b, C_local_buf, sfa_data, sfb_data): + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + T.ptx_tcgen05_mma_blockscaled_ss( + a_dtype_abbrv, + desc_a.data, + A_byte_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + C_offset, + instr_desc, + scale_out, + sfa_data, + 0, + sfb_data, + 0, + 0, + 0, + enable_ws, + enable_2cta, + ) + + return _bs_atom(desc_a, desc_b, C_local_buf, sfa_data, sfb_data) diff --git a/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py index f31c12fb9..86af177d9 100644 --- a/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py @@ -1,5 +1,6 @@ from __future__ import annotations import tilelang.language as T +from dataclasses import dataclass from enum import IntEnum from typing import Callable from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter @@ -24,6 +25,30 @@ lift = convert +@dataclass(frozen=True) +class WGMMADescriptorParams: + """Pre-computed parameters for WGMMA descriptor initialization and atom offset computation. + + Returned by ``compute_wgmma_*_desc_params()`` and consumed by + ``init_wgmma_*_desc()`` and ``wgmma_*_atom()`` methods. + """ + + swizzle_mode: int + """SwizzleMode enum value (passed directly to ``T.initialize_wgmma_descriptor``).""" + leading_byte_offset: int + """LBO >> 4, ready to pass to ``T.initialize_wgmma_descriptor``.""" + stride_byte_offset: int + """SBO >> 4, ready to pass to ``T.initialize_wgmma_descriptor``.""" + swizzle_atom_elems: int + """Number of elements per swizzle atom along the non-K dimension.""" + k_atom_size: int + """``max(swizzle_atom_elems // micro_size_k, 1)``.""" + elems_in_bytes: int + """Byte width of a single element: ``DataType(dtype).bits // 8``.""" + is_k_major: bool + """Whether the matrix is stored in K-major order (affects offset formula branching).""" + + class SwizzleMode(IntEnum): # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 NONE = 0 @@ -181,276 +206,514 @@ def wgmma( if is_fragment(A_region): return self.wgmma_rs(A_region, B_region, C_region, clear_accum, wg_wait) - local_size_out = self.local_size_out - a_dtype_abbrv = self.a_dtype_abbrv - b_dtype_abbrv = self.b_dtype_abbrv - accum_dtype = self.accum_dtype - accum_dtype_abbrv = self.accum_dtype_abbrv - m_dim = self.block_row_warps * self.warp_row_tiles - warp_cols = self.warp_cols + k_dim = self.chunk micro_size_k = self.micro_size_k - k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles - wgmma_prefix = self.wgmma_prefix - scale_in_a = 1 - scale_in_b = 1 + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + assert is_full_region(C_region), "Fragment output C must be a full region" + C_buf = C_region.buffer + + num_inst_m = self.wgmma_num_inst_m + num_inst_n = self.wgmma_num_inst_n + num_k_atoms = self.wgmma_num_k_atoms + a_params = self.compute_wgmma_a_desc_params(A_region) + b_params = self.compute_wgmma_b_desc_params(B_region) + + @T.macro + def _warp_mma(C_buf): + desc_a = T.alloc_wgmma_desc() + desc_b = T.alloc_wgmma_desc() + self.init_wgmma_a_desc(desc_a, A_region, a_params) + self.init_wgmma_b_desc(desc_b, B_region, b_params) + self.wgmma_fence_c(C_buf) + self.wgmma_arrive() + + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(num_k_atoms): + self.wgmma_ss_atom(desc_a, desc_b, C_buf, i, j, ki, a_params, b_params, clear_accum) + + self.wgmma_commit() + if wg_wait >= 0: + self.wgmma_wait(wg_wait) + self.wgmma_fence_c(C_buf) + + return _warp_mma(C_buf) + + def wgmma_rs( + self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + ): + k_dim = self.chunk + micro_size_k = self.micro_size_k assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" - a_is_k_major = not self.a_transposed + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(C_region), "Fragment output C must be a full region" + A_buf = A_region.buffer + C_buf = C_region.buffer + + num_inst_m = self.wgmma_num_inst_m + num_inst_n = self.wgmma_num_inst_n + num_k_atoms = self.wgmma_num_k_atoms + b_params = self.compute_wgmma_b_desc_params(B_region) + + @T.macro + def _warp_mma(A_buf, C_buf): + desc_b = T.alloc_wgmma_desc() + self.init_wgmma_b_desc(desc_b, B_region, b_params) + self.wgmma_fence_a(A_buf) + self.wgmma_fence_c(C_buf) + self.wgmma_arrive() + + for j in T.unroll(0, num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(0, num_k_atoms): + self.wgmma_rs_atom(A_buf, desc_b, C_buf, i, j, ki, b_params, clear_accum) + + self.wgmma_commit() + if wg_wait >= 0: + self.wgmma_wait(wg_wait) + self.wgmma_fence_c(C_buf) + self.wgmma_fence_a(A_buf) + + return _warp_mma(A_buf, C_buf) + + # ---- Atom-level interface ---- + + @property + def wgmma_num_inst_m(self) -> int: + """Number of WGMMA instruction atoms along the M dimension.""" + return 4 * self.warp_row_tiles // self.wgmma_inst_m + + @property + def wgmma_num_inst_n(self) -> int: + """Number of WGMMA instruction atoms along the N dimension.""" + return self.warp_col_tiles // self.wgmma_inst_n + + @property + def wgmma_num_k_atoms(self) -> int: + """Number of K-dimension micro-steps (``chunk // micro_size_k``).""" + return self.chunk // self.micro_size_k + + @property + def wgmma_a_regs(self) -> int: + """Number of 32-bit registers occupied by the A fragment (RS variant).""" + a_bits = DataType(self.a_dtype).bits + k_dim = self.chunk + micro_size_k = self.micro_size_k + return ((self.warp_rows * self.local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32 + + @property + def wgmma_accum_regs(self) -> int: + """Number of 32-bit registers occupied by the accumulator fragment.""" + m_dim = self.block_row_warps * self.warp_row_tiles + accum_bits = DataType(self.accum_dtype).bits + return ((m_dim // 64) * self.warp_cols * self.local_size_out * accum_bits + 31) // 32 + + # -- Descriptor parameter computation (pure Python, no TIR) -- + + def compute_wgmma_b_desc_params(self, B_region: BufferRegion) -> WGMMADescriptorParams: + """Compute B descriptor parameters from the B shared buffer region. + + This is a pure-Python helper -- no TIR code is emitted. + The returned ``WGMMADescriptorParams`` is passed to + ``init_wgmma_b_desc()`` and ``wgmma_*_atom()`` methods. + """ + n_dim = self.block_col_warps * self.warp_col_tiles + k_dim = self.chunk + micro_size_k = self.micro_size_k + elems_in_bytes = DataType(self.a_dtype).bits // 8 b_is_k_major = self.b_transposed - a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - elems_in_bits = DataType(self.a_dtype).bits - elems_in_bytes = elems_in_bits // 8 + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + return WGMMADescriptorParams( + swizzle_mode=int(b_swizzle_mode), + leading_byte_offset=int(b_leading_byte_offset >> 4), + stride_byte_offset=int(b_stride_byte_offset >> 4), + swizzle_atom_elems=b_swizzle_atom_elems, + k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=elems_in_bytes, + is_k_major=b_is_k_major, + ) + + def compute_wgmma_a_desc_params(self, A_region: BufferRegion) -> WGMMADescriptorParams: + """Compute A descriptor parameters from the A shared buffer region (SS variant). + + This is a pure-Python helper -- no TIR code is emitted. + The returned ``WGMMADescriptorParams`` is passed to + ``init_wgmma_a_desc()`` and ``wgmma_ss_atom()`` methods. + """ + m_dim = self.block_row_warps * self.warp_row_tiles + k_dim = self.chunk + micro_size_k = self.micro_size_k + elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_is_k_major = not self.a_transposed + a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes - accum_bits = DataType(accum_dtype).bits - accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 - # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) - if not a_swizzle_mode.is_none(): - # swizzle mode doesn't require LBO/SBO to be 1 - # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset if a_is_k_major: a_leading_byte_offset = 16 a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() else: - # MN Major - # LBO represents the distance between two atoms along the M dimension - # SBO represents the distance between two atoms along the K dimension a_m_axis_atoms = m_dim // a_swizzle_atom_elems if a_m_axis_atoms <= 1: a_leading_byte_offset = 0 else: a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - if a_m_axis_atoms <= 1: a_stride_byte_offset = 8 * elems_in_bytes * m_dim else: a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) - if not b_swizzle_mode.is_none(): - # swizzle mode doesn't require LBO/SBO to be 1 - # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() - else: - # MN Major, K * N - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // b_swizzle_atom_elems - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + return WGMMADescriptorParams( + swizzle_mode=int(a_swizzle_mode), + leading_byte_offset=int(a_leading_byte_offset >> 4), + stride_byte_offset=int(a_stride_byte_offset >> 4), + swizzle_atom_elems=a_swizzle_atom_elems, + k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=elems_in_bytes, + is_k_major=a_is_k_major, + ) - # for example, if [n, k] where k is 128, we should split it into 2 atoms - # where max specially handles the case when n_dim is 8. - ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) - bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) - wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n - num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m - num_inst_n = self.warp_col_tiles // wgmma_inst_n + # -- Descriptor initialization (emit TIR) -- - thread_binding = self.get_thread_binding() + def init_wgmma_b_desc( + self, + desc_b: Buffer, + B_region: BufferRegion, + b_params: WGMMADescriptorParams, + ): + """Emit TIR to initialize a pre-allocated WGMMA B descriptor. - A_ptr = retrive_ptr_from_buffer_region(A_region) + Parameters + ---------- + desc_b : Buffer + A descriptor buffer allocated via ``T.alloc_wgmma_desc()``. + B_region : BufferRegion + The B operand shared memory region. + b_params : WGMMADescriptorParams + Pre-computed parameters from ``compute_wgmma_b_desc_params()``. + """ B_ptr = retrive_ptr_from_buffer_region(B_region) - assert is_full_region(C_region), "Fragment output C must be a full region" + swizzle_mode = b_params.swizzle_mode + lbo = b_params.leading_byte_offset + sbo = b_params.stride_byte_offset - C_buf = C_region.buffer + @T.macro + def _init_b_desc(desc_b, B_ptr): + T.initialize_wgmma_descriptor(desc_b, B_ptr, swizzle_mode, lbo, sbo) + + return _init_b_desc(desc_b, B_ptr) + + def init_wgmma_a_desc( + self, + desc_a: Buffer, + A_region: BufferRegion, + a_params: WGMMADescriptorParams, + ): + """Emit TIR to initialize a pre-allocated WGMMA A descriptor (SS variant). + + Parameters + ---------- + desc_a : Buffer + A descriptor buffer allocated via ``T.alloc_wgmma_desc()``. + A_region : BufferRegion + The A operand shared memory region. + a_params : WGMMADescriptorParams + Pre-computed parameters from ``compute_wgmma_a_desc_params()``. + """ + A_ptr = retrive_ptr_from_buffer_region(A_region) + swizzle_mode = a_params.swizzle_mode + lbo = a_params.leading_byte_offset + sbo = a_params.stride_byte_offset @T.macro - def _warp_mma(A_ptr, B_ptr, C_buf): - tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + def _init_a_desc(desc_a, A_ptr): + T.initialize_wgmma_descriptor(desc_a, A_ptr, swizzle_mode, lbo, sbo) - desc_a = T.alloc_wgmma_desc() - desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + return _init_a_desc(desc_a, A_ptr) + + # -- Fence / Arrive / Commit / Wait primitives -- + + def wgmma_fence_a(self, A_buf: Buffer): + """Emit ``warpgroup_fence_operand`` for the A fragment buffer.""" + a_regs = self.wgmma_a_regs + + @T.macro + def _fence_a(A_buf): + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + + return _fence_a(A_buf) + + def wgmma_fence_c(self, C_buf: Buffer): + """Emit ``warpgroup_fence_operand`` for the accumulator buffer.""" + accum_regs = self.wgmma_accum_regs + + @T.macro + def _fence_c(C_buf): T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + + return _fence_c(C_buf) + + def wgmma_arrive(self): + """Emit ``warpgroup_arrive()``.""" + + @T.macro + def _arrive(): T.warpgroup_arrive() - for j in T.unroll(num_inst_n): - for i in T.unroll(num_inst_m): - for ki in T.unroll(k_dim // micro_size_k): - scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) - warp_i = (warp_m // 4) * num_inst_m + i - warp_j = warp_n * num_inst_n + j - A_offset = ( - (ki % ak_atom_size) * micro_size_k - + warp_i * 64 * a_swizzle_atom_elems - + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems - if a_is_k_major - else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k - ) - B_offset = ( - (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems - + (ki % bk_atom_size) * micro_size_k - + warp_j * wgmma_inst_n * b_swizzle_atom_elems - if b_is_k_major - else ( - ki * b_swizzle_atom_elems * micro_size_k - + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) - ) - ) - C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit - T.ptx_wgmma_ss( - accum_dtype, - wgmma_prefix, - a_is_k_major, - b_is_k_major, - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - desc_a.data, - (A_offset * elems_in_bytes) >> 4, - desc_b.data, - (B_offset * elems_in_bytes) >> 4, - C_buf.data, - C_offset, - scale_out, - scale_in_a, - scale_in_b, - ) + return _arrive() + def wgmma_commit(self): + """Emit ``warpgroup_commit_batch()``.""" + + @T.macro + def _commit(): T.warpgroup_commit_batch() - if wg_wait >= 0: - T.warpgroup_wait(wg_wait) - T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) - return _warp_mma(A_ptr, B_ptr, C_buf) + return _commit() - def wgmma_rs( - self, A_region: BufferRegion, B_region: BufferRegion, C_region: BufferRegion, clear_accum: PrimExpr = False, wg_wait: int = 0 + def wgmma_wait(self, n: int = 0): + """Emit ``warpgroup_wait(n)``.""" + + @T.macro + def _wait(): + T.warpgroup_wait(n) + + return _wait() + + # -- Atom emission -- + + def wgmma_rs_atom( + self, + A_buf: Buffer, + desc_b: Buffer, + C_buf: Buffer, + inst_m_idx: int, + inst_n_idx: int, + ki: int, + b_params: WGMMADescriptorParams, + clear_accum: PrimExpr = False, ): + """Emit a single WGMMA RS instruction for atom ``(inst_m_idx, inst_n_idx, ki)``. + + Must be called between a ``wgmma_fence_a``/``wgmma_fence_c``/``wgmma_arrive`` + sequence and a ``wgmma_commit``/``wgmma_wait`` sequence. + + Calling this for every ``(j, i, ki)`` in + ``T.grid(wgmma_num_inst_n, wgmma_num_inst_m, wgmma_num_k_atoms)`` + produces identical TIR to ``wgmma_rs()``. + + Parameters + ---------- + A_buf : Buffer + Fragment buffer for operand A (in registers). + desc_b : Buffer + Initialized B descriptor (from ``init_wgmma_b_desc``). + C_buf : Buffer + Accumulator fragment buffer. + inst_m_idx : int + M-dimension atom index (0 .. wgmma_num_inst_m - 1). + inst_n_idx : int + N-dimension atom index (0 .. wgmma_num_inst_n - 1). + ki : int + K-dimension atom index (0 .. wgmma_num_k_atoms - 1). + b_params : WGMMADescriptorParams + Pre-computed B descriptor parameters. + clear_accum : PrimExpr + Whether to zero the accumulator on the first K atom. + """ local_size_a = self.local_size_a local_size_out = self.local_size_out + warp_rows = self.warp_rows + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + n_dim = self.block_col_warps * self.warp_col_tiles + k_dim = self.chunk + wgmma_inst_n = self.wgmma_inst_n + num_inst_n = self.wgmma_num_inst_n a_dtype_abbrv = self.a_dtype_abbrv b_dtype_abbrv = self.b_dtype_abbrv accum_dtype = self.accum_dtype accum_dtype_abbrv = self.accum_dtype_abbrv - m_dim = self.block_row_warps * self.warp_row_tiles - warp_rows, warp_cols = self.warp_rows, self.warp_cols - micro_size_k = self.micro_size_k - k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles wgmma_prefix = self.wgmma_prefix - scale_in_a = 1 - scale_in_b = 1 + b_transposed = self.b_transposed + elems_in_bytes = b_params.elems_in_bytes + bk_atom_size = b_params.k_atom_size + b_swizzle_atom_elems = b_params.swizzle_atom_elems - assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + thread_binding = self.get_thread_binding() - elems_in_bytes = DataType(self.a_dtype).bits // 8 - a_bits = DataType(self.a_dtype).bits - accum_bits = DataType(accum_dtype).bits - a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32 - accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 - b_is_k_major = self.b_transposed + A_offset = ki * warp_rows * local_size_a + inst_m_idx * local_size_a + C_offset = inst_m_idx * warp_cols * local_size_out + inst_n_idx * warp_cols * local_size_out // num_inst_n - b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) - b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + @T.macro + def _rs_atom(A_buf, desc_b, C_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + warp_j = warp_n * num_inst_n + inst_n_idx + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + if b_params.is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + + T.ptx_wgmma_rs( + accum_dtype, + wgmma_prefix, + b_transposed, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf.data, + A_offset, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + 1, + 1, + ) + + return _rs_atom(A_buf, desc_b, C_buf) + + def wgmma_ss_atom( + self, + desc_a: Buffer, + desc_b: Buffer, + C_buf: Buffer, + inst_m_idx: int, + inst_n_idx: int, + ki: int, + a_params: WGMMADescriptorParams, + b_params: WGMMADescriptorParams, + clear_accum: PrimExpr = False, + ): + """Emit a single WGMMA SS instruction for atom ``(inst_m_idx, inst_n_idx, ki)``. - b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) - if not b_swizzle_mode.is_none(): - # swizzle mode doesn't require LBO/SBO to be 1 - # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset - if b_is_k_major: - b_leading_byte_offset = 16 - b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() - else: - # MN Major - # LBO represents the distance between two atoms along the N dimension - # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // b_swizzle_atom_elems - if b_n_axis_atoms <= 1: - b_leading_byte_offset = 0 - else: - b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim - if b_n_axis_atoms <= 1: - b_stride_byte_offset = 8 * elems_in_bytes * n_dim - else: - b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + Must be called between fence/arrive and commit/wait sequences. - bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) - wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n - num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m - num_inst_n = self.warp_col_tiles // wgmma_inst_n + Parameters + ---------- + desc_a : Buffer + Initialized A descriptor (from ``init_wgmma_a_desc``). + desc_b : Buffer + Initialized B descriptor (from ``init_wgmma_b_desc``). + C_buf : Buffer + Accumulator fragment buffer. + inst_m_idx : int + M-dimension atom index (0 .. wgmma_num_inst_m - 1). + inst_n_idx : int + N-dimension atom index (0 .. wgmma_num_inst_n - 1). + ki : int + K-dimension atom index (0 .. wgmma_num_k_atoms - 1). + a_params : WGMMADescriptorParams + Pre-computed A descriptor parameters. + b_params : WGMMADescriptorParams + Pre-computed B descriptor parameters. + clear_accum : PrimExpr + Whether to zero the accumulator on the first K atom. + """ + local_size_out = self.local_size_out + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + m_dim = self.block_row_warps * self.warp_row_tiles + n_dim = self.block_col_warps * self.warp_col_tiles + k_dim = self.chunk + wgmma_inst_n = self.wgmma_inst_n + num_inst_m = self.wgmma_num_inst_m + num_inst_n = self.wgmma_num_inst_n + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + wgmma_prefix = self.wgmma_prefix + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + a_elems_in_bytes = a_params.elems_in_bytes + b_elems_in_bytes = b_params.elems_in_bytes + ak_atom_size = a_params.k_atom_size + bk_atom_size = b_params.k_atom_size + a_swizzle_atom_elems = a_params.swizzle_atom_elems + b_swizzle_atom_elems = b_params.swizzle_atom_elems thread_binding = self.get_thread_binding() - assert is_full_region(A_region), "Fragment input A must be a full region" - assert is_full_region(C_region), "Fragment output C must be a full region" - A_buf = A_region.buffer - B_ptr = retrive_ptr_from_buffer_region(B_region) - C_buf = C_region.buffer + C_offset = inst_m_idx * warp_cols * local_size_out + inst_n_idx * warp_cols * local_size_out // num_inst_n @T.macro - def _warp_mma(A_buf, B_ptr, C_buf): + def _ss_atom(desc_a, desc_b, C_buf): tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) - - desc_b = T.alloc_wgmma_desc() - T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) - T.warpgroup_fence_operand(A_buf, num_regs=a_regs) - T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) - T.warpgroup_arrive() - - for j in T.unroll(0, num_inst_n): - for i in T.unroll(num_inst_m): - for ki in T.unroll(0, (k_dim // micro_size_k)): - warp_j = warp_n * num_inst_n + j - scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) - - A_offset = ki * warp_rows * local_size_a + i * local_size_a - B_offset = ( - (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems - + warp_j * wgmma_inst_n * b_swizzle_atom_elems - + (ki % bk_atom_size) * micro_size_k - if b_is_k_major - else ( - ki * b_swizzle_atom_elems * micro_size_k - + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) - ) - ) - C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit - T.ptx_wgmma_rs( - accum_dtype, - wgmma_prefix, - self.b_transposed, - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_buf.data, - A_offset, - desc_b.data, - (B_offset * elems_in_bytes) >> 4, - C_buf.data, - C_offset, - scale_out, - scale_in_a, - scale_in_b, - ) - - T.warpgroup_commit_batch() - if wg_wait >= 0: - T.warpgroup_wait(wg_wait) - T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) - T.warpgroup_fence_operand(A_buf, num_regs=a_regs) - - return _warp_mma(A_buf, B_ptr, C_buf) + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + warp_i = (warp_m // 4) * num_inst_m + inst_m_idx + warp_j = warp_n * num_inst_n + inst_n_idx + + A_offset = ( + (ki % ak_atom_size) * micro_size_k + + warp_i * 64 * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + ) + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + + T.ptx_wgmma_ss( + accum_dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + desc_a.data, + (A_offset * a_elems_in_bytes) >> 4, + desc_b.data, + (B_offset * b_elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + 1, + 1, + ) + + return _ss_atom(desc_a, desc_b, C_buf) def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: """ diff --git a/tilelang/intrinsics/__init__.py b/tilelang/intrinsics/__init__.py index b944ae89d..7d4eee9cc 100644 --- a/tilelang/intrinsics/__init__.py +++ b/tilelang/intrinsics/__init__.py @@ -8,6 +8,15 @@ TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) +from tilelang.cuda.intrinsics.macro.wgmma_macro_generator import ( + TensorCoreIntrinEmitter as WGMMATensorCoreIntrinEmitter, # noqa: F401 + WGMMADescriptorParams, # noqa: F401 +) +from tilelang.cuda.intrinsics.macro.tcgen05_macro_generator import ( + TensorCoreIntrinEmitter as TCGEN05TensorCoreIntrinEmitter, # noqa: F401 + TCGEN05DescriptorParams, # noqa: F401 +) + from tilelang.cuda.intrinsics.layout.mma_layout import get_swizzle_layout # noqa: F401 from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout # noqa: F401