diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 7ecffc26a2..cdadbe0400 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -4,21 +4,14 @@ import torch from triton.testing import do_bench -import tilelang import tilelang.language as T from tilelang.autotuner import autotune from tilelang import jit -from tilelang.contrib import nvcc -from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import get_e_factor -# Configure logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -arch = nvcc.get_target_compute_version() - -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - def ref_program(A, B): """ @@ -88,7 +81,7 @@ def get_configs(M, N, K): return configs -def matmul_sp(M, N, K, in_dtype, accum_dtype): +def matmul_sp(M, N, K, in_dtype, accum_dtype, e_dtype): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -164,8 +157,8 @@ def kernel( A TVM Tensor Language function (T.prim_func) that computes matmul. """ # Use half-precision for input data to reduce memory bandwidth, - # accumulate in float for better numerical accuracy - e_factor, e_dtype = ARCH_INFO[arch] + # accumulate in float for better numerical accurac + e_factor = get_e_factor(in_dtype, e_dtype) @T.prim_func def main( @@ -200,15 +193,8 @@ def main( # Clear out the accumulation buffer T.clear(C_local) - T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), - } - ) # Loop over sub-blocks in K dimension, pipelined by num_stages for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): # Load a sub-block of A from global memory into A_shared @@ -219,11 +205,13 @@ def main( T.copy(B[k * block_K, bx * block_N], B_shared) # Perform a partial matrix multiplication: # C_local += A_shared @ B_shared - T.gemm_sp_v2( + T.gemm_sp( A_shared, E_shared, B_shared, C_local, + transpose_A=False, + transpose_E=False, transpose_B=False, policy=policy, ) @@ -242,8 +230,8 @@ def main( parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--disable_cache", action="store_true") parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--e_dtype", type=str, default="int16", choices=["int16", "int8", "int32"], help="Metadata E datatype") parser.add_argument( "--bench_torch_sparse", type=str, @@ -253,16 +241,13 @@ def main( ) args = parser.parse_args() - if args.disable_cache: - tilelang.disable_cache() - M, N, K = args.m, args.n, args.k # Compute total floating-point operations to measure throughput total_flops = 2 * M * N * K # matmul(...) returns (best_latency, best_config, ref_latency) - best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype) + best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype, e_dtype=args.e_dtype) best_latency = best_result.latency best_config = best_result.config A = torch.randn(M, K, dtype=torch.float16, device="cuda") diff --git a/benchmark/matmul/benchmark_matmul_sp_compress.py b/benchmark/matmul/benchmark_matmul_sp_compress.py new file mode 100644 index 0000000000..fdb79eff3d --- /dev/null +++ b/benchmark/matmul/benchmark_matmul_sp_compress.py @@ -0,0 +1,84 @@ +import argparse +from typing import Optional + +import torch +from tilelang.profiler import do_bench +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse, torch_compress + +SUPPORTED_DTYPE_NAMES = ["float16", "bfloat16", "float32", "int8"] +SUPPORTED_META_DTYPE_NAMES = ["int8", "int16", "int32"] + + +def _resolve_torch_dtype(name: str) -> torch.dtype: + dtype = getattr(torch, name, None) + if dtype is None: + raise ValueError(f"Unsupported torch dtype: {name}") + return dtype + + +def _generate_semi_sparse(m: int, k: int, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor: + if dtype in (torch.int8, torch.uint8): + return randint_semi_sparse(m, k, low=-64, high=64, dtype=dtype, device=device) + return randn_semi_sparse(m, k, dtype=dtype, device=device) + + +def _compress_bytes(input_tensor: torch.Tensor, sparse_tensor: torch.Tensor, meta_tensor: torch.Tensor) -> int: + return ( + input_tensor.numel() * input_tensor.element_size() + + sparse_tensor.numel() * sparse_tensor.element_size() + + meta_tensor.numel() * meta_tensor.element_size() + ) + + +def benchmark_compress( + m: int, + k: int, + dtype: torch.dtype, + meta_dtype: Optional[torch.dtype] = None, # noqa: FA100 +): + a0 = _generate_semi_sparse(m, k, dtype) + + sparse0, meta0 = compress(a0, meta_dtype=meta_dtype) + ref_sparse0, ref_meta0 = torch_compress(a0, meta_dtype=meta_dtype) + + bytes_per_compress = _compress_bytes(a0, sparse0, meta0) + bytes_per_torch = _compress_bytes(a0, ref_sparse0, ref_meta0) + + tl_latency_ms = do_bench(lambda: compress(a0, meta_dtype=meta_dtype)) + torch_latency_ms = do_bench(lambda: torch_compress(a0, meta_dtype=meta_dtype)) + + tl_latency_s = tl_latency_ms * 1e-3 + torch_latency_s = torch_latency_ms * 1e-3 + tl_throughput_gbps = bytes_per_compress / tl_latency_s / 1e9 + torch_throughput_gbps = bytes_per_torch / torch_latency_s / 1e9 + + return bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_throughput_gbps, torch_throughput_gbps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark two TileLang compress operators by memory throughput") + parser.add_argument("--m", type=int, default=16384, help="Matrix rows") + parser.add_argument("--k", type=int, default=16384, help="Matrix columns") + parser.add_argument("--dtype", type=str, default="float16", choices=SUPPORTED_DTYPE_NAMES, help="Input dtype") + parser.add_argument( + "--meta_dtype", + type=str, + default=None, + choices=SUPPORTED_META_DTYPE_NAMES, + help="Metadata dtype (defaults to the library choice for the input dtype)", + ) + args = parser.parse_args() + + dtype = _resolve_torch_dtype(args.dtype) + meta_dtype = _resolve_torch_dtype(args.meta_dtype) if args.meta_dtype is not None else None + + bytes_per_compress, bytes_per_torch, tl_latency_ms, torch_latency_ms, tl_gbps, torch_gbps = benchmark_compress( + args.m, + args.k, + dtype=dtype, + meta_dtype=meta_dtype, + ) + + print(f"M={args.m} K={args.k} dtype={args.dtype} meta={args.meta_dtype or 'default'}") + print(f"tilelang: {tl_latency_ms:.4f} ms, {tl_gbps:.3f} GB/s (bytes={bytes_per_compress})") + print(f"torch: {torch_latency_ms:.4f} ms, {torch_gbps:.3f} GB/s (bytes={bytes_per_torch})") diff --git a/docs/deeplearning_operators/matmul_sparse.md b/docs/deeplearning_operators/matmul_sparse.md index 09dcc6460d..2036ba78f9 100644 --- a/docs/deeplearning_operators/matmul_sparse.md +++ b/docs/deeplearning_operators/matmul_sparse.md @@ -38,224 +38,58 @@ To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`). -A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression. +A compressor is provided in `tilelang.utils.sparse`. Pass in a dense 2:4-sparse tensor and optionally a metadata dtype to get back the compressed values and metadata: ```python from tilelang.utils.sparse import compress -A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) +A_sparse, E = compress(A) # default: int16 metadata for fp16/bf16 +A_sparse, E = compress(A.t().contiguous()) # compress the transposed layout ``` -Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. +Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern. The metadata uses a natural row-major layout that `T.gemm_sp` consumes directly — no additional layout annotation is needed. -> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor) -The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads). -For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**. +## `T.gemm_sp` -## `T.gemm_sp` with CUTLASS's compressor +A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires loading the metadata into shared memory and passing it to `T.gemm_sp`. -:::{warning} - -It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time. - -::: - -A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata. - -Check comments in below kernel code for required modification. +The default metadata dtype for fp16/bf16 is `int16` with an E-factor of 16 (one `int16` value covers 16 K-elements). For int8/float8 the default is `int32` with E-factor 32. ```python -def matmul_sp_sm80( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, - trans_A, - trans_B, +import tilelang.language as T +from tilelang.utils.sparse import get_e_factor + +def matmul_sp( + M, N, K, + block_M, block_N, block_K, + in_dtype, accum_dtype, e_dtype, + num_stages, threads, + policy=T.GemmWarpPolicy.Square, ): - is_8_bit = "8" in in_dtype - metadata_dtype = 'int32' if is_8_bit else 'int16' - E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes - A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (K, N) if not trans_B else (N, K) - A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) - - import tilelang.language as T + e_factor = get_e_factor(in_dtype, e_dtype) @T.prim_func def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), + A_sparse: T.Tensor((M, K // 2), in_dtype), + E: T.Tensor((M, K // e_factor), e_dtype), + B: T.Tensor((K, N), in_dtype), + C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata - C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout({ # Annotate reordered cutlass metadata layout - E: - make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: - make_cutlass_metadata_layout( - E_shared, mma_dtype=in_dtype, arch="8.0"), - }) - T.clear(C_frag) + A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) + B_shared = T.alloc_shared((block_K, block_N), in_dtype) + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) + T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) - if trans_A: - T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) - else: - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata - T.copy(C_frag, C[by * block_M, bx * block_N]) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, + transpose_A=False, transpose_E=False, transpose_B=False, + policy=policy) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) return main ``` - -Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`. - -## `T.gemm_sp_v2` with a custom compressor - -To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`. - -Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors. - -The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs. - -Suppose we have the following row vector: -```python -t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten() -``` - -The non-zero elements and their corresponding indices are: - -```python -t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten() -indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten() -``` - -The corresponding uint16 metadata is: -```python -# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000]) -# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16) -# Note: the above code is not runnable in python as the interpreter won't take the binary -# as 2's complement -metadata_int16 = tensor(-29107) -``` - -You can decode an int16 metadata tensor using the following utility: -```python -def decode_metadata(meta: torch.Tensor) -> torch.Tensor: - assert meta.dtype is torch.int16 - groups_per_meta = 16 // 4 - out = [] - for g in range(groups_per_meta): - group_bits = (meta >> (g * 4)) & 0xF - idx0 = group_bits & 0x3 - idx1 = (group_bits >> 2) & 0x3 - out.append(torch.stack([idx0, idx1], dim=-1)) - return torch.concat(out, dim=-1).view(meta.shape[0], -1) -``` - -The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level. - -For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function. - -If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel. - -```python - -@tilelang.jit(out_idx=[1, 2], pass_configs={ - tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, -}) -def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): - e_factor, e_dtype = ARCH_INFO["8.0"] - e_K = K // e_factor - elem, group = 2, 4 - - assert M % block_M == 0, "M must be divisible by block_M" - assert K % block_K == 0, "K must be divisible by block_K" - assert K % e_factor == 0, "K must be divisible by e_factor" - assert block_K % e_factor == 0, "block_K must be divisible by e_factor" - - @T.prim_func - def kernel( - A: T.Tensor((M, K), dtype), - A_sp: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, e_K), e_dtype), - ): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) - E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - if use_cutlass_layout: # NOTE: Make sure compressor metadata layout - T.annotate_layout({ # is same with your computation kernel - E: - make_cutlass_metadata_layout( - E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: - make_cutlass_metadata_layout( - E_shared, - mma_dtype="float16", - arch="8.0", - block_k=block_K), - }) - T.clear(A_sp_shared) - T.clear(E_shared) - non_zero_cnt = T.alloc_local((1, ), dtype="uint8") - non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") - T.copy(A[bx * block_M, by * block_K], A_shared) - for tm in T.Parallel(block_M): - for g_i in range(0, block_K // group): - a_k = g_i * group - T.clear(non_zero_cnt) - T.clear(non_zero_elt_log_idx) - for i in range(group): - val = A_shared[tm, a_k + i] - if val != 0.0: - non_zero_elt_log_idx[non_zero_cnt[0]] = i - A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val - non_zero_cnt[0] += 1 - if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: - non_zero_elt_log_idx[0] = 0 - non_zero_elt_log_idx[1] = 3 - A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] - A_sp_shared[tm, a_k // 2] = 0.0 - elif non_zero_cnt[0] == 1: - A_sp_shared[tm, a_k // 2 + 1] = 0 - non_zero_elt_log_idx[1] = 3 - for i in T.serial(elem): - val = non_zero_elt_log_idx[i] - E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) - T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) - T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) - - return kernel -``` - -## A note on `gemm_sp` and `gemm_sp_v2` - -Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout. - -However, fixing a specific layout introduces several potential issues: - -1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling. - -2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically. - -3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.) - -`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm`. It lowers directly to PTX, removing the need for a fixed metadata layout. diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py deleted file mode 100644 index 4b03ae83da..0000000000 --- a/examples/gemm_sp/example_custom_compress.py +++ /dev/null @@ -1,342 +0,0 @@ -import argparse - -import tilelang -import tilelang.language as T - -from tilelang.layout import make_cutlass_metadata_layout -from tilelang.utils.sparse import randn_semi_sparse -from tilelang.utils.tensor import torch_assert_close - -from tilelang.profiler import do_bench - -import torch - -torch.manual_seed(42) - -DEFAULT_CONFIG = { # take best config from autotune script - "4090": { - T.float: { - "block_M": 128, - "block_N": 64, - "block_K": 64, - "num_stages": 1, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - T.float16: { - "block_M": 256, - "block_N": 128, - "block_K": 64, - "num_stages": 2, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - }, - "h20": { - T.float: { - "block_M": 128, - "block_N": 64, - "block_K": 128, - "num_stages": 3, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - T.float16: { - "block_M": 128, - "block_N": 64, - "block_K": 128, - "num_stages": 3, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - }, -} - -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - - -@tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16_custom_compress( - M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout -): - e_factor, e_dtype = (16, T.int16) - - @T.prim_func - def gemm_sp_fp16_custom_compress( - A_sparse: T.Tensor((M, K // 2), T.float16), - E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), T.float16), - C: T.Tensor((M, N), accum_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) - E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), T.float16) - C_shared = T.alloc_shared((block_M, block_N), accum_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - if use_cutlass_layout: - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), - } - ) - T.clear(C_local) - T.disable_warp_group_reg_alloc() - T.use_swizzle(panel_size=10, enable=enable_rasterization) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - T.copy(E[by * block_M, k * block_K // e_factor], E_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) - - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return gemm_sp_fp16_custom_compress - - -def torch_compress(dense): - """ - A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. - """ - if dense.dim() != 2: - raise RuntimeError(f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor") - - m, k = dense.shape - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 16") - else: - if m % 32 != 0: - raise RuntimeError(f"Number of rows of dense matrix {m} must be divisible by 32") - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError(f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}") - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - return (sparse, meta) - - -def decode_metadata(meta: torch.Tensor) -> torch.Tensor: - assert meta.dtype is torch.int16 - groups_per_meta = 16 // 4 # 4 groups per uint16 - out = [] - for g in range(groups_per_meta): - group_bits = (meta >> (g * 4)) & 0xF - idx0 = group_bits & 0x3 - idx1 = (group_bits >> 2) & 0x3 - out.append(torch.stack([idx0, idx1], dim=-1)) - return torch.concat(out, dim=-1).view(meta.shape[0], -1) - - -@tilelang.jit( - out_idx=[1, 2], - pass_configs={ - tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, - }, -) -def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): - e_factor, e_dtype = ARCH_INFO["8.0"] - e_K = K // e_factor - elem, group = 2, 4 - - assert M % block_M == 0, "M must be divisible by block_M" - assert K % block_K == 0, "K must be divisible by block_K" - assert K % e_factor == 0, "K must be divisible by e_factor" - assert block_K % e_factor == 0, "block_K must be divisible by e_factor" - - @T.prim_func - def kernel( - A: T.Tensor((M, K), dtype), - A_sp: T.Tensor((M, K // 2), dtype), - E: T.Tensor((M, e_K), e_dtype), - ): - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) - E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - if use_cutlass_layout: - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), - } - ) - T.clear(A_sp_shared) - T.clear(E_shared) - # TODO: alloc_var seems buggy here - non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) - non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) - T.copy(A[bx * block_M, by * block_K], A_shared) - for tm in T.Parallel(block_M): - for g_i in range(0, block_K // group): - a_k = g_i * group - non_zero_cnt[0] = 0 - for i in range(elem): - non_zero_elt_log_idx[i] = 0 - for i in range(group): - val = A_shared[tm, a_k + i] - if val != 0.0: - non_zero_elt_log_idx[non_zero_cnt[0]] = i - A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val - non_zero_cnt[0] += 1 - # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main - if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: - non_zero_elt_log_idx[0] = 0 - non_zero_elt_log_idx[1] = 3 - A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] - A_sp_shared[tm, a_k // 2] = 0.0 - elif non_zero_cnt[0] == 1: - A_sp_shared[tm, a_k // 2 + 1] = 0 - non_zero_elt_log_idx[1] = 3 - for i in T.serial(elem): - val = non_zero_elt_log_idx[i] - E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) - T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) - T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) - - return kernel - - -def main(M=1024, N=1024, K=1024, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=T.float, cfg="4090"): - kernel = matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) - - a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) - b = torch.randn(K, N, device="cuda", dtype=torch.half) - - if use_torch_compressor: - assert not use_cutlass_layout, "torch sparse must be used with naive layout" - a_sparse, e = torch_compress(a) - else: - a_sparse, e = compress_kernel(M, K, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) - - c = kernel(a_sparse, e, b) - - ref_c = a @ b - - assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" - torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) - print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") - - latency = do_bench(lambda: kernel(a_sparse, e, b)) - ref_latency = do_bench(lambda: a @ b) - - total_flops = 2 * M * N * K - tflops = total_flops / latency / 1e9 - ref_tflops = total_flops / ref_latency / 1e9 - print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") - print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency / 1e3:} s") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") - parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") - parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") - args = parser.parse_args() - main( - M=args.m, - N=args.n, - K=args.k, - use_cutlass_layout=args.use_cutlass_layout, - use_torch_compressor=args.use_torch_compressor, - accum_dtype=args.accum_dtype, - cfg=args.cfg, - ) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 769ea67362..3f23e6a397 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -3,64 +3,15 @@ import tilelang import tilelang.language as T -from tilelang.layout import make_cutlass_metadata_layout -from tilelang.utils.sparse import compress, randn_semi_sparse -from tilelang.contrib import nvcc +from tilelang.utils.sparse import compress, randn_semi_sparse, get_e_factor from tilelang.profiler import do_bench import torch -arch = nvcc.get_target_compute_version() - -DEFAULT_CONFIG = { # take best config from autotune script - "4090": { - T.float: { - "block_M": 128, - "block_N": 64, - "block_K": 64, - "num_stages": 1, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - T.float16: { - "block_M": 256, - "block_N": 128, - "block_K": 64, - "num_stages": 2, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - }, - "h20": { - T.float: { - "block_M": 128, - "block_N": 64, - "block_K": 128, - "num_stages": 3, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - T.float16: { - "block_M": 128, - "block_N": 64, - "block_K": 128, - "num_stages": 3, - "thread_num": 128, - "policy": T.GemmWarpPolicy.Square, - "enable_rasterization": True, - }, - }, -} - -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} - @tilelang.jit(out_idx=[-1]) -def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): - e_factor, e_dtype = ARCH_INFO[arch] +def matmul_sp_fp16(M, N, K, accum_dtype, e_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization): + e_factor = get_e_factor(T.float16, e_dtype) @T.prim_func def gemm_sp_fp16( @@ -79,17 +30,11 @@ def gemm_sp_fp16( T.clear(C_local) T.disable_warp_group_reg_alloc() T.use_swizzle(panel_size=10, enable=enable_rasterization) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), - } - ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared) T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, transpose_A=False, transpose_E=False, transpose_B=False, policy=policy) T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -97,13 +42,26 @@ def gemm_sp_fp16( return gemm_sp_fp16 -def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"): - kernel = matmul_sp_fp16(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) +def main( + M=1024, + N=1024, + K=1024, + accum_dtype=T.float, + e_dtype=T.int16, + block_M=128, + block_N=128, + block_K=64, + num_stages=2, + thread_num=128, + policy=T.GemmWarpPolicy.Square, + enable_rasterization=True, +): + kernel = matmul_sp_fp16(M, N, K, accum_dtype, e_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization) a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) b = torch.randn(K, N, device="cuda", dtype=torch.half) - a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch) + a_sparse, e = compress(a, meta_dtype=e_dtype.as_torch()) c = kernel(a_sparse, e, b) ref_c = a @ b @@ -123,11 +81,32 @@ def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser = argparse.ArgumentParser(description="Sparse FP16 MatMul Example") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") - parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") + parser.add_argument( + "--e_dtype", + default=T.int16, + choices=[T.int8, T.int16, T.int32], + help="Data type for metadata E, which controls the sparsity pattern. Note that int8 and int32 are only supported on sm90+", + ) + parser.add_argument("--accum_dtype", default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") + parser.add_argument("--block_M", type=int, default=128) + parser.add_argument("--block_N", type=int, default=256) + parser.add_argument("--block_K", type=int, default=128) + parser.add_argument("--num_stages", type=int, default=2) + parser.add_argument("--thread_num", type=int, default=256) args = parser.parse_args() - main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg) + main( + M=args.m, + N=args.n, + K=args.k, + accum_dtype=args.accum_dtype, + e_dtype=args.e_dtype, + block_M=args.block_M, + block_N=args.block_N, + block_K=args.block_K, + num_stages=args.num_stages, + thread_num=args.thread_num, + ) diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py index aa1a747f24..f99402bb22 100644 --- a/examples/gemm_sp/test_example_gemm_sp.py +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -1,17 +1,9 @@ import tilelang.testing -import example_custom_compress import example_gemm_sp @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_eq(9, 0) -def test_example_custom_compress(): - example_custom_compress.main() - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_gemm_sp(): example_gemm_sp.main() diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 832a8fba90..eac8bb5f28 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -601,12 +601,18 @@ std::string CodeGenTileLangCUDA::Finish() { if (need_wgmma_instruction_h_) { decl_stream << "#include \n"; } + if (need_wgmma_sp_instruction_h_) { + decl_stream << "#include \n"; + } if (need_tcgen05mma_instruction_h_) { decl_stream << "#include \n"; } if (need_mma_sm70_instruction_h_) { decl_stream << "#include \n"; } + if (need_mma_sp_instruction_h_) { + decl_stream << "#include \n"; + } if (need_tcgen05_common_h_) { decl_stream << "#include \n"; } @@ -2632,16 +2638,70 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string b_offset = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_offset = this->PrintExpr(op->args[11]); - std::string metadata = this->PrintExpr(op->args[12]); - std::string metadata_offset = this->PrintExpr(op->args[13]); - std::string sparse_selector = this->PrintExpr(op->args[14]); - bool saturate = Downcast(op->args[15])->value; + std::string e_ref = this->PrintExpr(op->args[12]); + std::string e_offset = this->PrintExpr(op->args[13]); + int64_t sparse_selector = Downcast(op->args[14])->value; + + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_mma_sp_instruction_h_ = true; this->PrintIndent(); - std::string asm_code = PrintMMAAssembly( - shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, - b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, - sparse_selector, "", true, saturate); - this->stream << asm_code; + + std::string mma_call = + "tl::mma_sp_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), " + "(TransB), (SparseSel)>(" + "reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((A_ptr) + (A_offset)), " + "reinterpret_cast((B_ptr) + (B_offset)), " + "reinterpret_cast((E_ptr) + (E_offset)));\n"; + tl::codegen::Replacer replacer; + + // TF32 workaround: float32 A/B in TF32 context maps to kTensorFloat32. + std::string AType = tl::codegen::ptx::DTypeEnumToString(dtype_a_enum); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(dtype_b_enum); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + std::string ARegType = tl::codegen::GetMMARegisterType(dtype_a_enum); + if (ARegType == "float") { + ARegType = "uint32_t"; + } + std::string BRegType = tl::codegen::GetMMARegisterType(dtype_b_enum); + if (BRegType == "float") { + BRegType = "uint32_t"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true"); + replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true"); + replacer.register_rule("(SparseSel)", sparse_selector == 0 + ? "SM80::MMA::SparseSel::Zero" + : "SM80::MMA::SparseSel::One"); + replacer.register_rule("(ARegType)", ARegType); + replacer.register_rule("(BRegType)", BRegType); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", a_offset); + replacer.register_rule("(B_ptr)", b_ref); + replacer.register_rule("(B_offset)", b_offset); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(E_ptr)", e_ref); + replacer.register_rule("(E_offset)", e_offset); + this->stream << replacer.rewrite(mma_call); } else if (op->op.same_as(tl::ptx_wgmma_ss())) { // arg 0: dtype // arg 1: shape @@ -2786,6 +2846,143 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(scale_out)", scale_out); wgmma_call = replacer.rewrite(wgmma_call); this->stream << wgmma_call; + } else if (op->op.same_as(tl::ptx_wgmma_sp_ss())) { + ICHECK_EQ(op->args.size(), 18U) << "ptx_wgmma_sp_ss args is " << op->args; + std::string wgmma_prefix = Downcast(op->args[0])->value; + bool a_is_k_major = Downcast(op->args[1])->value; + bool b_is_k_major = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_desc = this->PrintExpr(op->args[6]); + std::string A_offset = this->PrintExpr(op->args[7]); + std::string e_data = this->PrintExpr(op->args[8]); + std::string E_offset = this->PrintExpr(op->args[9]); + std::string sparse_selector = this->PrintExpr(op->args[10]); + std::string b_desc = this->PrintExpr(op->args[11]); + std::string B_offset = this->PrintExpr(op->args[12]); + std::string c_data = this->PrintExpr(op->args[13]); + std::string C_offset = this->PrintExpr(op->args[14]); + std::string scale_out = this->PrintExpr(op->args[15]); + bool scale_in_a = Downcast(op->args[16])->value; + bool scale_in_b = Downcast(op->args[17])->value; + + this->PrintIndent(); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(wgmma_prefix); + need_wgmma_sp_instruction_h_ = true; + int sparse_sel_val_ss = Downcast(op->args[10])->value; + std::string spsel_ss = (sparse_sel_val_ss == 0) + ? "cute::SM90::GMMA::SparseSel::Zero" + : "cute::SM90::GMMA::SparseSel::One"; + std::string wgmma_sp_asm_code = + "tl::wgmma_sp_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), " + "(tnspB), (scaleA), (scaleB), (spsel)>(uint64_t((desc_a) + " + "(A_offset)), " + "uint64_t((desc_b) + (B_offset)), " + "reinterpret_cast((C_data)) " + "+ (C_offset), (scale_out), *reinterpret_cast((e_data) + " + "(E_offset)));\n"; + + tl::codegen::Replacer replacer; + std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(C_dtype)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true"); + replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(spsel)", spsel_ss); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(e_data)", e_data); + replacer.register_rule("(E_offset)", E_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C_data)", c_data); + replacer.register_rule("(C_offset)", C_offset); + replacer.register_rule("(scale_out)", scale_out); + wgmma_sp_asm_code = replacer.rewrite(wgmma_sp_asm_code); + this->stream << wgmma_sp_asm_code; + } else if (op->op.same_as(tl::ptx_wgmma_sp_rs())) { + ICHECK_EQ(op->args.size(), 17U) << "ptx_wgmma_sp_rs args is " << op->args; + std::string wgmma_prefix = Downcast(op->args[0])->value; + bool b_is_k_major = Downcast(op->args[1])->value; + std::string A_dtype = Downcast(op->args[2])->value; + std::string B_dtype = Downcast(op->args[3])->value; + std::string C_dtype = Downcast(op->args[4])->value; + std::string a_ref = this->PrintExpr(op->args[5]); + std::string A_offset = this->PrintExpr(op->args[6]); + std::string e_data = this->PrintExpr(op->args[7]); + std::string E_offset = this->PrintExpr(op->args[8]); + std::string sparse_selector = this->PrintExpr(op->args[9]); + std::string b_desc = this->PrintExpr(op->args[10]); + std::string B_offset = this->PrintExpr(op->args[11]); + std::string c_data = this->PrintExpr(op->args[12]); + std::string C_offset = this->PrintExpr(op->args[13]); + std::string scale_out = this->PrintExpr(op->args[14]); + bool scale_in_a = Downcast(op->args[15])->value; + bool scale_in_b = Downcast(op->args[16])->value; + + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(wgmma_prefix); + need_wgmma_sp_instruction_h_ = true; + this->PrintIndent(); + int sparse_sel_val_rs = Downcast(op->args[9])->value; + std::string spsel_rs = (sparse_sel_val_rs == 0) + ? "cute::SM90::GMMA::SparseSel::Zero" + : "cute::SM90::GMMA::SparseSel::One"; + std::string wgmma_sp_rs_asm_code = + "tl::wgmma_sp_rs<(AType), (BType), (CType), (M), (N), (K), false, " + "(tnspB), (scaleA), (scaleB), (spsel)>(reinterpret_cast((A_ptr) + " + "(A_offset)), uint64_t((desc_b) + (B_offset)), " + "reinterpret_cast((C_data)) + (C_offset), (scale_out), " + "*reinterpret_cast((e_data) + (E_offset)));\n"; + + tl::codegen::Replacer replacer; + std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype); + if (AType == "tl::DataType::kFloat32") { + AType = "tl::DataType::kTensorFloat32"; + } + std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype); + if (BType == "tl::DataType::kFloat32") { + BType = "tl::DataType::kTensorFloat32"; + } + + replacer.register_rule("(AType)", AType); + replacer.register_rule("(BType)", BType); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(C_dtype)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(spsel)", spsel_rs); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(e_data)", e_data); + replacer.register_rule("(E_offset)", E_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C_data)", c_data); + replacer.register_rule("(C_offset)", C_offset); + replacer.register_rule("(scale_out)", scale_out); + wgmma_sp_rs_asm_code = replacer.rewrite(wgmma_sp_rs_asm_code); + this->stream << wgmma_sp_rs_asm_code; } else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) { ICHECK_EQ(op->args.size(), 15U) << "ptx_tcgen05_mma_ss args is " << op->args; diff --git a/src/backend/cuda/codegen/codegen_cuda.h b/src/backend/cuda/codegen/codegen_cuda.h index 656e7dbf38..6663a54297 100644 --- a/src/backend/cuda/codegen/codegen_cuda.h +++ b/src/backend/cuda/codegen/codegen_cuda.h @@ -125,6 +125,10 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool need_tcgen05mma_instruction_h_{false}; // whether need tl mma_sm70 instruction header bool need_mma_sm70_instruction_h_{false}; + // whether need tl mma_sp instruction header + bool need_mma_sp_instruction_h_{false}; + // whether need tl wgmma_sp instruction header + bool need_wgmma_sp_instruction_h_{false}; // whether need tcgen_05 common header bool need_tcgen05_common_h_{false}; // whether need cast_smem_ptr_to_int helper function diff --git a/src/backend/cuda/op/gemm_sp.cc b/src/backend/cuda/op/gemm_sp.cc index f7eaffd80a..47306a6332 100644 --- a/src/backend/cuda/op/gemm_sp.cc +++ b/src/backend/cuda/op/gemm_sp.cc @@ -1,19 +1,23 @@ /*! * \file tl/backend/cuda/op/gemm_sp.cc - * \brief CUDA implementation for tl.gemm_sp lowering and layout inference. - */ - + * \brief CUDA implementation for tl.gemm_sp instruction selection. + */ \ #include "op/gemm_sp.h" +#include "op/gemm.h" -#include "layout/layout.h" #include "op/builtin.h" +#include "op/tcgen5_meta.h" #include "op/utils.h" #include "target/utils.h" -#include -#include +#include +#include -#include +#include +#include +#include +#include +#include namespace tvm { namespace tl { @@ -24,191 +28,293 @@ namespace cuda { namespace { -constexpr const char *kCudaMMA = "cuda.mma"; -constexpr const char *kCudaWGMMA = "cuda.wgmma"; +constexpr const char *kCudaMMASP = "cuda.mma.sp"; +constexpr const char *kCudaWGMMASP = "cuda.wgmma.sp"; +constexpr const char *kCudaTCGEN05SP = "cuda.tcgen05.sp"; + +bool CheckWGMMA(const GemmSPNode &op) { + if (op.B.scope() != "shared.dyn" && op.B.scope() != "shared") { + return false; + } + + if (op.C->dtype == DataType::Float(16) || + op.C->dtype == DataType::Float(32)) { + if (op.A->dtype == DataType::Float(16) && + op.B->dtype == DataType::Float(16)) + return op.K % 32 == 0; + else if (op.A->dtype == DataType::BFloat(16) && + op.B->dtype == DataType::BFloat(16)) + return op.K % 32 == 0; + else if (op.A->dtype == DataType::Float(32) && + op.B->dtype == DataType::Float(32)) + return (!op.trans_A) && op.trans_B && op.K % 16 == 0; + else if (op.A->dtype.is_float8() && op.B->dtype.is_float8()) + return (!op.trans_A) && op.trans_B && op.K % 64 == 0; + else + return false; + } else if (op.C->dtype == DataType::Int(32)) { + if ((op.A->dtype == DataType::Int(8) || op.A->dtype == DataType::UInt(8)) && + (op.B->dtype == DataType::Int(8) || op.B->dtype == DataType::UInt(8))) + return (!op.trans_A) && op.trans_B && op.K % 64 == 0; + else + return false; + } else { + return false; + } +} + +// TODO @botbw: support tcgen5mma.sp for sparse inputs when it's available +bool AllowTcgen5Mma(const GemmSPNode &op, Target target) { + bool scope_ok = (IsSharedBuffer(op.A) || op.A.scope() == "shared.tmem") && + IsSharedBuffer(op.B) && op.C.scope() == "shared.tmem"; + if (!TargetIsSm100(target) || !scope_ok) + return false; + DataType ab_dtype = + (op.A.scope() == "shared.tmem") ? op.B->dtype : op.A->dtype; + return GetTCGEN5MMAMeta(op.M, op.N, op.K, ab_dtype, op.C->dtype).first; +} + +bool AllowWgmma(const GemmSPNode &op, int block_size, Target target) { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); -String SelectGemmInst(const GemmSPNode &op, int block_size, Target target) { int warp_size = TargetGetWarpSize(target); - bool maybe_wgmma = TargetIsHopper(target) && op.m_ >= 64 && - (block_size / warp_size % 4 == 0); - return maybe_wgmma ? String(kCudaWGMMA) : String(kCudaMMA); + int num_warps = block_size / warp_size; + return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && op.M >= 64 && num_warps % 4 == 0 && + CheckWGMMA(op); } -bool UseWgmma(String gemm_inst) { return gemm_inst == kCudaWGMMA; } +void FatalWgmmaUnavailable(const GemmSPNode &op, Target target) { + LOG(FATAL) << "T.wgmma_gemm() requires Hopper WGMMA lowering, but " + "constraints were not satisfied. Got target=" + << target << ", A(scope=" << op.A.scope() + << ", dtype=" << op.A->dtype << "), B(scope=" << op.B.scope() + << ", dtype=" << op.B->dtype << "), C(scope=" << op.C.scope() + << ", dtype=" << op.C->dtype << "), M=" << op.M << ", N=" << op.N + << ", K=" << op.K << "."; +} -} // namespace +void FatalTcgen5Unavailable(const GemmSPNode &op, Target target) { + LOG(FATAL) << "tcgen5"; + // LOG(FATAL) << "T.tcgen05_gemm() requires Blackwell TCGEN5MMA lowering, " + // "but constraints were not satisfied. Got target=" + // << target << ", A(scope=" << op.A.scope() + // << ", dtype=" << op.A->dtype << "), B(scope=" << op.B.scope() + // << ", dtype=" << op.B->dtype << "), C(scope=" << op.C.scope() + // << ", dtype=" << op.C->dtype << "), M=" << op.M + // << ", N=" << op.N << ", K=" << op.K << "."; +} -struct GemmSP { - static std::pair - ComputeWarpPartition(const GemmSPWarpPolicyNode &policy, int M, int N, - int block_size, Target target, String gemm_inst, - int bits) { - int num_warps = block_size / TargetGetWarpSize(target); +std::pair +ComputeDefaultWarpPartition(const GemmSPWarpPolicyNode &policy, int M, int N, + int num_warps, int k_n_per_warp) { + int m_warp = 1, n_warp = 1; + constexpr int kMPerWarp = 16; - ICHECK(gemm_inst == kCudaMMA || gemm_inst == kCudaWGMMA) - << "CUDA GemmSP currently only supports MMA and WGMMA"; - auto [m_warp, n_warp] = - static_cast(policy).computeWarpPartition( - M, N, block_size, target, gemm_inst); - - // Special handling for gemm_sp when the tiling size is not a multiple. - // This should be consistent with shape check in gemm_sp_sm80.h. - int m_atom_size = bits == 16 ? 32 : 16; - int n_atom_size = bits == 16 ? 32 : 16; - static const char *err_msg = - "Cannot arrange the warp shape to be a multiple of atom size, please " - "reduce num threads or increase tiling size"; - if (TargetIsAmpere(target)) { - int warp_shape_m = M / m_warp; - int warp_shape_n = N / n_warp; - if (warp_shape_m % m_atom_size) { // GemmWarpPolicy::kFullRow - m_warp = M / m_atom_size; - ICHECK(m_warp > 0) << err_msg; - n_warp = num_warps / m_warp; - warp_shape_n = N / n_warp; - ICHECK(warp_shape_n % n_atom_size == 0) << err_msg; - } else if (warp_shape_n % n_atom_size != - 0) { // GemmWarpPolicy::kFullColumn - n_warp = N / n_atom_size; - ICHECK(n_warp > 0) << err_msg; - m_warp = num_warps / n_warp; - warp_shape_m = M / m_warp; - ICHECK(warp_shape_m % m_atom_size == 0) << err_msg; + ICHECK(M % kMPerWarp == 0) + << "M must be divisible by " << kMPerWarp << ", but got " << M; + ICHECK(N % k_n_per_warp == 0) + << "N must be divisible by " << k_n_per_warp << ", but got " << N; + + if (policy.isFullRow()) { + m_warp = num_warps; + n_warp = 1; + if (M % (m_warp * kMPerWarp) != 0) { + int max_m_warps = M / kMPerWarp; + m_warp = max_m_warps; + n_warp = num_warps / m_warp; + if (n_warp == 0) + n_warp = 1; + } + } else if (policy.isFullCol()) { + m_warp = 1; + n_warp = num_warps; + if (N % (n_warp * k_n_per_warp) != 0) { + int max_n_warps = N / k_n_per_warp; + n_warp = max_n_warps; + m_warp = num_warps / n_warp; + if (m_warp == 0) + m_warp = 1; + } + } else if (policy.isSquare()) { + int max_m_warps = M / kMPerWarp; + float ideal_ratio = N > 0 ? static_cast(M) / N : 1.0f; + + int best_m = 1; + int best_n = 1; + float best_balance = std::numeric_limits::max(); + for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { + int n = num_warps / m; + + float m_per_warp = static_cast(M) / (m * kMPerWarp); + float n_per_warp = static_cast(N) / (n * k_n_per_warp); + if (m_per_warp < 1 || n_per_warp < 1) + continue; + if (m * n != num_warps) + continue; + + float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); + if (balance < best_balance) { + best_balance = balance; + best_m = m; + best_n = n; } - ICHECK(m_warp * n_warp == num_warps) - << "m_warp * n_warp must equal num_warps, please report an issue " - "when encounter this" - << ", m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps" - << num_warps; - policy.m_warp = m_warp; - policy.n_warp = n_warp; } - return {m_warp, n_warp}; + + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmSPWarpPolicy"; } - static Stmt Lower(const GemmSPNode &op, const LowerArgs &T, - arith::Analyzer *analyzer) { - auto block_size = *as_const_int(T.thread_bounds->extent); - auto gemm_inst = SelectGemmInst(op, block_size, T.target); - bool maybe_wgmma = UseWgmma(gemm_inst); - auto [warp_m, warp_n] = op.policy_->computeWarpPartition( - op.m_, op.n_, block_size, T.target, gemm_inst, op.a_->dtype.bits()); - - std::stringstream ss; - std::string op_name = "tl::gemm_sp_ss"; - ICHECK(IsSharedBuffer(op.a_) && IsSharedBuffer(op.b_)) - << "Only support shared.dyn scope for A and B, but received " - << op.a_.scope() << " and " << op.b_.scope(); - ICHECK(IsSharedBuffer(op.e_)) - << "Only support shared.dyn scope for E as copy from smem to rmem are " - "delegated to cute implementation, found " - << op.e_.scope(); - ss << op_name << "<" << op.m_ << ", " << op.n_ << ", " << op.k_ << ", "; - ss << warp_m << ", " << warp_n << ", "; - ss << op.transA_ << ", " << op.transB_; - ss << ", " << op.clearAccum_; - if (TargetIsHopper(T.target)) { - ss << ", " << (maybe_wgmma ? "true" : "false"); + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + policy.m_warp = m_warp; + policy.n_warp = n_warp; + return {m_warp, n_warp}; +} + +std::pair +ComputeWgmmaWarpPartition(const GemmSPWarpPolicyNode &policy, int M, int N, + int num_warps) { + ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128*k threads."; + + int m_warp = 1, n_warp = 1; + constexpr int kMPerWarp = 16; + constexpr int kNPerWarp = 8; + constexpr int kGroup = 4; + + ICHECK(M % kMPerWarp == 0) + << "M must be divisible by " << kMPerWarp << ", but got " << M; + ICHECK(N % kNPerWarp == 0) + << "N must be divisible by " << kNPerWarp << ", but got " << N; + + m_warp = kGroup; + n_warp = num_warps / m_warp; + + if (policy.isFullRow()) { + for (int cand = num_warps; cand >= kGroup; cand -= kGroup) { + if (M % (cand * kMPerWarp) == 0) { + m_warp = cand; + n_warp = num_warps / m_warp; + break; + } + } + } else if (policy.isFullCol()) { + int cand_n = n_warp; + if (N % (cand_n * kNPerWarp) != 0) { + int max_n = N / kNPerWarp; + for (int n = std::min(cand_n, max_n); n >= 1; --n) { + if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) { + n_warp = n; + m_warp = num_warps / n_warp; + break; + } + } } - if (op.wgWait_ != 0) { - ss << ", " << op.wgWait_; + } else if (policy.isSquare()) { + int max_m = M / kMPerWarp; + int max_n = N / kNPerWarp; + + float ideal = N > 0 ? static_cast(M) / N : 1.f; + float best_score = std::numeric_limits::max(); + int best_m = kGroup, best_n = n_warp; + + for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) { + if (num_warps % m) + continue; + int n = num_warps / m; + if (n > max_n) + continue; + + float m_per_warp = static_cast(M) / (m * kMPerWarp); + float n_per_warp = static_cast(N) / (n * kNPerWarp); + float score = std::abs(m_per_warp / n_per_warp - ideal); + + if (score < best_score) { + best_score = score; + best_m = m; + best_n = n; + } } - ss << ">"; - - PrimExpr Aptr = - MakeAccessPtrFromRegion(op.aRegion_, /*r*/ 1, /*require_2d*/ true); - PrimExpr Bptr = - MakeAccessPtrFromRegion(op.bRegion_, /*r*/ 1, /*require_2d*/ true); - PrimExpr Cptr = - MakeAccessPtrFromRegion(op.cRegion_, /*rw*/ 3, /*require_2d*/ true); - PrimExpr Eptr = - MakeAccessPtrFromRegion(op.eRegion_, /*r*/ 1, /*require_2d*/ false); - - auto new_call = - Call(DataType::Handle(), tl::tl_gemm_sp(), - Array{StringImm(ss.str()), Aptr, Bptr, Cptr, Eptr}); - return Evaluate(new_call); + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmSPWarpPolicy"; } - static LayoutMap InferLayout(const GemmSPNode &op, const LayoutInferArgs &T, - InferLevel level) { - LayoutMap results; - ICHECK(IsFragmentBuffer(op.c_)); - auto thread_range = T.thread_bounds; - auto block_size = *as_const_int(thread_range->extent); - if (TargetIsHopper(T.target)) { - auto gemm_inst = SelectGemmInst(op, block_size, T.target); - bool maybe_wgmma = UseWgmma(gemm_inst); - auto [warp_m, warp_n] = op.policy_->computeWarpPartition( - op.m_, op.n_, block_size, T.target, gemm_inst, op.a_->dtype.bits()); - auto fragment = - maybe_wgmma - ? makeGemmFragmentCHopper(op.m_, op.n_, op.m_ / warp_m, - op.n_ / warp_n, op.c_->dtype.bits()) - : makeGemmFragmentC(op.m_, op.n_, op.m_ / warp_m, op.n_ / warp_n, - op.c_->dtype.bits()); - results.Set(op.c_, fragment->BindThreadRange(thread_range)); - if (IsSharedBuffer(op.a_)) { - int dim_A = op.a_->shape.size(); - const int64_t mat_stride = *as_const_int(op.a_->shape[dim_A - 2]); - const int64_t mat_continuous = *as_const_int(op.a_->shape[dim_A - 1]); - auto layout = - makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous, - op.a_->dtype.bits(), op.transA_ ? 1 : 2); - results.Set(op.a_, ExpandLayoutToMatchBuffer(layout, op.a_)); - } else { - ICHECK(false) << "Not implemented"; - } + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + policy.m_warp = m_warp; + policy.n_warp = n_warp; + return {m_warp, n_warp}; +} - if (IsSharedBuffer(op.b_)) { - int dim_B = op.b_->shape.size(); - const int64_t mat_stride = *as_const_int(op.b_->shape[dim_B - 2]); - const int64_t mat_continuous = *as_const_int(op.b_->shape[dim_B - 1]); - const int64_t continuity = - op.transB_ ? mat_continuous : mat_continuous / warp_n; - auto layout = - makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, - op.b_->dtype.bits(), op.transB_ ? 2 : 1); - results.Set(op.b_, ExpandLayoutToMatchBuffer(layout, op.b_)); - } else { - ICHECK(false) << "WGMMA only support B in shared."; - } - } else if (TargetIsAmpere(T.target)) { - auto [warp_m, warp_n] = op.policy_->computeWarpPartition( - op.m_, op.n_, block_size, T.target, String(kCudaMMA), - op.a_->dtype.bits()); - auto fragment = makeGemmSparseFragmentC( - op.m_, op.n_, op.m_ / warp_m, op.n_ / warp_n, op.c_->dtype.bits()); - results.Set(op.c_, fragment->BindThreadRange(thread_range)); - - if (IsSharedBuffer(op.a_)) { - int dim_A = op.a_->shape.size(); - const int64_t mat_stride = *as_const_int(op.a_->shape[dim_A - 2]); - const int64_t mat_continuous = *as_const_int(op.a_->shape[dim_A - 1]); - auto layout = makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, - op.a_->dtype.bits()); - results.Set(op.a_, ExpandLayoutToMatchBuffer(layout, op.a_)); - } else if (IsFragmentBuffer(op.a_)) { - ICHECK(false) << "Not Implemented"; - } else { - ICHECK(0); - } - if (IsSharedBuffer(op.b_)) { - int dim_B = op.b_->shape.size(); - const int64_t mat_stride = *as_const_int(op.b_->shape[dim_B - 2]); - const int64_t mat_continuous = *as_const_int(op.b_->shape[dim_B - 1]); - auto layout = makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, - op.b_->dtype.bits()); - results.Set(op.b_, ExpandLayoutToMatchBuffer(layout, op.b_)); - } else if (IsFragmentBuffer(op.b_)) { - ICHECK(false) << "Not Implemented"; - } else { - ICHECK(0); +} // namespace + +struct GemmSP { + static String SelectInst(const GemmSPNode &op, int block_size, + Target target) { + if (op.isWgmma_) { + if (!AllowWgmma(op, block_size, target)) { + FatalWgmmaUnavailable(op, target); } - } else { - ICHECK(0) << "Architecture is not supported: " << T.target->str(); + return kCudaWGMMASP; + } + if (op.isTcgen05_) { + FatalTcgen5Unavailable(op, target); + + // if (!AllowTcgen5Mma(op, target)) { + // FatalTcgen5Unavailable(op, target); + // } + // return kCudaTCGEN05SP; + } + + if (AllowTcgen5Mma(op, target)) { + LOG(WARNING) << "TCGEN5MMASP is not yet available for sparse GEMM. " + "Falling back to WGMMA or MMA."; + // return kCudaTCGEN05SP; } - return results; + if (AllowWgmma(op, block_size, target)) { + return kCudaWGMMASP; + } + return kCudaMMASP; + } + + static std::pair + ComputeWarpPartition(const GemmSPWarpPolicyNode &policy, int M, int N, + int block_size, Target target, String gemm_inst) { + int num_warps = block_size / TargetGetWarpSize(target); + if (gemm_inst == kCudaTCGEN05SP) { + policy.m_warp = 1; + policy.n_warp = num_warps; + return {1, num_warps}; + } + if (gemm_inst == kCudaWGMMASP) { + return ComputeWgmmaWarpPartition(policy, M, N, num_warps); + } + int k_n_per_warp = + (TargetIsVolta(target) || TargetIsTuring(target)) ? 16 : 8; + return ComputeDefaultWarpPartition(policy, M, N, num_warps, k_n_per_warp); + } + + static bool ReuseExistingSharedLayout(String gemm_inst) { + return gemm_inst == kCudaMMASP; + } + + static String InstructionKind(String gemm_inst) { + if (gemm_inst == kCudaWGMMASP) { + return "wgmma.sp"; + } + if (gemm_inst == kCudaTCGEN05SP) { + return "tcgen5mma.sp"; + } + if (gemm_inst == kCudaMMASP) { + return "mma.sp"; + } + return "unknown"; } }; @@ -224,16 +330,58 @@ bool RegisterCudaGemmSP() { RegisterGemmSPImpl(GemmSPImpl{ "cuda.GemmSP", MatchCudaGemmSPTarget, + cuda::GemmSP::SelectInst, cuda::GemmSP::ComputeWarpPartition, - cuda::GemmSP::Lower, - cuda::GemmSP::InferLayout, + cuda::GemmSP::ReuseExistingSharedLayout, + cuda::GemmSP::InstructionKind, }); return true; } -const bool cuda_gemm_sp_registered = RegisterCudaGemmSP(); +const bool cuda_gemm_registered = RegisterCudaGemmSP(); } // namespace +// TVM_FFI_STATIC_INIT_BLOCK() { +// namespace refl = tvm::ffi::reflection; +// refl::GlobalDef().def( +// "tl.get_tcgen5_mma_meta", [](int M, int N, int K, DataType ab_dtype, +// DataType c_dtype, bool disable_2cta) { +// auto [success, meta] = +// GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype, disable_2cta); +// Array result; +// if (success) { +// result.push_back(Integer(meta.atom_m)); +// result.push_back(Integer(meta.atom_n)); +// result.push_back(Integer(meta.atom_k)); +// result.push_back(Integer(meta.enable_ws)); +// result.push_back(Integer(meta.enable_2cta)); +// } +// return result; +// }); +// refl::GlobalDef().def( +// "tl.get_tcgen5_instr_desc", +// [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, +// DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int +// scale_in_a, int scale_in_b) { +// uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, +// c_dtype, a_is_k_major, +// b_is_k_major, scale_in_a, +// scale_in_b); +// return Integer(static_cast(desc)); +// }); +// refl::GlobalDef().def("tl.get_tcgen5_blockscaled_instr_desc", +// [](int atom_m, int atom_n, DataType ab_dtype, +// bool a_is_k_major, bool b_is_k_major, int +// scale_in_a, int scale_in_b, int a_sf_id, int +// b_sf_id) { +// uint32_t desc = GetTCGEN5BlockScaledInstrDesc( +// atom_m, atom_n, ab_dtype, a_is_k_major, +// b_is_k_major, scale_in_a, scale_in_b, a_sf_id, +// b_sf_id); +// return Integer(static_cast(desc)); +// }); +// } + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 508f90c654..837e5b1bcf 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -207,6 +207,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_sp_ss) + .set_num_inputs(18) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_wgmma_sp_rs) + .set_num_inputs(17) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss) .set_num_inputs(14) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10a..564401d1ef 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -369,6 +369,16 @@ TVM_DLL const Op &ptx_wgmma_ss(); */ TVM_DLL const Op &ptx_wgmma_rs(); +/*! + * \brief tvm intrinsic for sparse ptx wgmma shared-shared instructions. + */ +TVM_DLL const Op &ptx_wgmma_sp_ss(); + +/*! + * \brief tvm intrinsic for sparse ptx wgmma register-shared instructions. + */ +TVM_DLL const Op &ptx_wgmma_sp_rs(); + /*! * \brief tvm intrinsic for tcgen05 mma shared-shared instructions. */ diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 959186da5a..f472156c36 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -261,7 +261,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, } TIR_REGISTER_TL_TILE_OP(Gemm, gemm) - .set_num_inputs(5) + .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -275,7 +275,7 @@ TVM_REGISTER_OP("tl.tileop.wgmma_gemm") IntImm(DataType::Int(32), 1)); return Gemm(args, ann); }) - .set_num_inputs(5) + .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -289,7 +289,7 @@ TVM_REGISTER_OP("tl.tileop.tcgen05_gemm") IntImm(DataType::Int(32), 1)); return Gemm(args, ann); }) - .set_num_inputs(5) + .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -299,9 +299,6 @@ TVM_REGISTER_OP("tl.GemmWarpPolicy") TVM_FFI_STATIC_INIT_BLOCK() { GemmNode::RegisterReflection(); GemmWarpPolicyNode::RegisterReflection(); -} - -TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition", [](GemmWarpPolicy policy, int M, int N, int block_size, diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index fa7b92de73..15d2d6f058 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -5,17 +5,24 @@ */ #include "gemm_sp.h" +#include "utils.h" +#include "builtin.h" +#include +#include #include #include -#include "utils.h" +#include "../target/utils.h" +#include "tvm/ffi/string.h" #include namespace tvm { namespace tl { +using namespace tir; + namespace { std::vector &GemmSPImplRegistry() { @@ -29,90 +36,103 @@ const GemmSPImpl &ResolveGemmSPImpl(Target target) { for (const GemmSPImpl &impl : registry) { if (impl.match_target(target)) { ICHECK(matched_impl == nullptr) - << "tl.gemm_sp found multiple target-specific implementations for " + << "tl.gemm found multiple target-specific implementations for " << target->ToDebugString() << ": " << matched_impl->name << " and " << impl.name; matched_impl = &impl; } } ICHECK(matched_impl != nullptr) - << "tl.gemm_sp requires a target-specific implementation, but no " - "gemm_sp implementation is registered for " + << "tl.gemm requires a target-specific implementation, but no gemm " + "implementation is registered for " << target->ToDebugString(); return *matched_impl; } } // namespace +std::pair GemmSPWarpPolicyNode::computeWarpPartition( + int M, int N, int block_size, Target target, String gemm_inst) const { + return ResolveGemmSPImpl(target).compute_warp_partition( + *this, M, N, block_size, target, gemm_inst); +} + void RegisterGemmSPImpl(GemmSPImpl impl) { ICHECK(impl.name != nullptr); ICHECK(impl.match_target != nullptr); + ICHECK(impl.select_inst != nullptr); ICHECK(impl.compute_warp_partition != nullptr); - ICHECK(impl.lower != nullptr); - ICHECK(impl.infer_layout != nullptr); + ICHECK(impl.reuse_existing_shared_layout != nullptr); + ICHECK(impl.instruction_kind != nullptr); GemmSPImplRegistry().push_back(impl); } -std::pair GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, - int block_size, - Target target, - String gemm_inst, - int bits) const { - return ResolveGemmSPImpl(target).compute_warp_partition( - *this, M, N, block_size, target, gemm_inst, bits); -} - /** - * @brief Construct a GemmSP operator node from TL call arguments and a buffer - * map. + * @brief Construct a GemmSP operator from serialized TL arguments. * - * Parses the expected call argument tuple and fills an internal GemmSPNode: - * - Buffers: A (args[0]), E (args[1]), B (args[2]), C (args[3]) are looked up - * in vmap. - * - Booleans: trans_A (args[4]), trans_B (args[5]). - * - Dimensions: M (args[6]), N (args[7]), K (args[8]) as integers. - * - Warp policy: policy (args[9]) mapped to GemmWarpPolicy. - * - clear_accum: boolean flag (args[10]). - * - Optional kPack (args[11]): must be 1 or 2 (checked via ICHECK). - * - Optional wg_wait (args[12]): integer workgroup wait parameter. + * Deserializes operator parameters from `args` and resolves buffer references, + * populating an internal GemmSPNode with buffers, transpose flags, M/N/K, + * warp policy, clear_accum, strides, offsets, and optional kPack/wg_wait. * - * The populated GemmSPNode is stored in the instance's internal data_ pointer. - * - * @param args Positional TL call arguments in the above order. - * - * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_E (Bool), + * trans_B (Bool), M (Int), N (Int), K (Int), policy (Int), + * clear_accum (Bool), stride_A (Int), stride_B (Int), + * offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] */ GemmSP::GemmSP(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); + auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); auto e_access = NormalizeToAccessRegion(args[1], kAccessRead); auto b_access = NormalizeToAccessRegion(args[2], kAccessRead); auto c_access = NormalizeToAccessRegion(args[3], kAccessReadWrite); + node->aRegion_ = a_access.region; node->eRegion_ = e_access.region; node->bRegion_ = b_access.region; node->cRegion_ = c_access.region; node->SetAccessRegions({a_access, e_access, b_access, c_access}); - node->a_ = node->aRegion_->buffer; - node->e_ = node->eRegion_->buffer; - node->b_ = node->bRegion_->buffer; - node->c_ = node->cRegion_->buffer; - node->transA_ = args[4].as().value(); - node->transB_ = args[5].as().value(); - node->m_ = args[6].as().value()->value; - node->n_ = args[7].as().value()->value; - node->k_ = args[8].as().value()->value; - node->policy_ = GemmSPWarpPolicy(args[9].as().value()->value); - node->clearAccum_ = args[10].as().value(); - if (args.size() > 11) { - node->kPack_ = args[11].as().value()->value; - if (node->kPack_ != 1 && node->kPack_ != 2) { + + node->A = node->aRegion_->buffer; + node->E = node->eRegion_->buffer; + node->B = node->bRegion_->buffer; + node->C = node->cRegion_->buffer; + + node->trans_A = args[4].as().value(); + node->trans_E = args[5].as().value(); + node->trans_B = args[6].as().value(); + node->M = args[7].as().value()->value; + node->N = args[8].as().value()->value; + node->K = args[9].as().value()->value; + node->policy = GemmSPWarpPolicy(args[10].as().value()->value); + node->clear_accum = args[11].as().value(); + node->stride_A = args[12].as().value()->value; + node->stride_B = args[13].as().value()->value; + node->offset_A = args[14].as().value()->value; + node->offset_B = args[15].as().value()->value; + if (args.size() > 16) { + node->kPack = args[16].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { ICHECK(false) << "kPack must be 1 or 2"; } } - if (args.size() > 12) { - node->wgWait_ = args[12].as().value()->value; + if (args.size() > 17) { + node->wg_wait = args[17].as().value()->value; } + if (auto val = annotations.Get("is_wgmma")) { + const auto *int_val = val->as(); + ICHECK(int_val) << "is_wgmma annotation must be IntImmNode"; + node->isWgmma_ = int_val->value != 0; + } + if (auto val = annotations.Get("is_tcgen05")) { + const auto *int_val = val->as(); + ICHECK(int_val) << "is_tcgen05 annotation must be IntImmNode"; + node->isTcgen05_ = int_val->value != 0; + } + data_ = std::move(node); } @@ -121,80 +141,147 @@ AccessRegions GemmSPNode::GetAccessRegions() const { result.reads.push_back(aRegion_); result.reads.push_back(eRegion_); result.reads.push_back(bRegion_); - if (!clearAccum_) { + if (!is_one(clear_accum)) { result.reads.push_back(cRegion_); } result.writes.push_back(cRegion_); return result; } -/** - * @brief Create a deep copy of this GemmSPNode wrapped as a TileOperator. - * - * Returns a new TileOperator that owns a copy of this node. The cloned node - * duplicates all fields of the original; subsequent modifications to the - * clone do not affect the original node. - * - * @return TileOperator A TileOperator holding a cloned GemmSPNode. - */ TileOperator GemmSPNode::Clone() const { auto op = tvm::ffi::make_object(*this); return GemmSP(op); } -/** - * @brief Lower this GemmSP node through the registered backend. - * - * @param T Lowering context containing thread bounds and target. - * @return Stmt The backend-specific lowered statement. - */ +String GemmSPNode::getGemmSPInstructionKey(int block_size, + Target target) const { + return ResolveGemmSPImpl(target).select_inst(*this, block_size, target); +} + +String GemmSPNode::getGemmSPInstructionKind(int block_size, + Target target) const { + const GemmSPImpl &impl = ResolveGemmSPImpl(target); + return impl.instruction_kind(impl.select_inst(*this, block_size, target)); +} + Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - return ResolveGemmSPImpl(T.target).lower(*this, T, analyzer); + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp.lower")) { + auto prim_func = + Downcast((*f)(tvm::ffi::GetRef(this), T.target, + T.layout_map, T.thread_bounds, T.thread_var)); + ICHECK(prim_func->attrs.defined()); + auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); + if (prim_func->body.as()) { + BlockRealize block_realize = Downcast(prim_func->body); + auto block = block_realize->block; + { + BlockNode *n = block.CopyOnWrite(); + n->name_hint = global_symbol.value(); + n->annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); + } + return BlockRealize(block_realize->iter_values, block_realize->predicate, + block); + } + // wrap with block realize node + Map block_annotations; + block_annotations.Set(tl::attr::kLexicalAllocScope, + IntImm(DataType::Int(32), 1)); + return BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/global_symbol.value(), prim_func->body, + /*init=*/Optional(), /*alloc_buffers=*/{}, + /*match_buffers=*/{}, /*annotations=*/block_annotations)); + } else { + LOG(FATAL) << "No lower function found for gemm_sp"; + } } -/** - * @brief Infers and returns the memory/layout mapping for the GemmSP operator. - * - * Delegates target-specific layout inference to the registered GemmSP backend. - * The function caches its work: if layout inference has already completed - * (completed_ == true) it returns an empty LayoutMap. - * - * Precondition: - * - C.scope() must be "local.fragment". - * - * @param T LayoutInferArgs containing thread bounds and target. - * @param level Currently unused inference detail level. - * @return LayoutMap mapping A, B, and C to their inferred layouts (or empty if - * inference was already completed). - */ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (completed_) return {}; - LayoutMap results = ResolveGemmSPImpl(T.target).infer_layout(*this, T, level); + LayoutMap results; + if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp.infer_layout")) { + auto inferred_layouts = Downcast( + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); + auto block_size = *as_const_int(T.thread_bounds->extent); + String gemm_inst = getGemmSPInstructionKey(block_size, T.target); + bool reuse_existing_shared_layout = + ResolveGemmSPImpl(T.target).reuse_existing_shared_layout(gemm_inst); + for (auto kv : inferred_layouts) { + const Buffer &buf = kv.first; + const Layout &layout = kv.second; + if (reuse_existing_shared_layout && IsSharedBuffer(buf) && + T.layout_map.count(buf)) { + continue; + } + if (auto frag = layout.as()) { + results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds)); + } else { + results.Set(buf, layout); + } + } + } else { + LOG(FATAL) << "No infer layout function found for gemm_sp"; + } + completed_ = true; return results; } TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp) - .set_num_inputs(5) + .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_REGISTER_OP("tl.GemmSPWarpPolicy") - .set_attr("TScriptPrinterName", "GemmSPWarpPolicy"); +TVM_REGISTER_OP("tl.tileop.wgmma_gemm_sp") + .set_attr("TScriptPrinterName", "wgmma_gemm_sp") + .set_attr("TLOpBuilder", + [](Array args, + Map annotations) { + Map ann = annotations; + ann.Set("is_wgmma", + IntImm(DataType::Int(32), 1)); + return GemmSP(args, ann); + }) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tl.tileop.tcgen05_gemm_sp") + .set_attr("TScriptPrinterName", "tcgen05_gemm_sp") + .set_attr("TLOpBuilder", + [](Array args, + Map annotations) { + Map ann = annotations; + ann.Set("is_tcgen05", + IntImm(DataType::Int(32), 1)); + return GemmSP(args, ann); + }) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); TVM_FFI_STATIC_INIT_BLOCK() { - GemmSPNode::RegisterReflection(); GemmSPWarpPolicyNode::RegisterReflection(); + GemmSPNode::RegisterReflection(); namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tl.GemmSPWarpPolicyComputeWarpPartition", - [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target, - String gemm_inst, int bits) { - policy->computeWarpPartition(M, N, block_size, target, gemm_inst, bits); - return; - }); + refl::GlobalDef().def("tl.GemmSPWarpPolicyComputeWarpPartition", + [](GemmSPWarpPolicy policy, int M, int N, + int block_size, Target target, String gemm_inst) { + policy->computeWarpPartition(M, N, block_size, target, + gemm_inst); + }); + refl::GlobalDef().def("tl.GemmSPGetGemmInstructionKey", + [](GemmSP gemm, int block_size, Target target) { + return gemm->getGemmSPInstructionKey(block_size, + target); + }); } } // namespace tl } // namespace tvm diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index b42178c753..7e043124eb 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -16,13 +16,14 @@ namespace tl { using namespace tir; -class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { +class GemmSPWarpPolicyNode : public Object { public: - std::pair computeWarpPartition(int M, int N, int block_size, - Target target, String gemm_inst, - int bits) const; + mutable int m_warp{0}; + mutable int n_warp{0}; + int policy_type; + TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, - GemmWarpPolicyNode); + Object); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -31,6 +32,21 @@ class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { .def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp) .def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp); } + + std::pair computeWarpPartition(int M, int N, int block_size, + Target target, + String gemm_inst) const; + + bool isSquare() const { + return policy_type == int(GemmWarpPolicyType::kSquare); + } + bool isFullRow() const { + return policy_type == int(GemmWarpPolicyType::kFullRow); + } + bool isFullCol() const { + return policy_type == int(GemmWarpPolicyType::kFullCol); + } + bool isFree() const { return policy_type == int(GemmWarpPolicyType::kFree); } }; class GemmSPWarpPolicy : public ObjectRef { @@ -61,47 +77,65 @@ class GemmSPWarpPolicy : public ObjectRef { class GemmSPNode : public TileOperatorNode { public: - BufferRegion aRegion_, bRegion_, cRegion_, eRegion_; - tir::Buffer a_, b_, c_, e_; - bool transA_, transB_; - int m_, n_, k_; - bool clearAccum_ = false; - // Backend-specific K packing parameter. - int kPack_ = 1; - int wgWait_ = 0; - - mutable GemmSPWarpPolicy policy_; + bool CheckWGMMA() const; + tir::Buffer A, E, B, C; + // pointer to the A, E, B, C + BufferRegion aRegion_, eRegion_, bRegion_, cRegion_; + bool trans_A, trans_B, trans_E; + int M, N, K; + int stride_A, stride_B; + int offset_A, offset_B; + PrimExpr clear_accum = const_false(); + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack = 1; + int wg_wait = 0; + bool isWgmma_ = false; + bool isTcgen05_ = false; + mutable GemmSPWarpPolicy policy; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - AccessRegions GetAccessRegions() const override; - - TileOperator Clone() const; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("policy", &GemmSPNode::policy_) + .def_ro("A", &GemmSPNode::A) + .def_ro("E", &GemmSPNode::E) + .def_ro("B", &GemmSPNode::B) + .def_ro("C", &GemmSPNode::C) .def_ro("aRegion", &GemmSPNode::aRegion_) + .def_ro("eRegion", &GemmSPNode::eRegion_) .def_ro("bRegion", &GemmSPNode::bRegion_) .def_ro("cRegion", &GemmSPNode::cRegion_) - .def_ro("eRegion", &GemmSPNode::eRegion_) - .def_ro("a", &GemmSPNode::a_) - .def_ro("b", &GemmSPNode::b_) - .def_ro("c", &GemmSPNode::c_) - .def_ro("e", &GemmSPNode::e_) - .def_ro("transA", &GemmSPNode::transA_) - .def_ro("transB", &GemmSPNode::transB_) - .def_ro("m", &GemmSPNode::m_) - .def_ro("n", &GemmSPNode::n_) - .def_ro("k", &GemmSPNode::k_) - .def_ro("clearAccum", &GemmSPNode::clearAccum_) - .def_ro("kPack", &GemmSPNode::kPack_) - .def_ro("wgWait", &GemmSPNode::wgWait_); + .def_ro("trans_A", &GemmSPNode::trans_A) + .def_ro("trans_B", &GemmSPNode::trans_B) + .def_ro("trans_E", &GemmSPNode::trans_E) + .def_ro("M", &GemmSPNode::M) + .def_ro("N", &GemmSPNode::N) + .def_ro("K", &GemmSPNode::K) + .def_ro("stride_A", &GemmSPNode::stride_A) + .def_ro("stride_B", &GemmSPNode::stride_B) + .def_ro("offset_A", &GemmSPNode::offset_A) + .def_ro("offset_B", &GemmSPNode::offset_B) + .def_ro("clear_accum", &GemmSPNode::clear_accum) + .def_ro("kPack", &GemmSPNode::kPack) + .def_ro("wg_wait", &GemmSPNode::wg_wait) + .def_ro("isWgmma", &GemmSPNode::isWgmma_) + .def_ro("isTcgen05", &GemmSPNode::isTcgen05_) + .def_ro("policy", &GemmSPNode::policy); } + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + AccessRegions GetAccessRegions() const override; + + TileOperator Clone() const; + + // Target-specific GEMM SP instruction key. + String getGemmSPInstructionKey(int block_size, Target target) const; + String getGemmSPInstructionKind(int block_size, Target target) const; + private: mutable bool completed_ = false; }; @@ -112,15 +146,15 @@ struct GemmSPImpl { const char *name; GemmSPTargetPredicate match_target; + String (*select_inst)(const GemmSPNode &op, int block_size, Target target); + std::pair (*compute_warp_partition)( const GemmSPWarpPolicyNode &policy, int M, int N, int block_size, - Target target, String gemm_inst, int bits); + Target target, String gemm_inst); - Stmt (*lower)(const GemmSPNode &op, const LowerArgs &T, - arith::Analyzer *analyzer); + bool (*reuse_existing_shared_layout)(String gemm_inst); - LayoutMap (*infer_layout)(const GemmSPNode &op, const LayoutInferArgs &T, - InferLevel level); + String (*instruction_kind)(String gemm_inst); }; void RegisterGemmSPImpl(GemmSPImpl impl); diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc deleted file mode 100644 index f5fb269d39..0000000000 --- a/src/op/gemm_sp_py.cc +++ /dev/null @@ -1,180 +0,0 @@ -/*! - * \file tl/op/gemm_sp_py.cc - * \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP) - * operators - */ - -#include "gemm_sp_py.h" -#include "utils.h" - -#include "builtin.h" -#include -#include -#include -#include - -#include "tvm/ffi/string.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -/** - * @brief Construct a Gemm operator from serialized TL arguments and a buffer - * map. - * - * This constructor deserializes operator parameters from `args` and resolves - * buffer references via `vmap`, populating an internal GemmSPPyNode with: - * - device pointers for A, E, B, C and their corresponding Buffer objects, - * - transpose flags for A and B, - * - matrix dimensions M, N, K, - * - warp allocation policy and clear_accum flag, - * - strides and memory offsets for A and B, - * - optional kPack (must be 1 or 2) and optional wg_wait. - * - * The populated GemmSPPyNode is stored into the wrapper's internal `data_`. - * - * @param args Positional serialized arguments produced by the TL frontend: - * expected layout is: - * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), - * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), - * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), - * (optional) kPack (Int), (optional) wg_wait (Int)] - * @param vmap Mapping from access pointer vars to Buffer objects used to - * resolve the Buffer corresponding to each pointer argument. - * - * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor - * fails with an ICHECK (runtime assertion). No other validation is - * performed here. - */ -GemmSPPy::GemmSPPy(Array args, Map annotations) { - ObjectPtr node = tvm::ffi::make_object(); - - auto a_access = NormalizeToAccessRegion(args[0], kAccessRead); - auto e_access = NormalizeToAccessRegion(args[1], kAccessRead); - auto b_access = NormalizeToAccessRegion(args[2], kAccessRead); - auto c_access = NormalizeToAccessRegion(args[3], kAccessReadWrite); - - node->aRegion_ = a_access.region; - node->eRegion_ = e_access.region; - node->bRegion_ = b_access.region; - node->cRegion_ = c_access.region; - node->SetAccessRegions({a_access, e_access, b_access, c_access}); - - node->A = node->aRegion_->buffer; - node->E = node->eRegion_->buffer; - node->B = node->bRegion_->buffer; - node->C = node->cRegion_->buffer; - - node->trans_A = args[4].as().value(); - node->trans_B = args[5].as().value(); - node->trans_E = args[6].as().value(); - node->M = args[7].as().value()->value; - node->N = args[8].as().value()->value; - node->K = args[9].as().value()->value; - node->policy = GemmWarpPolicy(args[10].as().value()->value); - node->clear_accum = args[11].as().value(); - node->stride_A = args[12].as().value()->value; - node->stride_B = args[13].as().value()->value; - node->offset_A = args[14].as().value()->value; - node->offset_B = args[15].as().value()->value; - if (args.size() > 16) { - node->kPack = args[16].as().value()->value; - if (node->kPack != 1 && node->kPack != 2) { - ICHECK(false) << "kPack must be 1 or 2"; - } - } - if (args.size() > 17) { - node->wg_wait = args[17].as().value()->value; - } - data_ = std::move(node); -} - -AccessRegions GemmSPPyNode::GetAccessRegions() const { - AccessRegions result; - result.reads.push_back(aRegion_); - result.reads.push_back(eRegion_); - result.reads.push_back(bRegion_); - if (!is_one(clear_accum)) { - result.reads.push_back(cRegion_); - } - result.writes.push_back(cRegion_); - return result; -} - -/** - * @brief Create a copy of this GemmSPPyNode as a TileOperator. - * - * Constructs a new GemmSPPyNode by copying the current node state and returns - * it wrapped in a GemmSPPy TileOperator. - * - * @return TileOperator A GemmSPPy operator that owns a copy of this node. - */ -TileOperator GemmSPPyNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return GemmSPPy(op); -} - -Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) { - auto prim_func = - Downcast((*f)(tvm::ffi::GetRef(this), T.target, - T.thread_bounds, T.thread_var)); - ICHECK(prim_func->attrs.defined()); - auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); - ICHECK(global_symbol.has_value()); - if (prim_func->body.as()) { - BlockRealize block_realize = Downcast(prim_func->body); - auto block = block_realize->block; - { - BlockNode *n = block.CopyOnWrite(); - n->name_hint = global_symbol.value(); - n->annotations.Set(tl::attr::kLexicalAllocScope, - IntImm(DataType::Int(32), 1)); - } - return BlockRealize(block_realize->iter_values, block_realize->predicate, - block); - } - // warp with block realize node - Map block_annotations; - block_annotations.Set(tl::attr::kLexicalAllocScope, - IntImm(DataType::Int(32), 1)); - return BlockRealize( - /*iter_values=*/Array(), - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, - /*name_hint=*/global_symbol.value(), prim_func->body, - /*init=*/Optional(), /*alloc_buffers=*/{}, - /*match_buffers=*/{}, /*annotations=*/block_annotations)); - } else { - LOG(FATAL) << "No lower function found for gemm_sp_py"; - } -} - -LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { - if (completed_) - return {}; - LayoutMap results; - - if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) { - results = Downcast( - (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); - } else { - LOG(FATAL) << "No infer layout function found for gemm_sp_py"; - } - - completed_ = true; - return results; -} - -TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py) - .set_num_inputs(5) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); - -TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); } -} // namespace tl -} // namespace tvm diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h deleted file mode 100644 index db564ec1c0..0000000000 --- a/src/op/gemm_sp_py.h +++ /dev/null @@ -1,92 +0,0 @@ -/*! - * \file tl/op/gemm_sp_py.h - * \brief Define gemm_sp_py operator. - * - */ - -// TODO: @botbw: remove redundant code with gemm.h - -#ifndef TVM_TL_OP_GEMM_SP_PY_H_ -#define TVM_TL_OP_GEMM_SP_PY_H_ - -#include "gemm_sp.h" -#include "operator.h" - -namespace tvm { - -namespace tl { - -using namespace tir; - -class GemmSPPyNode : public TileOperatorNode { -public: - tir::Buffer A, E, B, C; - // pointer to the A, E, B, C - BufferRegion aRegion_, eRegion_, bRegion_, cRegion_; - bool trans_A, trans_B, trans_E; - int M, N, K; - int stride_A, stride_B; - int offset_A, offset_B; - PrimExpr clear_accum = const_false(); - // Backend-specific K packing parameter. - int kPack = 1; - int wg_wait = 0; - - // use GemmWarp Policy here as the atom size are flexible in v2 - mutable GemmWarpPolicy policy; - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode, - TileOperatorNode); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("A", &GemmSPPyNode::A) - .def_ro("E", &GemmSPPyNode::E) - .def_ro("B", &GemmSPPyNode::B) - .def_ro("C", &GemmSPPyNode::C) - .def_ro("aRegion", &GemmSPPyNode::aRegion_) - .def_ro("eRegion", &GemmSPPyNode::eRegion_) - .def_ro("bRegion", &GemmSPPyNode::bRegion_) - .def_ro("cRegion", &GemmSPPyNode::cRegion_) - .def_ro("trans_A", &GemmSPPyNode::trans_A) - .def_ro("trans_B", &GemmSPPyNode::trans_B) - .def_ro("trans_E", &GemmSPPyNode::trans_E) - .def_ro("M", &GemmSPPyNode::M) - .def_ro("N", &GemmSPPyNode::N) - .def_ro("K", &GemmSPPyNode::K) - .def_ro("stride_A", &GemmSPPyNode::stride_A) - .def_ro("stride_B", &GemmSPPyNode::stride_B) - .def_ro("offset_A", &GemmSPPyNode::offset_A) - .def_ro("offset_B", &GemmSPPyNode::offset_B) - .def_ro("clear_accum", &GemmSPPyNode::clear_accum) - .def_ro("kPack", &GemmSPPyNode::kPack) - .def_ro("wg_wait", &GemmSPPyNode::wg_wait) - .def_ro("policy", &GemmSPPyNode::policy); - } - - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - AccessRegions GetAccessRegions() const override; - - TileOperator Clone() const; - -private: - mutable bool completed_ = false; -}; - -class GemmSPPy : public TileOperator { -public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator, - GemmSPPyNode); - TVM_DLL - GemmSPPy(Array args, - Map annotations = Map()); - static const Op &Get(); -}; - -} // namespace tl -} // namespace tvm - -#endif // TVM_TL_OP_GEMM_SP_PY_H_ diff --git a/src/tl_templates/cuda/compress_sm90.cu b/src/tl_templates/cuda/compress_sm90.cu deleted file mode 100644 index 8bb236dd83..0000000000 --- a/src/tl_templates/cuda/compress_sm90.cu +++ /dev/null @@ -1,167 +0,0 @@ -#include - -#include - -#include "cute/tensor.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/packed_stride.hpp" - -using namespace cute; - -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ - << " at: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } -template -std::tuple compress_impl(torch::Tensor A) { - using ElementA = T; - using ElementE = uint8_t; - using LayoutTagA = conditional_t; - using ProblemShape = cute::Shape; - - using StrideA = cutlass::gemm::TagToStrideA_t; - using StrideE = StrideA; - - // NOTE: this is derived from sparse sm90 mma atoms - // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp - using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; - static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; - using SparseConfig = cutlass::Sm90GemmSparseConfig< - cute::sparse_elem<2, ElementA>, GmmaMajorA, - SparseE, cute::C>; - - using CompressorUtility = - cutlass::transform::kernel::StructuredSparseCompressorUtility< - ProblemShape, ElementA, LayoutTagA, SparseConfig>; - - using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; - - using Compressor = cutlass::transform::device::TransformUniversalAdapter; - - TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); - TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); - - int M = -1; - int K = -1; - int N = -1; // not used, but required for config - int L = 1; - if constexpr(transposed) { - M = A.size(1); - K = A.size(0); - } else { - M = A.size(0); - K = A.size(1); - } - - ProblemShape problem_shape = make_tuple(M, N, K, L); - StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - - CompressorUtility compressor_utility(problem_shape, stride_A); - int ME = compressor_utility.get_metadata_m_physical(); - int KE = compressor_utility.get_metadata_k_physical(); - int KC = compressor_utility.get_tensorA_k_physical(); - - StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); - auto dtype = A.dtype().toScalarType(); - torch::Tensor A_compressed = torch::zeros(KC * M, - torch::TensorOptions().dtype(dtype).device(A.device())); - torch::Tensor E = torch::zeros({ME, KE}, - torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = A.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - - typename Compressor::Arguments arguments{problem_shape, - { - A.data_ptr(), - stride_A, - A_compressed.data_ptr(), - E.data_ptr(), - }, - {hw_info}}; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - CUTLASS_CHECK(compressor_op.can_implement(arguments)); - CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); - CUTLASS_CHECK(compressor_op.run()); - CUDA_CHECK(cudaDeviceSynchronize()); - - if constexpr (transposed) { - return std::make_tuple(A_compressed.view({KC, M}), E); - } else { - return std::make_tuple(A_compressed.view({M, KC}), E); - } -} - -// block <= 128 -// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 -#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (BLOCK_K) { \ - case int(32 * FACTOR): return compress_impl(TENSOR); \ - case int(64 * FACTOR): return compress_impl(TENSOR); \ - case int(128 * FACTOR): return compress_impl(TENSOR); \ - default: \ - TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ - } \ - }() - -#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (dtype) { \ - case torch::kFloat32: \ - return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ - case torch::kFloat16: \ - case torch::kBFloat16: \ - return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ - case torch::kFloat8_e4m3fn: \ - return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ - case torch::kFloat8_e5m2: \ - return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ - case torch::kChar: \ - return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ - case torch::kByte: \ - return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ - default: \ - TORCH_CHECK(false, "Unsupported dtype"); \ - } \ - }() - -std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { - auto dtype = A.dtype().toScalarType(); - return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); -} - -#undef DISPATCH_BLOCK_K -#undef DISPATCH_CONTIGUOUS - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90), - "compress_sm90"); -} diff --git a/src/tl_templates/cuda/gemm_sp.h b/src/tl_templates/cuda/gemm_sp.h deleted file mode 100644 index f40a7bd0f8..0000000000 --- a/src/tl_templates/cuda/gemm_sp.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once -#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) -#include "gemm_sp_sm90.h" -#else(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) -#include "gemm_sp_sm80.h" -#endif diff --git a/src/tl_templates/cuda/gemm_sp_sm80.h b/src/tl_templates/cuda/gemm_sp_sm80.h deleted file mode 100644 index f1fc860092..0000000000 --- a/src/tl_templates/cuda/gemm_sp_sm80.h +++ /dev/null @@ -1,270 +0,0 @@ -#include -#include - -namespace tl { - -static int const kSparse = 2; -template struct ShapeCheck { - static constexpr bool value = false; -}; - -template struct ShapeCheck { - static constexpr bool value = - (Shape::kM % 32 == 0) && (Shape::kN % 32 == 0) && (Shape::kK % 32 == 0); -}; - -template struct ShapeCheck { - static constexpr bool value = - ShapeCheck::value; // Same as half -}; - -template struct ShapeCheck { - static constexpr bool value = - (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); -}; - -template struct ShapeCheck { - static constexpr bool value = - (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); -}; - -// ref: -// https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h -template struct DispatchInstructionShape { - static_assert(!std::is_same_v, - "Unsupported type for DispatchInstructionShape"); -}; - -template <> struct DispatchInstructionShape { - using Shape = cutlass::gemm::GemmShape<16, 8, 32>; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template <> struct DispatchInstructionShape { - using Shape = cutlass::gemm::GemmShape<16, 8, 32>; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// TODO: Not supported for now -// template<> -// struct DispatchInstructionShape { -// using Shape = cutlass::gemm::GemmShape<16, 8, 16>; -// using Operator = cutlass::arch::OpMultiplyAdd; -// }; - -template <> struct DispatchInstructionShape { - using Shape = cutlass::gemm::GemmShape<16, 8, 64>; - using Operator = cutlass::arch::OpMultiplyAddSaturate; -}; - -template <> struct DispatchInstructionShape { - using Shape = cutlass::gemm::GemmShape<16, 8, 64>; - using Operator = cutlass::arch::OpMultiplyAddSaturate; -}; - -// TODO: Not supported for now -// template<> -// struct DispatchInstructionShape { -// using Shape = cutlass::gemm::GemmShape<16, 8, 128>; -// using Operator = cutlass::arch::OpMultiplyAddSaturate; -// }; - -template -struct DispatchSharedMemoryLayoutA; - -template -struct DispatchSharedMemoryLayoutA { - using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< - cutlass::sizeof_bits::value, K / kSparse>; -}; - -template -struct DispatchSharedMemoryLayoutA { - static int const Crosswise_A = - cutlass::platform::min(int(128 / sizeof(T)), M); - using SmemLayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< - cutlass::sizeof_bits::value, Crosswise_A>; -}; - -template -struct DispatchSharedMemoryLayoutB; - -template -struct DispatchSharedMemoryLayoutB { - static_assert( - cutlass::sizeof_bits::value != 8, - "int8, uint8, float8 only support column major layout for matrix B"); - static int const Crosswise_B = - cutlass::platform::min(int(128 / sizeof(T)), N); - using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< - cutlass::sizeof_bits::value, Crosswise_B>; -}; - -template -struct DispatchSharedMemoryLayoutB { - static int const kCrosswiseB = (K > (1024 / cutlass::sizeof_bits::value)) - ? (1024 / cutlass::sizeof_bits::value) - : K; - using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< - cutlass::sizeof_bits::value, kCrosswiseB>; -}; - -template struct DispatchType { - static_assert(std::is_same::value, "Unsupported dtype"); -}; - -template <> struct DispatchType { - using Type = cutlass::half_t; -}; - -template <> struct DispatchType { - using Type = cutlass::bfloat16_t; -}; - -template <> struct DispatchType { - using Type = uint8_t; -}; - -template <> struct DispatchType { - using Type = int8_t; -}; - -template -class GemmTensorOp { -public: - static_assert(Shape::kM % num_warp_m == 0); - static_assert(Shape::kN % num_warp_n == 0); - using ElementA = typename DispatchType::Type; - using ElementB = typename DispatchType::Type; - using ElementC = C_type_raw; - - static_assert(std::is_same_v, - "A and B are not the same type"); - static_assert(ShapeCheck::value, - "Invalid shape for ElementA"); - - using LayoutA = - typename std::conditional_t; - using LayoutB = - typename std::conditional_t; - using LayoutC = cutlass::layout::RowMajor; - using ThreadblockShape = Shape; - using SmemLayoutA = - typename DispatchSharedMemoryLayoutA::SmemLayoutA; - using SmemLayoutB = - typename DispatchSharedMemoryLayoutB::SmemLayoutB; - - using WarpShape = cutlass::gemm::GemmShape; - using InstructionShape = typename DispatchInstructionShape::Shape; - using Operator = typename DispatchInstructionShape::Operator; - static_assert(WarpShape::kK % InstructionShape::kK == 0, - "K dimension must be divisible by instruction shape K."); - - // instruction/warp config - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::SparseMma, - cutlass::MatrixShape<1, 1>>; - using MmaWarp = - cutlass::gemm::warp::SparseMmaTensorOp; - static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); - - using SmemLayoutE = typename MmaWarp::LayoutE; - static_assert(std::is_same_v, - "Meta data layout must be ColumnMajor for sparse mma."); - - // other traits - using FragmentA = typename MmaWarp::FragmentA; - using FragmentB = typename MmaWarp::FragmentB; - using FragmentC = typename MmaWarp::FragmentC; - using FragmentE = typename MmaWarp::FragmentE; - - using IteratorA = typename MmaWarp::IteratorA; - using IteratorB = typename MmaWarp::IteratorB; - using IteratorE = typename MmaWarp::IteratorE; - - using TensorRefA = typename IteratorA::TensorRef; - using TensorRefB = typename IteratorB::TensorRef; - using TensorRefE = typename IteratorE::TensorRef; - using ElementE = typename TensorRefE::Element; - - static int const kElementsPerElementE = MmaWarp::kElementsPerElementE; - static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); - - using ShapeA = cutlass::MatrixShape; - using ShapeB = cutlass::MatrixShape; - using ShapeE = - cutlass::MatrixShape; - - static int constexpr kKgroups = WarpShape::kK / InstructionShape::kK; - - template - static CUTLASS_DEVICE void - body(A_type_raw *pA, E_type_raw *pE, B_type_raw *pB, FragmentC &accum, - const int warp_idx_m, const int warp_idx_n, const int lane_id) { - MmaWarp mma_op; - FragmentA frag_a; - FragmentB frag_b; - FragmentE frag_e; - const TensorRefA ref_A( - (ElementA *)pA, - MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); - const TensorRefE ref_E( - (ElementE *)pE, - MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); - const TensorRefB ref_B( - (ElementB *)pB, - MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn})); - IteratorA iter_A(ref_A, lane_id); - IteratorE iter_E(ref_E, lane_id); - IteratorB iter_B(ref_B, lane_id); - iter_A.add_tile_offset({warp_idx_m, 0}); - iter_E.add_tile_offset({warp_idx_m, 0}); - iter_B.add_tile_offset({0, warp_idx_n}); - if constexpr (clear_accum) { - accum.clear(); - } - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < kKgroups; ++k) { - iter_A.load(frag_a); - iter_E.load(frag_e); - iter_B.load(frag_b); - ++iter_A; - ++iter_E; - ++iter_B; - mma_op(accum, frag_a, frag_b, accum, frag_e); - } - } -}; - -template -TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { - using MMA = - GemmTensorOp, num_warp_m, num_warp_n, - trans_A, trans_B, clear_accum, A_type, B_type, C_type>; - using FragmentC = typename MMA::FragmentC; - - int warp_id = threadIdx.x / 32; - int lane_id = threadIdx.x % 32; - MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, - warp_id / num_warp_m, lane_id); -} - -} // namespace tl diff --git a/src/tl_templates/cuda/gemm_sp_sm90.h b/src/tl_templates/cuda/gemm_sp_sm90.h deleted file mode 100644 index 522fc11ee9..0000000000 --- a/src/tl_templates/cuda/gemm_sp_sm90.h +++ /dev/null @@ -1,234 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace cute { -namespace tl_wgmma_sp { -template -class GemmTensorOp { -public: - static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); - - using A_type_cute = typename tl::to_cute_type::type; - using B_type_cute = typename tl::to_cute_type::type; - using A_type = conditional_t::value, - tfloat32_t, A_type_cute>; - using B_type = conditional_t::value, - tfloat32_t, B_type_cute>; - using C_type = C_type_raw; - - static constexpr bool need_tfloat32_cast = - std::is_same::value && - std::is_same::value; - - static constexpr GMMA::Major GmmaMajorA = - trans_A ? GMMA::Major::MN : GMMA::Major::K; - static constexpr GMMA::Major GmmaMajorB = - trans_B ? GMMA::Major::K : GMMA::Major::MN; - - using TiledMma = decltype(make_tiled_mma( - GMMA::ss_op_selector_sparse< - A_type, B_type, C_type, - Shape, Int, Int>, - GmmaMajorA, GmmaMajorB>(), - Layout, Int, _1>>{})); - - using ElementAMma = typename TiledMma::ValTypeA; - using ElementAMmaSparsity = Int; - using ElementBMma = typename TiledMma::ValTypeB; - using ElementEMma = typename TiledMma::ValTypeE; - using ElementEMmaSparsity = Int; - using E_type_raw = typename ElementEMma::raw_type; - - using SparseConfig = - cutlass::Sm90GemmSparseConfig{}, _128{}))>; - - using LayoutA = decltype(SparseConfig::deduce_layoutA()); - using LayoutE = decltype(SparseConfig::deduce_layoutE()); - - using SmemLayoutAtomA = - decltype(cutlass::gemm::collective::detail::ss_smem_selector_sparse< - GmmaMajorA, A_type, Int, Int, ElementAMmaSparsity>()); - using SmemLayoutAtomB = - decltype(cutlass::gemm::collective::detail::ss_smem_selector< - GmmaMajorB, B_type, Int, Int>()); - - using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; - using SmemLayoutAtomE = - ComposedLayout, - smem_sparse_ptr_flag_bits>, - SmemLayoutAtomE_>; - - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, Shape, Int>{}, - conditional_t, Step<_1, _2>>{})); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, Shape, Int>{}, - conditional_t, Step<_2, _1>>{})); - using SmemLayoutE = decltype(tile_to_shape( - SmemLayoutAtomE{}, Shape, Int>{}, - conditional_t, Step<_1, _2>>{})); - - using SmemCopyAtomE = AutoVectorizingCopy; - - template - static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC, - E_type_raw *pE) { - const int tid = threadIdx.x; - Tensor sA = - make_tensor(make_smem_ptr(recast_ptr(pA)), SmemLayoutA{}); - Tensor sB = - make_tensor(make_smem_ptr(recast_ptr(pB)), SmemLayoutB{}); - Tensor sE = as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(recast_ptr(pE)), SmemLayoutE{})); - - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tid); - - Tensor tCsA = thr_mma.partition_A(sA); - Tensor tCsB = thr_mma.partition_B(sB); - Tensor tCsE = partition_E(thr_mma, sE(_, _)); - - Tensor tCrA = thr_mma.make_fragment_A(tCsA); - Tensor tCrB = thr_mma.make_fragment_B(tCsB); - Tensor tCrE = make_fragment_like(tCsE); - - auto copy_atom_E = Copy_Atom{}; - auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); - auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(tid); - Tensor tEsE = smem_thr_copy_E.partition_S(sE); - Tensor tErE = smem_thr_copy_E.retile_D(tCrE); - - Tensor acc = - make_tensor(make_rmem_ptr(pC), - partition_shape_C(tiled_mma, Shape, Int>{})); - - warpgroup_fence_operand(acc); - warpgroup_arrive(); - if constexpr (clear_accum) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - copy(smem_tiled_copy_E, tEsE, tErE); - - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // warpgroup_arrive(); - // (V,M) x (V,N) => (V,M,N) - gemm(tiled_mma, make_zip_tensor(tCrA(_, _, k_block), tCrE(_, _, k_block)), - tCrB(_, _, k_block), acc); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - - warpgroup_commit_batch(); - if constexpr (wg_wait >= 0) { - warpgroup_wait(); - } - warpgroup_fence_operand(acc); - } - - template - CUTE_HOST_DEVICE static constexpr auto - thrfrg_E(TiledMMA const &mma, - ETensor &&etensor) { - using TiledMma = TiledMMA; - - CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); - - // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(get<0>(PermutationMNK{}), get<2>(PermutationMNK{})); - auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) - - // Tile the tensor for the Atom - auto e_tile = - make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), - make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); - auto e_tensor = - zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) - - // Transform the Atom mode from (M,K) to (Thr,Val) - using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; - auto tv_tensor = - e_tensor.compose(AtomLayoutE_TV{}, _); // ((ThrV,FrgV),(RestM,RestK)) - - // Tile the tensor for the Thread - auto thr_tile = - make_tile(_, make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), - make_layout(size<3>(mma.thr_layout_vmnk_)))); - auto thr_tensor = zipped_divide( - tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) - - return thr_tensor; - } - - template - CUTE_HOST_DEVICE static constexpr auto - get_layoutE_TV(TiledMMA const &mma) { - // (M,K) -> (M,K) - auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); - // (ethrid,val) -> (M,K) - auto layoutE_TV = thrfrg_E(mma, ref_E); - - // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) - auto etile = make_tile( - _, make_tile(make_layout(make_shape(size<1>(mma.thr_layout_vmnk_), - size<2>(mma.thr_layout_vmnk_)), - make_stride(Int<1>{}, Int<0>{})), - _)); - - // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); - - // (thr_idx,val) -> (M,K) - return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); - } - - template - CUTE_HOST_DEVICE static constexpr auto - partition_E(ThrMMA const &thr_mma, ETensor &&etensor) { - auto thr_tensor = make_tensor(static_cast(etensor).data(), - thrfrg_E(thr_mma, etensor.layout())); - - auto thr_vmk = make_coord( - get<0>(thr_mma.thr_vmnk_), - make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); - return thr_tensor(thr_vmk, - make_coord(_, repeat(thr_tensor)>(_))); - } - - template - CUTE_HOST_DEVICE static constexpr auto - make_tiled_copy_E(Copy_Atom const ©_atom, - TiledMMA const &mma) { - return make_tiled_copy_impl( - copy_atom, get_layoutE_TV(mma), - make_shape(tile_size<0>(mma), tile_size<2>(mma))); - } -}; - -} // namespace tl_wgmma_sp -} // namespace cute - -namespace tl { -template , - typename E_type = typename GMMA::ElementEMma::raw_type> -TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { - static_assert(use_wgmma, "only wgmma is supported for now"); - if constexpr (use_wgmma) { - GMMA::body(pA, pB, accum, pE); - } else { - CUTE_GCC_UNREACHABLE; - } -} -} // namespace tl diff --git a/src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp b/src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp new file mode 100644 index 0000000000..d60dca813b --- /dev/null +++ b/src/tl_templates/cuda/instruction/cute_extension/mma_sm80_sparse.hpp @@ -0,0 +1,585 @@ +// NOTE: CUTLASS didn't implement this for sm8x +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTE_ARCH_SPARSE_MMA_SM80_ENABLED +#endif + +namespace SM80 { +namespace MMA { + +enum class SparseSel : int { Zero = 0, One = 1 }; + +namespace SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x16_F16F16F16F16_TN { + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, + uint32_t const &a0, uint32_t const &a1, + uint32_t const &b0, uint32_t const &b1, + uint32_t const &c0, uint32_t const &c1, + uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f16." + "f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4,%5}, {%6,%7}, %8, %9;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(e), + "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4,%5}, {%6,%7}, %8, %9;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(e), + "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x16_F16F16F16F16_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x32_F16F16F16F16_TN { + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t &d0, uint32_t &d1, uint32_t const &a0, uint32_t const &a1, + uint32_t const &a2, uint32_t const &a3, uint32_t const &b0, + uint32_t const &b1, uint32_t const &b2, uint32_t const &b3, + uint32_t const &c0, uint32_t const &c1, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16." + "f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, %13;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "r"(c0), "r"(c1), "r"(e), + "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, %13;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "r"(c0), "r"(c1), "r"(e), + "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x32_F16F16F16F16_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x16_F32F16F16F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, + uint32_t const &a0, uint32_t const &a1, + uint32_t const &b0, uint32_t const &b1, + float const &c0, float const &c1, + float const &c2, float const &c3, + uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32." + "f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x16_F32F16F16F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x32_F32F16F16F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32." + "f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x32_F32F16F16F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x16_F32BF16BF16F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, + uint32_t const &a0, uint32_t const &a1, + uint32_t const &b0, uint32_t const &b1, + float const &c0, float const &c1, + float const &c2, float const &c3, + uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32." + "bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x16_F32BF16BF16F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x32_F32BF16BF16F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32." + "bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x32_F32BF16BF16F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x8_F32TF32TF32F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, + uint32_t const &a0, uint32_t const &a1, + uint32_t const &b0, uint32_t const &b1, + float const &c0, float const &c1, + float const &c2, float const &c3, + uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k8.row.col.f32." + "tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, %13;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1), + "f"(c2), "f"(c3), "r"(e), "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x8_F32TF32TF32F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM80_16x8x16_F32TF32TF32F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32." + "tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#else + asm volatile("mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, %17;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), "r"(e), + "n"(int32_t(spsel))); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x16_F32TF32TF32F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct SM80_16x8x32_S32S8S8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &b0, + uint32_t const &b1, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(c2), + "r"(c3), "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3), "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x32_S32S8S8S32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +template <> struct SM80_16x8x32_S32S8S8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &b0, + uint32_t const &b1, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { + CUTE_INVALID_CONTROL_PATH( + "SM80_16x8x32_S32S8S8S32_TN with SparseSel::One is invalid"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct SM80_16x8x64_S32S8S8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, + uint32_t const &a3, uint32_t const &b0, + uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, " + "%16, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), + "r"(b3), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x64_S32S8S8S32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +template <> struct SM80_16x8x64_S32S8S8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, + uint32_t const &a3, uint32_t const &b0, + uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { + CUTE_INVALID_CONTROL_PATH( + "SM80_16x8x64_S32S8S8S32_TN with SparseSel::One is invalid"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct SM80_16x8x32_S32U8U8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &b0, + uint32_t const &b1, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), "r"(c2), + "r"(c3), "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11}, %12, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1), + "r"(c2), "r"(c3), "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x32_S32U8U8S32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +template <> struct SM80_16x8x32_S32U8U8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &b0, + uint32_t const &b1, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { + CUTE_INVALID_CONTROL_PATH( + "SM80_16x8x32_S32U8U8S32_TN with SparseSel::One is invalid"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct SM80_16x8x64_S32U8U8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, + uint32_t const &a3, uint32_t const &b0, + uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { +#if defined(CUTE_ARCH_SPARSE_MMA_SM80_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, " + "%16, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), + "r"(b3), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "r"(c0), "r"(c1), "r"(c2), "r"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM80_16x8x64_S32U8U8S32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM80_ENABLED"); +#endif + } +}; + +template <> struct SM80_16x8x64_S32U8U8S32_TN { + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void fma(uint32_t &d0, uint32_t &d1, uint32_t &d2, + uint32_t &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, + uint32_t const &a3, uint32_t const &b0, + uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, uint32_t const &c0, + uint32_t const &c1, uint32_t const &c2, + uint32_t const &c3, uint32_t const &e) { + CUTE_INVALID_CONTROL_PATH( + "SM80_16x8x64_S32U8U8S32_TN with SparseSel::One is invalid"); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SPARSE +} // namespace MMA +} // namespace SM80 diff --git a/src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp b/src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp new file mode 100644 index 0000000000..6e3a06fdb9 --- /dev/null +++ b/src/tl_templates/cuda/instruction/cute_extension/mma_sm89_sparse.hpp @@ -0,0 +1,207 @@ +#pragma once + +#include "mma_sm80_sparse.hpp" // for SM80::MMA::SparseSel + +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +#define CUTE_ARCH_SPARSE_MMA_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +#if defined(CUTE_ARCH_SPARSE_MMA_SM89_SUPPORTED) +#define CUTE_ARCH_SPARSE_MMA_SM89_ENABLED +#endif +#endif + +namespace SM89 { +namespace MMA { + +using SM80::MMA::SparseSel; + +namespace SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM89_16x8x64_F32E4M3E4M3F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { + static_assert(spsel == SparseSel::Zero, + "SM89 fp8 sparse mma only supports SparseSel::Zero"); +#if defined(CUTE_ARCH_SPARSE_MMA_SM89_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.f32." + "e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM89_16x8x64_F32E4M3E4M3F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM89_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM89_16x8x64_F32E4M3E5M2F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { + static_assert(spsel == SparseSel::Zero, + "SM89 fp8 sparse mma only supports SparseSel::Zero"); +#if defined(CUTE_ARCH_SPARSE_MMA_SM89_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.f32." + "e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM89_16x8x64_F32E4M3E5M2F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM89_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM89_16x8x64_F32E5M2E4M3F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { + static_assert(spsel == SparseSel::Zero, + "SM89 fp8 sparse mma only supports SparseSel::Zero"); +#if defined(CUTE_ARCH_SPARSE_MMA_SM89_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.f32." + "e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM89_16x8x64_F32E5M2E4M3F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM89_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM89_16x8x64_F32E5M2E5M2F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, + uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, + uint32_t const &b0, uint32_t const &b1, uint32_t const &b2, + uint32_t const &b3, float const &c0, float const &c1, float const &c2, + float const &c3, uint32_t const &e) { + static_assert(spsel == SparseSel::Zero, + "SM89 fp8 sparse mma only supports SparseSel::Zero"); +#if defined(CUTE_ARCH_SPARSE_MMA_SM89_ENABLED) +#if ((__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + asm volatile("mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.f32." + "e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + asm volatile("mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), + "r"(b2), "r"(b3), "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#endif +#else + CUTE_INVALID_CONTROL_PATH("SM89_16x8x64_F32E5M2E5M2F32_TN requires " + "CUTE_ARCH_SPARSE_MMA_SM89_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SPARSE +} // namespace MMA +} // namespace SM89 diff --git a/src/tl_templates/cuda/instruction/mma_sp.h b/src/tl_templates/cuda/instruction/mma_sp.h new file mode 100644 index 0000000000..2c5b65fe63 --- /dev/null +++ b/src/tl_templates/cuda/instruction/mma_sp.h @@ -0,0 +1,186 @@ +#pragma once + +#include "../common.h" +#include "cute_extension/mma_sm80_sparse.hpp" +#include "cute_extension/mma_sm89_sparse.hpp" + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +namespace detail { + +template struct MmaSpImplTraits { + using DReg = std::remove_extent_t; + using AReg = std::remove_extent_t; + using BReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + using EReg = std::remove_extent_t; + + static constexpr int kDRegs = std::extent_v; + static constexpr int kARegs = std::extent_v; + static constexpr int kBRegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; + static constexpr int kERegs = std::extent_v; +}; + +template +TL_DEVICE void +call_fma_sp_impl(typename MmaSpImplTraits::DReg *d, + const typename MmaSpImplTraits::AReg *a, + const typename MmaSpImplTraits::BReg *b, + const typename MmaSpImplTraits::CReg *c, + const typename MmaSpImplTraits::EReg *e, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence, + std::index_sequence) { + Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]..., e[EIdx]...); +} + +template +TL_DEVICE void call_fma_sp(typename MmaSpImplTraits::DReg *d, + const typename MmaSpImplTraits::AReg *a, + const typename MmaSpImplTraits::BReg *b, + const typename MmaSpImplTraits::CReg *c, + const typename MmaSpImplTraits::EReg *e) { + call_fma_sp_impl( + d, a, b, c, e, std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}, + std::make_index_sequence::kERegs>{}); +} + +template +struct MmaSpDispatcher { + using CRegType = void; + using ARegType = void; + using BRegType = void; + + static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, + const CRegType *, const uint32_t *) { + static_assert(always_false_v>, + "tl::mma_sp_sync: unsupported configuration"); + } +}; + +#define TL_DEFINE_MMA_SP_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \ + NValue, KValue, TransAValue, TransBValue, \ + ImplTemplate) \ + template \ + struct MmaSpDispatcher { \ + using Impl = ImplTemplate; \ + using Traits = MmaSpImplTraits; \ + using CRegType = typename Traits::DReg; \ + using ARegType = typename Traits::AReg; \ + using BRegType = typename Traits::BReg; \ + static_assert( \ + std::is_same_v, \ + "tl::mma_sp_sync requires matching accumulator/output regs"); \ + static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ + const BRegType *b, const CRegType *c, \ + const uint32_t *e) { \ + call_fma_sp(d, a, b, c, \ + reinterpret_cast(e)); \ + } \ + }; + +// FP16 — logical K=16 (A holds K/2=8 actual elements, 2 regs) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, + true, + SM80::MMA::SPARSE::SM80_16x8x16_F16F16F16F16_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, + true, + SM80::MMA::SPARSE::SM80_16x8x16_F32F16F16F32_TN) + +// FP16 — logical K=32 (A holds K/2=16 actual elements, 4 regs) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 32, false, + true, + SM80::MMA::SPARSE::SM80_16x8x32_F16F16F16F16_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 32, false, + true, + SM80::MMA::SPARSE::SM80_16x8x32_F32F16F16F32_TN) + +// BF16 — logical K=16 +TL_DEFINE_MMA_SP_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, + true, + SM80::MMA::SPARSE::SM80_16x8x16_F32BF16BF16F32_TN) + +// BF16 — logical K=32 +TL_DEFINE_MMA_SP_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 32, false, + true, + SM80::MMA::SPARSE::SM80_16x8x32_F32BF16BF16F32_TN) + +// TF32 — logical K=8 (A holds K/2=4 actual elements, 2 regs) +TL_DEFINE_MMA_SP_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, + false, true, + SM80::MMA::SPARSE::SM80_16x8x8_F32TF32TF32F32_TN) + +// TF32 — logical K=16 (A holds K/2=8 actual elements, 4 regs) +TL_DEFINE_MMA_SP_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 16, + false, true, + SM80::MMA::SPARSE::SM80_16x8x16_F32TF32TF32F32_TN) + +// INT8 — logical K=32 (A holds K/2=16, 2 regs); SparseSel::One is invalid +TL_DEFINE_MMA_SP_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, + SM80::MMA::SPARSE::SM80_16x8x32_S32S8S8S32_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, + SM80::MMA::SPARSE::SM80_16x8x32_S32U8U8S32_TN) + +// INT8 — logical K=64 (A holds K/2=32, 4 regs); SparseSel::One is invalid +TL_DEFINE_MMA_SP_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 64, false, true, + SM80::MMA::SPARSE::SM80_16x8x64_S32S8S8S32_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 64, false, true, + SM80::MMA::SPARSE::SM80_16x8x64_S32U8U8S32_TN) + +// FP8 — logical K=64 (A holds K/2=32, 4 regs); only SparseSel::Zero is valid on +// SM89 +TL_DEFINE_MMA_SP_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 64, + false, true, + SM89::MMA::SPARSE::SM89_16x8x64_F32E4M3E4M3F32_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 64, + false, true, + SM89::MMA::SPARSE::SM89_16x8x64_F32E4M3E5M2F32_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 64, + false, true, + SM89::MMA::SPARSE::SM89_16x8x64_F32E5M2E4M3F32_TN) +TL_DEFINE_MMA_SP_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 64, + false, true, + SM89::MMA::SPARSE::SM89_16x8x64_F32E5M2E5M2F32_TN) + +#undef TL_DEFINE_MMA_SP_DISPATCHER + +} // namespace detail + +template +TL_DEVICE void mma_sp_sync( + typename detail::MmaSpDispatcher::CRegType *c, + const typename detail::MmaSpDispatcher::ARegType *a, + const typename detail::MmaSpDispatcher::BRegType *b, + const uint32_t *e) { + using Dispatcher = detail::MmaSpDispatcher; + static_assert(!std::is_void_v, + "tl::mma_sp_sync: unsupported configuration"); + Dispatcher::exec(c, a, b, c, e); +} + +} // namespace tl diff --git a/src/tl_templates/cuda/instruction/wgmma_sp.h b/src/tl_templates/cuda/instruction/wgmma_sp.h new file mode 100644 index 0000000000..ca22759206 --- /dev/null +++ b/src/tl_templates/cuda/instruction/wgmma_sp.h @@ -0,0 +1,467 @@ +#pragma once + +#include "../common.h" +#include "wgmma.h" +#include +#include + +#ifndef __CUDACC_RTC__ +#include +#include +#endif + +namespace tl { + +namespace detail { + +template struct CallWgmmaSpSS { + using CReg = std::remove_extent_t; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(CReg) == sizeof(uint32_t), + "tl::wgmma_sp_ss expects 32-bit accumulator registers."); + + template + TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c, + uint32_t e, cute::SM90::GMMA::ScaleOut scale, + std::index_sequence) { + Impl::fma(desc_a, desc_b, c[Idx]..., e, scale); + } + + TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw, + uint32_t e, bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto c = reinterpret_cast(c_raw); + Run(desc_a, desc_b, c, e, scale, std::make_index_sequence{}); + } +}; + +template struct CallWgmmaSpRS { + using AReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + static constexpr int kARegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(AReg) == sizeof(uint32_t), + "tl::wgmma_sp_rs expects 32-bit register operands for A."); + static_assert(sizeof(CReg) == sizeof(uint32_t) || + sizeof(CReg) == sizeof(float), + "tl::wgmma_sp_rs expects 32-bit accumulator registers."); + + template + TL_DEVICE static void Run(const AReg *a, uint64_t desc_b, CReg *c, uint32_t e, + cute::SM90::GMMA::ScaleOut scale, + std::index_sequence, + std::index_sequence) { + Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., e, scale); + } + + TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b, + uint32_t *c_raw, uint32_t e, bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto a = reinterpret_cast(a_raw); + auto c = reinterpret_cast(c_raw); + Run(a, desc_b, c, e, scale, std::make_index_sequence{}, + std::make_index_sequence{}); + } +}; + +} // namespace detail + +template +struct WgmmaSpSSImpl { + static_assert(detail::IsValidScale, + "tl::wgmma_sp_ss: invalid scaleA"); + static_assert(detail::IsValidScale, + "tl::wgmma_sp_ss: invalid scaleB"); + TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool, + uint32_t) { + static_assert(always_false_v>, + "tl::wgmma_sp_ss: unsupported configuration"); + } +}; + +template +struct WgmmaSpRSImpl { + static_assert(detail::IsValidScale, + "tl::wgmma_sp_rs: invalid scaleA"); + static_assert(detail::IsValidScale, + "tl::wgmma_sp_rs: invalid scaleB"); + TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool, + uint32_t) { + static_assert(always_false_v>, + "tl::wgmma_sp_rs: unsupported configuration"); + } +}; + +#define TL_WGMMA_SP_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSpSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName< \ + detail::MajorValue::value, detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value, spsel>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpSS::exec(desc_a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSpSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName< \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value, spsel>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpSS::exec(desc_a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaSpSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_ss: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_sp_ss: only +1 scaling supported for this " \ + "sparse WGMMA"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpSS::exec(desc_a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSpRSImpl { \ + static_assert(!tnspA, "tl::wgmma_sp_rs: operand A must be K-major"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName< \ + detail::MajorValue::value, detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value, spsel>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpRS::exec(a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSpRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName< \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value, spsel>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpRS::exec(a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaSpRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_sp_rs: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_sp_rs: only +1 scaling supported for this " \ + "sparse WGMMA"); \ + using Impl = cute::SM90::GMMA::SPARSE::ImplName; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out, uint32_t e) { \ + detail::CallWgmmaSpRS::exec(a, desc_b, c, e, scale_out); \ + } \ + }; + +#define TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(40) \ + OP(48) \ + OP(56) \ + OP(64) \ + OP(72) \ + OP(80) \ + OP(88) \ + OP(96) \ + OP(104) \ + OP(112) \ + OP(120) \ + OP(128) \ + OP(136) \ + OP(144) \ + OP(152) \ + OP(160) \ + OP(168) \ + OP(176) \ + OP(184) \ + OP(192) \ + OP(200) \ + OP(208) \ + OP(216) \ + OP(224) \ + OP(232) \ + OP(240) \ + OP(248) \ + OP(256) + +#define TL_WGMMA_SP_FOREACH_N_INT32_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(48) \ + OP(64) \ + OP(80) \ + OP(96) \ + OP(112) \ + OP(128) \ + OP(144) \ + OP(160) \ + OP(176) \ + OP(192) \ + OP(208) \ + OP(224) \ + OP(240) \ + OP(256) + +#define TL_WGMMA_SP_DEFINE_F16_F16_F16_SS(N) \ + TL_WGMMA_SP_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 32, \ + GMMA_64x##N##x32_F16F16F16_SS) +#define TL_WGMMA_SP_DEFINE_F16_F16_F32_SS(N) \ + TL_WGMMA_SP_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 32, \ + GMMA_64x##N##x32_F32F16F16_SS) +#define TL_WGMMA_SP_DEFINE_BF16_BF16_F32_SS(N) \ + TL_WGMMA_SP_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 32, \ + GMMA_64x##N##x32_F32BF16BF16_SS) + +#define TL_WGMMA_SP_DEFINE_F32_TF32_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, \ + 16, GMMA_64x##N##x16_F32TF32TF32_SS_TN) + +#define TL_WGMMA_SP_DEFINE_S32_S8S8_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32S8S8_SS_TN) +#define TL_WGMMA_SP_DEFINE_S32_S8U8_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32S8U8_SS_TN) +#define TL_WGMMA_SP_DEFINE_S32_U8S8_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32U8S8_SS_TN) +#define TL_WGMMA_SP_DEFINE_S32_U8U8_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32U8U8_SS_TN) + +#define TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E4M3E4M3_SS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E4M3E4M3_SS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E4M3E5M2_SS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E4M3E5M2_SS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E5M2E4M3_SS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E5M2E4M3_SS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E5M2E5M2_SS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_SP_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E5M2E5M2_SS_TN) + +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_F16_F16_SS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_F16_F32_SS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_BF16_BF16_F32_SS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_TF32_SS_TN); + +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_S8S8_SS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_S8U8_SS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_U8S8_SS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_U8U8_SS_TN); + +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_SS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_SS_TN); + +#define TL_WGMMA_SP_DEFINE_F16_F16_F16_RS(N) \ + TL_WGMMA_SP_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 32, \ + GMMA_64x##N##x32_F16F16F16_RS) +#define TL_WGMMA_SP_DEFINE_F16_F16_F32_RS(N) \ + TL_WGMMA_SP_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 32, \ + GMMA_64x##N##x32_F32F16F16_RS) +#define TL_WGMMA_SP_DEFINE_BF16_BF16_F32_RS(N) \ + TL_WGMMA_SP_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 32, \ + GMMA_64x##N##x32_F32BF16BF16_RS) + +#define TL_WGMMA_SP_DEFINE_F32_TF32_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, \ + 16, GMMA_64x##N##x16_F32TF32TF32_RS_TN) + +#define TL_WGMMA_SP_DEFINE_S32_S8S8_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32S8S8_RS_TN) +#define TL_WGMMA_SP_DEFINE_S32_S8U8_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32S8U8_RS_TN) +#define TL_WGMMA_SP_DEFINE_S32_U8S8_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32U8S8_RS_TN) +#define TL_WGMMA_SP_DEFINE_S32_U8U8_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 64, \ + GMMA_64x##N##x64_S32U8U8_RS_TN) + +#define TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E4M3E4M3_RS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E4M3E4M3_RS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E4M3E5M2_RS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E4M3E5M2_RS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E5M2E4M3_RS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E5M2E4M3_RS_TN) +#define TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 64, \ + GMMA_64x##N##x64_F16E5M2E5M2_RS_TN) +#define TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_SP_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 64, \ + GMMA_64x##N##x64_F32E5M2E5M2_RS_TN) + +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_F16_F16_RS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_F16_F32_RS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_BF16_BF16_F32_RS); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_TF32_RS_TN); + +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_S8S8_RS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_S8U8_RS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_U8S8_RS_TN); +TL_WGMMA_SP_FOREACH_N_INT32_MUL8(TL_WGMMA_SP_DEFINE_S32_U8U8_RS_TN); + +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_RS_TN); +TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8(TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_RS_TN); + +#undef TL_WGMMA_SP_DEFINE_F16_F16_F16_SS +#undef TL_WGMMA_SP_DEFINE_F16_F16_F32_SS +#undef TL_WGMMA_SP_DEFINE_BF16_BF16_F32_SS +#undef TL_WGMMA_SP_DEFINE_F32_TF32_SS_TN +#undef TL_WGMMA_SP_DEFINE_S32_S8S8_SS_TN +#undef TL_WGMMA_SP_DEFINE_S32_S8U8_SS_TN +#undef TL_WGMMA_SP_DEFINE_S32_U8S8_SS_TN +#undef TL_WGMMA_SP_DEFINE_S32_U8U8_SS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_SS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_SS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_SS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_SS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_SS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_SS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_SS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_SS_TN +#undef TL_WGMMA_SP_DEFINE_F16_F16_F16_RS +#undef TL_WGMMA_SP_DEFINE_F16_F16_F32_RS +#undef TL_WGMMA_SP_DEFINE_BF16_BF16_F32_RS +#undef TL_WGMMA_SP_DEFINE_F32_TF32_RS_TN +#undef TL_WGMMA_SP_DEFINE_S32_S8S8_RS_TN +#undef TL_WGMMA_SP_DEFINE_S32_S8U8_RS_TN +#undef TL_WGMMA_SP_DEFINE_S32_U8S8_RS_TN +#undef TL_WGMMA_SP_DEFINE_S32_U8U8_RS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E4M3E4M3_RS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E4M3E4M3_RS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E4M3E5M2_RS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E4M3E5M2_RS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E5M2E4M3_RS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E5M2E4M3_RS_TN +#undef TL_WGMMA_SP_DEFINE_F16_E5M2E5M2_RS_TN +#undef TL_WGMMA_SP_DEFINE_F32_E5M2E5M2_RS_TN +#undef TL_WGMMA_SP_FOREACH_N_FLOAT_MUL8 +#undef TL_WGMMA_SP_FOREACH_N_INT32_MUL8 +#undef TL_WGMMA_SP_DEFINE_SS_TN_FIXED_SCALE +#undef TL_WGMMA_SP_DEFINE_SS_GENERAL +#undef TL_WGMMA_SP_DEFINE_SS_TN +#undef TL_WGMMA_SP_DEFINE_RS_TN_FIXED_SCALE +#undef TL_WGMMA_SP_DEFINE_RS_GENERAL +#undef TL_WGMMA_SP_DEFINE_RS_TN + +template +TL_DEVICE void wgmma_sp_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, + bool scale_out, uint32_t e) { + WgmmaSpSSImpl::execute(desc_a, desc_b, c, scale_out, e); +} + +template +TL_DEVICE void wgmma_sp_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c, + bool scale_out, uint32_t e) { + WgmmaSpRSImpl::execute(a, desc_b, c, scale_out, e); +} + +} // namespace tl diff --git a/src/transform/lower_opaque_block.cc b/src/transform/lower_opaque_block.cc index a56e370e4a..dc75da4995 100644 --- a/src/transform/lower_opaque_block.cc +++ b/src/transform/lower_opaque_block.cc @@ -107,7 +107,7 @@ class OpaqueBlockLower : public StmtExprMutator { } // Step 5. Materialize a lexical scope boundary only for blocks that were // explicitly marked by an earlier semantic lowering pass (for example - // gemm/gemm_sp_py). We intentionally avoid re-inferring this from the + // gemm/gemm_sp). We intentionally avoid re-inferring this from the // lowered alloc_buffers here because provenance has already been blurred by // earlier allocation planning/hoisting passes. if (HasLexicalAllocScopeAnnotation(new_block->annotations)) { diff --git a/testing/python/issue/test_tilelang_issue_tma_no_ws.py b/testing/python/issue/test_tilelang_issue_tma_no_ws.py index a0ab91da27..22d966c02b 100644 --- a/testing/python/issue/test_tilelang_issue_tma_no_ws.py +++ b/testing/python/issue/test_tilelang_issue_tma_no_ws.py @@ -3,9 +3,10 @@ import tilelang import tilelang.testing from tilelang import language as T -from tilelang.layout import make_cutlass_metadata_layout import torch +from tilelang.utils.sparse import get_e_factor + def _compile_tvm_ffi(func, pass_configs, **kwargs): tilelang.disable_cache() @@ -319,32 +320,27 @@ def test_sparse_ws_regular_metadata_copy_stays_in_producer(): num_stages = 2 threads = 128 + e_factor = get_e_factor(T.float16, T.uint8) + @T.prim_func def sparse_tensorcore_metadata_copy( A_sparse: T.Tensor((M, K // 2), T.float16), - E: T.Tensor((M, K // 8), "uint8"), + E: T.Tensor((M, K // e_factor), T.uint8), B: T.Tensor((K, N), T.float16), C: T.Tensor((M, N), T.float16), ): with T.Kernel(T.ceildiv(N, block_n), T.ceildiv(M, block_m), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_m, block_k // 2), T.float16) B_shared = T.alloc_shared((block_k, block_n), T.float16) - E_shared = T.alloc_shared((block_m, block_k // 8), "uint8") + E_shared = T.alloc_shared((block_m, block_k // e_factor), T.uint8) C_local = T.alloc_fragment((block_m, block_n), T.float32) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_k), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_k), - } - ) - T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_k), num_stages=num_stages): - T.copy(E[by * block_m, k * block_k // 8], E_shared) + T.copy(E[by * block_m, k * block_k // e_factor], E_shared) T.copy(A_sparse[by * block_m, k * block_k // 2], A_shared) T.copy(B[k * block_k, bx * block_n], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, transpose_A=False, transpose_E=False, transpose_B=False) T.copy(C_local, C[by * block_m, bx * block_n]) @@ -354,12 +350,10 @@ def sparse_tensorcore_metadata_copy( src = kernel.get_kernel_source() producer_idx = src.index("if (128 <= ((int)threadIdx.x)) {") consumer_idx = src.index("} else {", producer_idx) - metadata_copy_idx = src.index("*(uchar2*)(E +") - gemm_idx = src.index("tl::gemm_sp_ss<") + metadata_copy_idx = src.index("tl::tma_load(E_desc") assert producer_idx < metadata_copy_idx < consumer_idx - assert consumer_idx < gemm_idx - assert "*(uchar2*)(E +" not in src[consumer_idx:] + assert "tl::tma_load(E_desc" not in src[consumer_idx:] @tilelang.testing.requires_cuda_compute_version(9, 0) diff --git a/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py b/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py index 34b1b1dba1..a132e14cab 100644 --- a/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py +++ b/testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py @@ -3,7 +3,6 @@ import tilelang import tilelang.testing from tilelang import language as T -from tilelang.layout import make_cutlass_metadata_layout def _compile_tvm_ffi(func, pass_configs=None): @@ -31,33 +30,21 @@ def test_ws_keeps_full_producer_extent_for_lowered_simt_copy(): @T.prim_func def main( A_sparse: T.Tensor((M, K // 2), T.float16), - E: T.Tensor((M, K // e_factor), "uint8"), + E: T.Tensor((M, K // e_factor), T.int8), B: T.Tensor((K, N), T.float16), C: T.Tensor((M, N), T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) B_shared = T.alloc_shared((block_K, block_N), T.float16) - E_shared = T.alloc_shared((block_M, block_K // e_factor), "uint8") + E_shared = T.alloc_shared((block_M, block_K // e_factor), T.int8) C_frag = T.alloc_fragment((block_M, block_N), T.float32) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout( - E_shared, - mma_dtype=T.float16, - arch="9.0", - block_k=block_K, - ), - } - ) - T.disable_warp_group_reg_alloc() T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(E[by * block_M, k * block_K // e_factor], E_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_frag, False, False) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, transpose_A=False, transpose_E=False, transpose_B=False) T.copy(C_frag, C[by * block_M, bx * block_N]) kernel = _compile_tvm_ffi(main) diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index de7808d9f0..b2682e897f 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -1,188 +1,299 @@ import pytest -import torch -import tilelang -import tilelang.testing -import tilelang.language as T - -from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse -from tilelang.layout import make_cutlass_metadata_layout +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse, get_e_factor from tilelang.utils.tensor import torch_assert_close -from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter - -torch.backends.cuda.matmul.allow_tf32 = False - -def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): - is_8bit = "8" in in_dtype - is_unsigned = "uint" in in_dtype - is_int = "int" in in_dtype - if is_int: - if is_8bit: - low, high = (0, 4) if is_unsigned else (-2, 2) - else: - low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse(M, K, low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) - B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda") - else: - A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(T.dtype(in_dtype).as_torch()) - B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(T.dtype(in_dtype).as_torch()) - return A, B +import tilelang.testing +import torch +import tilelang.language as T -def matmul_sp_sm90( +def matmul( M, N, K, block_M, block_N, block_K, + trans_A, + trans_B, in_dtype, out_dtype, accum_dtype, + metadata_dtype, + E_factor, num_stages, threads, - trans_A, - trans_B, ): - E_factor = 4 if in_dtype == T.float32 else 8 A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (K, N) if not trans_B else (N, K) + B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + E_shape = (M, K // E_factor) if not trans_A else (K // E_factor, M) + E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), "uint8"), + E: T.Tensor(E_shape, metadata_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8") + E_shared = T.alloc_shared(E_shared_shape, metadata_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K), - } - ) - T.disable_warp_group_reg_alloc() T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) if trans_A: + T.copy(E[k * block_K // E_factor, by * block_M], E_shared) T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) else: + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) if trans_B: T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_A, trans_B) T.copy(C_frag, C[by * block_M, bx * block_N]) return main -def matmul_sp_sm80( +def run_gemm_ss( M, N, K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, block_M, block_N, block_K, + num_stages, + num_threads, + meta_dtype, +): + metadata_dtype = meta_dtype + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + get_e_factor(in_dtype, metadata_dtype), + num_stages, + num_threads, + ) + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + + A_sparse, E = compress(A.t().contiguous() if trans_A else A, meta_dtype=meta_dtype.as_torch()) + if trans_A: + A_sparse = A_sparse.t().contiguous() + E = E.t().contiguous() + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(out_dtype.as_torch()).to(torch.float32), + C.to(out_dtype.as_torch()).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + + +def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype, seed=0): + torch.manual_seed(seed) + is_8bit = "8" in in_dtype + is_unsigned = "uint" in in_dtype + is_int = "int" in in_dtype + if is_int: + if is_8bit: + low, high = (0, 4) if is_unsigned else (-2, 2) + else: + low, high = (0, 128) if is_unsigned else (-64, 64) + A = randint_semi_sparse(M, K, low=low, high=high, dtype=in_dtype.as_torch(), device="cuda", transposed=trans_A) + B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=in_dtype.as_torch(), device="cuda") + else: + A = randn_semi_sparse(M, K, dtype=in_dtype.as_torch(), device="cuda", transposed=trans_A) + B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(in_dtype.as_torch()) + + return A, B + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype", + [ + (128, 128, 32, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128, T.int16), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int32), + (128, 128, 32, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128, T.int16), + (64, 128, 32, True, False, T.float16, T.float16, T.float, 64, 128, 32, 2, 128, T.int16), + (64, 128, 32, True, True, T.float16, T.float16, T.float, 64, 128, 32, 2, 128, T.int16), + (128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128, T.int16), + (128, 128, 32, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32, 2, 128, T.int16), + (64, 128, 128, True, True, T.int8, T.int8, T.int32, 64, 128, 128, 2, 128, T.int32), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int16), + (128, 128, 64, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128, T.int32), + ], +) +def test_gemm_ss( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype +): + run_gemm_ss( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages, + num_threads, + meta_dtype=meta_dtype, + ) + + +def matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, in_dtype, out_dtype, accum_dtype, + metadata_dtype, + E_factor, num_stages, threads, - trans_A, - trans_B, ): - is_8_bit = "8" in in_dtype - metadata_dtype = T.int32 if is_8_bit else T.int16 - E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (K, N) if not trans_B else (N, K) + B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + E_shape = (M, K // E_factor) if not trans_A else (K // E_factor, M) + E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) + + import tilelang.language as T @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), + E: T.Tensor(E_shape, metadata_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) + E_shared = T.alloc_shared(E_shared_shape, metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout( { - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), + A_shared: tilelang.layout.make_swizzled_layout(A_shared), } ) T.clear(C_frag) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) if trans_A: + T.copy(E[k * block_K // E_factor, by * block_M], E_shared) T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) else: + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) if trans_B: T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) + T.copy(A_shared, A_frag) + T.gemm_sp(A_frag, E_shared, B_shared, C_frag, trans_A, trans_A, trans_B) T.copy(C_frag, C[by * block_M, bx * block_N]) return main -def normalize(tensor, max_range=100.0): - assert max_range <= 448.0 - max_v = tensor.abs().max().clamp(1e-4) - scaler = max_range / max_v - return tensor * scaler - - -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def run_gemm_sp( - kernel, +def run_gemm_rs( M, N, K, + trans_A, + trans_B, in_dtype, out_dtype, + dtypeAccum, + block_M, + block_N, block_K, - trans_A, - trans_B, + num_stages, + num_threads, + meta_dtype, ): - kernel = tilelang.compile( - kernel, - out_idx=[-1], + metadata_dtype = meta_dtype + program = matmul_rs( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + metadata_dtype, + get_e_factor(in_dtype, metadata_dtype), + num_stages, + num_threads, ) - A, B = generate_dense_input( - M=M, - N=N, - K=K, - trans_A=trans_A, - trans_B=trans_B, - in_dtype=in_dtype, + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) - A_sparse, E = compress(A, transposed=trans_A, block_k=block_K) - + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A.t().contiguous() if trans_A else A, meta_dtype=meta_dtype.as_torch()) + if trans_A: + A_sparse = A_sparse.t().contiguous() + E = E.t().contiguous() C_sp = kernel(A_sparse, E, B) def _matmul(A, B): @@ -190,149 +301,460 @@ def _matmul(A, B): A = A.T if trans_B: B = B.T - if "float8" in in_dtype or "int8" in in_dtype: - A = A.to(torch.float32) - B = B.to(torch.float32) + A = A.to(torch.float32) + B = B.to(torch.float32) return torch.matmul(A, B) C = _matmul(A, B) - if "float8" in in_dtype: - diff = calc_diff(C_sp, C) - assert diff < 1e-3, f"{diff=}" - else: - torch_assert_close( - C_sp.to(torch.float32), - C.to(torch.float32), - rtol=1e-3, - atol=1e-3, - base_name="tilelang_sp", - ref_name="ref_dense", - ) - print("pass") + torch_assert_close( + C_sp.to(out_dtype.as_torch()).to(torch.float32), + C.to(out_dtype.as_torch()).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(9, 0) -def run_gemm_sp_sm90( +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype", + [ + (128, 256, 32, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int32), + (128, 256, 32, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, False, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, True, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128, T.int16), + (128, 256, 32, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128, T.int16), + (64, 128, 128, True, True, T.int8, T.int8, T.int32, 64, 128, 128, 2, 128, T.int32), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int16), + (128, 128, 64, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128, T.int32), + ], +) +def test_gemm_rs( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype +): + run_gemm_rs( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages, + num_threads, + meta_dtype=meta_dtype, + ) + + +def matmul_sr( M, N, K, + block_M, + block_N, + block_K, + trans_A, + trans_B, in_dtype, out_dtype, accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + B_frag_shape = B_shared_shape + E_shape = (M, K // E_factor) if not trans_A else (K // E_factor, M) + E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor(E_shape, metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared(E_shared_shape, metadata_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(E[k * block_K // E_factor, by * block_M], E_shared) + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B_shared, B_frag) + T.gemm_sp(A_shared, E_shared, B_frag, C_frag, trans_A, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_sr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, - trans_A, - trans_B, + meta_dtype, ): - kernel = matmul_sp_sm90( + metadata_dtype = meta_dtype + program = matmul_sr( M, N, K, block_M, block_N, block_K, + trans_A, + trans_B, in_dtype, out_dtype, - accum_dtype, + dtypeAccum, + metadata_dtype, + get_e_factor(in_dtype, metadata_dtype), num_stages, num_threads, - trans_A, - trans_B, ) - run_gemm_sp( - kernel, + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A.t().contiguous() if trans_A else A, meta_dtype=meta_dtype.as_torch()) + if trans_A: + A_sparse = A_sparse.t().contiguous() + E = E.t().contiguous() + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(out_dtype.as_torch()).to(torch.float32), + C.to(out_dtype.as_torch()).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", + ) + + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize( + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype", + [ + (128, 256, 32, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128, T.int32), + (128, 256, 32, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, False, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, True, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128, T.int16), + (128, 256, 32, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128, T.int16), + (64, 128, 128, True, True, T.int8, T.int8, T.int32, 64, 128, 128, 2, 128, T.int32), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128, T.int16), + (128, 128, 64, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128, T.int32), + ], +) +def test_gemm_sr( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype +): + run_gemm_sr( M, N, K, + trans_A, + trans_B, in_dtype, out_dtype, + dtypeAccum, + block_M, + block_N, block_K, - trans_A, - trans_B, + num_stages, + num_threads, + meta_dtype=meta_dtype, ) -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(8, 0) -@tilelang.testing.requires_cuda_compute_version_le(8, 9) -def run_gemm_sp_sm80( +def matmul_rr( M, N, K, - in_dtype, - out_dtype, - accum_dtype, block_M, block_N, block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, num_stages, - num_threads, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + A_frag_shape = A_shared_shape + B_frag_shape = B_shared_shape + E_shape = (M, K // E_factor) if not trans_A else (K // E_factor, M) + E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor(E_shape, metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared(E_shared_shape, metadata_dtype) + A_frag = T.alloc_fragment(A_frag_shape, in_dtype) + B_frag = T.alloc_fragment(B_frag_shape, in_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout( + { + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + } + ) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(E[k * block_K // E_factor, by * block_M], E_shared) + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(A_shared, A_frag) + T.copy(B_shared, B_frag) + T.gemm_sp(A_frag, E_shared, B_frag, C_frag, trans_A, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_rr( + M, + N, + K, trans_A, trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, + meta_dtype=T.int16, ): - kernel = matmul_sp_sm80( + metadata_dtype = meta_dtype + program = matmul_rr( M, N, K, block_M, block_N, block_K, + trans_A, + trans_B, in_dtype, out_dtype, - accum_dtype, + dtypeAccum, + metadata_dtype, + get_e_factor(in_dtype, metadata_dtype), num_stages, num_threads, - trans_A, - trans_B, ) - run_gemm_sp( - kernel, - M, - N, - K, - in_dtype, - out_dtype, - block_K, - trans_A, - trans_B, + + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + ) + A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) + A_sparse, E = compress(A.t().contiguous() if trans_A else A, meta_dtype=meta_dtype.as_torch()) + if trans_A: + A_sparse = A_sparse.t().contiguous() + E = E.t().contiguous() + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B) + + C = _matmul(A, B) + + torch_assert_close( + C_sp.to(out_dtype.as_torch()).to(torch.float32), + C.to(out_dtype.as_torch()).to(torch.float32), + rtol=1e-3, + atol=1e-3, + base_name="tilelang_sp", + ref_name="ref_dense", ) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_eq(9, 0) @pytest.mark.parametrize( - "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype", [ - (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False), - (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False), - (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False), - (512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True), - (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), + (128, 256, 32, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int32), + (128, 256, 32, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, False, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (64, 256, 32, True, True, T.float16, T.float16, T.float32, 64, 256, 32, 2, 128, T.int16), + (128, 256, 32, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128, T.int16), + (128, 8, 32, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128, T.int16), + (128, 8, 64, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128, T.int32), + (64, 128, 128, True, True, T.int8, T.int8, T.int32, 64, 128, 128, 2, 128, T.int32), + (128, 128, 64, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128, T.int16), + (128, 128, 64, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128, T.int32), ], ) -def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): - run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) +def test_gemm_rr( + M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads, meta_dtype +): + run_gemm_rr( + M, + N, + K, + trans_A, + trans_B, + in_dtype, + out_dtype, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages, + num_threads, + meta_dtype=meta_dtype, + ) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(8, 0) -@tilelang.testing.requires_cuda_compute_version_le(8, 9) @pytest.mark.parametrize( - "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", + "in_dtype, out_dtype, dtypeAccum, meta_dtype", [ - (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), - (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True), - (512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True), - (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), + (T.float16, T.float16, T.float32, T.int16), + (T.int8, T.int8, T.int32, T.int32), + ( + T.float16, + T.float16, + T.float32, + T.int8, + ), + ( + T.bfloat16, + T.bfloat16, + T.float32, + T.int8, + ), + ( + T.bfloat16, + T.bfloat16, + T.float32, + T.int16, + ), + ( + T.int8, + T.int8, + T.int32, + T.int8, + ), + ( + T.int8, + T.int8, + T.int32, + T.int16, + ), + ( + T.float8_e5m2, + T.float8_e5m2, + T.float32, + T.int8, + ), + ( + T.float8_e5m2, + T.float8_e5m2, + T.float32, + T.int16, + ), + ( + T.float8_e5m2, + T.float8_e5m2, + T.float32, + T.int32, + ), ], ) -def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): - run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B) +def test_compress_dtype_combinations(in_dtype, out_dtype, dtypeAccum, meta_dtype): + run_gemm_ss(128, 128, 128, False, True, in_dtype, out_dtype, dtypeAccum, 128, 128, 64, 2, 128, meta_dtype=meta_dtype) if __name__ == "__main__": diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py deleted file mode 100644 index 921e3b4de2..0000000000 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ /dev/null @@ -1,624 +0,0 @@ -import pytest -from tilelang import tvm as tvm -from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse -from tilelang.utils.tensor import torch_assert_close -from tilelang.layout import make_cutlass_metadata_layout -from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter - -import tilelang.testing -import torch -import tilelang.language as T - - -def matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - metadata_dtype, - E_factor, - num_stages, - threads, -): - A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) - C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout( - { - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - } - ) - T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) - if trans_A: - T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) - else: - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) - T.copy(C_frag, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_ss( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - metadata_dtype, - SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[3], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) - - A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") - C_sp = kernel(A_sparse, E, B) - - def _matmul(A, B): - if trans_A: - A = A.T - if trans_B: - B = B.T - A = A.to(torch.float32) - B = B.to(torch.float32) - return torch.matmul(A, B) - - C = _matmul(A, B) - - torch_assert_close( - C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - rtol=1e-3, - atol=1e-3, - base_name="tilelang_sp", - ref_name="ref_dense", - ) - print("pass") - - -def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): - is_8bit = "8" in in_dtype - is_unsigned = "uint" in in_dtype - is_int = "int" in in_dtype - if is_int: - if is_8bit: - low, high = (0, 4) if is_unsigned else (-2, 2) - else: - low, high = (0, 128) if is_unsigned else (-64, 64) - A = randint_semi_sparse(M, K, low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) - B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=T.dtype(in_dtype).as_torch(), device="cuda") - else: - A = randn_semi_sparse(M, K, dtype=T.dtype(in_dtype).as_torch(), device="cuda", transposed=trans_A) - B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(T.dtype(in_dtype).as_torch()) - return A, B - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), - (512, 1024, 768, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), - (128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128), - (128, 128, 128, False, True, T.int8, T.int32, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), - ], -) -def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - metadata_dtype, - E_factor, - num_stages, - threads, -): - A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - - import tilelang.language as T - - @T.prim_func - def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout( - { - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - } - ) - T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) - if trans_A: - T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) - else: - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.gemm_sp_v2(A_frag, E_shared, B_shared, C_frag, trans_A, trans_B) - T.copy(C_frag, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rs( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 - program = matmul_rs( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - metadata_dtype, - SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor - num_stages, - num_threads, - ) - kernel = tilelang.compile( - program, - out_idx=[3], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) - A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") - C_sp = kernel(A_sparse, E, B) - - def _matmul(A, B): - if trans_A: - A = A.T - if trans_B: - B = B.T - A = A.to(torch.float32) - B = B.to(torch.float32) - return torch.matmul(A, B) - - C = _matmul(A, B) - - torch_assert_close( - C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - rtol=1e-3, - atol=1e-3, - base_name="tilelang_sp", - ref_name="ref_dense", - ) - print("pass") - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), - ], -) -def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - metadata_dtype, - E_factor, - num_stages, - threads, -): - A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - B_frag_shape = B_shared_shape - - import tilelang.language as T - - @T.prim_func - def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout( - { - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - } - ) - T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) - if trans_A: - T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) - else: - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(B_shared, B_frag) - T.gemm_sp_v2(A_shared, E_shared, B_frag, C_frag, trans_A, trans_B) - T.copy(C_frag, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_sr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 - program = matmul_sr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - metadata_dtype, - SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[3], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) - A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") - C_sp = kernel(A_sparse, E, B) - - def _matmul(A, B): - if trans_A: - A = A.T - if trans_B: - B = B.T - A = A.to(torch.float32) - B = B.to(torch.float32) - return torch.matmul(A, B) - - C = _matmul(A, B) - - torch_assert_close( - C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - rtol=1e-3, - atol=1e-3, - base_name="tilelang_sp", - ref_name="ref_dense", - ) - print("pass") - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), - ], -) -def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -def matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - metadata_dtype, - E_factor, - num_stages, - threads, -): - A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) - B_shape = (N, K) if trans_B else (K, N) - A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - A_frag_shape = A_shared_shape - B_frag_shape = B_shared_shape - - import tilelang.language as T - - @T.prim_func - def main( - A_sparse: T.Tensor(A_sparse_shape, in_dtype), - E: T.Tensor((M, K // E_factor), metadata_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) - A_frag = T.alloc_fragment(A_frag_shape, in_dtype) - B_frag = T.alloc_fragment(B_frag_shape, in_dtype) - C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) - T.annotate_layout( - { - A_shared: tilelang.layout.make_swizzled_layout(A_shared), - B_shared: tilelang.layout.make_swizzled_layout(B_shared), - E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), - } - ) - T.clear(C_frag) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(E[by * block_M, k * block_K // E_factor], E_shared) - if trans_A: - T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) - else: - T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.copy(A_shared, A_frag) - T.copy(B_shared, B_frag) - T.gemm_sp_v2(A_frag, E_shared, B_frag, C_frag, trans_A, trans_B) - T.copy(C_frag, C[by * block_M, bx * block_N]) - - return main - - -def run_gemm_rr( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 - program = matmul_rr( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - metadata_dtype, - SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype], # E_factor - num_stages, - num_threads, - ) - - kernel = tilelang.compile( - program, - out_idx=[3], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, - ) - A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) - A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") - C_sp = kernel(A_sparse, E, B) - - def _matmul(A, B): - if trans_A: - A = A.T - if trans_B: - B = B.T - A = A.to(torch.float32) - B = B.to(torch.float32) - return torch.matmul(A, B) - - C = _matmul(A, B) - - torch_assert_close( - C_sp.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - C.to(T.dtype(out_dtype).as_torch()).to(torch.float32), - rtol=1e-3, - atol=1e-3, - base_name="tilelang_sp", - ref_name="ref_dense", - ) - print("pass") - - -@tilelang.testing.requires_cuda -@pytest.mark.parametrize( - "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", - [ - (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128), - (128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128), - (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128), - (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), - (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), - ], -) -def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): - run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/utils/test_compress_utils.py b/testing/python/utils/test_compress_utils.py index e8fc20539e..7912a0769f 100644 --- a/testing/python/utils/test_compress_utils.py +++ b/testing/python/utils/test_compress_utils.py @@ -1,39 +1,136 @@ +import pytest import torch import tilelang +from tilelang.utils.sparse import get_e_factor import tilelang.testing +import tilelang.language as T -from tilelang.utils.sparse import compress_sm90, randn_semi_sparse +from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse, torch_compress +from tilelang.utils.tensor import torch_assert_close -def _test_compress_sm90(M, K, block_k, dtype): - A = randn_semi_sparse(M, K, dtype=dtype, device="cuda") - A_sparse, E = compress_sm90(A, block_k, False) +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + metadata_dtype, + E_factor, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + E_shape = (M, K // E_factor) if not trans_A else (K // E_factor, M) + E_shared_shape = (block_M, block_K // E_factor) if not trans_A else (block_K // E_factor, block_M) + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor(E_shape, metadata_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared(E_shared_shape, metadata_dtype) + C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_frag) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(E[k * block_K // E_factor, by * block_M], E_shared) + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_A, trans_B) + T.copy(C_frag, C[by * block_M, bx * block_N]) -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(9, 0) -def test_compress_sm90(): - _test_compress_sm90(1024, 1024, 128, torch.float16) - _test_compress_sm90(1024, 1024, 64, torch.float16) - _test_compress_sm90(1024, 1024, 32, torch.float16) + return main + + +def generate_dense_input(N, trans_A, trans_B, in_dtype, seed=0): + torch.manual_seed(seed) + is_8bit = "8" in str(in_dtype) + is_unsigned = "uint" in str(in_dtype) + is_int = "int" in str(in_dtype) + if is_int: + if is_8bit: + low, high = (0, 128) if is_unsigned else (-64, 64) + else: + low, high = (0, 258) if is_unsigned else (-128, 128) + A = randint_semi_sparse(N, N, low=low, high=high, dtype=in_dtype, device="cuda", transposed=trans_A) + B = torch.randint(size=(N, N) if trans_B else (N, N), low=low, high=high, dtype=in_dtype, device="cuda") + else: + A = randn_semi_sparse(N, N, dtype=in_dtype, device="cuda", transposed=trans_A) + B = torch.randn((N, N) if trans_B else (N, N), device="cuda", dtype=torch.float32).to(in_dtype) + return A, B - _test_compress_sm90(1024, 1024, 128, torch.bfloat16) - _test_compress_sm90(1024, 1024, 64, torch.bfloat16) - _test_compress_sm90(1024, 1024, 32, torch.bfloat16) - _test_compress_sm90(1024, 1024, 64, torch.float32) - _test_compress_sm90(1024, 1024, 32, torch.float32) - _test_compress_sm90(1024, 1024, 16, torch.float32) +def _test_compress(dtype, meta_dtype): + A, B = generate_dense_input(64, in_dtype=dtype.as_torch(), trans_A=False, trans_B=False) + sp_tl, meta_tl = compress(A, meta_dtype=meta_dtype.as_torch()) + sp_ref, meta_ref = torch_compress(A, meta_dtype=meta_dtype.as_torch()) + # NOTE: in case that there are multiple zeros, the case might fail occasionally + # if we directly compare the compressed sparse values + program = matmul( + 64, + 64, + 64, + 64, + 64, + 64, + False, + False, + dtype, + dtype, + T.int32 if dtype == T.int8 else T.float32, + meta_dtype, + get_e_factor(dtype, meta_dtype), + 0, + 128, + ) + kernel = tilelang.compile( + program, + out_idx=[3], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, + ) - _test_compress_sm90(1024, 1024, 256, torch.float8_e4m3fn) - _test_compress_sm90(1024, 1024, 128, torch.float8_e4m3fn) - _test_compress_sm90(1024, 1024, 64, torch.float8_e4m3fn) + C_tl = kernel(sp_tl, meta_tl, B) + C_ref = kernel(sp_ref, meta_ref, B) + torch_assert_close(C_tl, C_ref, atol=1e-2, rtol=1e-2) - _test_compress_sm90(1024, 1024, 256, torch.float8_e5m2) - _test_compress_sm90(1024, 1024, 128, torch.float8_e5m2) - _test_compress_sm90(1024, 1024, 64, torch.float8_e5m2) + +@tilelang.testing.requires_cuda +@pytest.mark.parametrize( + "dtype, meta_dtype", + [ + (T.int8, T.int8), + (T.int8, T.int16), + (T.int8, T.int32), + (T.float16, T.int8), + (T.float16, T.int16), + (T.float32, T.int8), + (T.float32, T.int16), + ], +) +def test_compress(dtype, meta_dtype): + _test_compress(dtype, meta_dtype) if __name__ == "__main__": - test_compress_sm90() - print("All tests passed.") + tilelang.testing.main() diff --git a/tilelang/cuda/intrinsics/layout/mma_sp_layout.py b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py index c814e32307..e3405316bd 100644 --- a/tilelang/cuda/intrinsics/layout/mma_sp_layout.py +++ b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py @@ -87,7 +87,10 @@ def metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(thread_id: int, local_i def metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: - return metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(thread_id, local_id) # same mapping for 16bit and 32bit + logical_id = get_logical_id_32bit(thread_id) + row = logical_id // 2 + (local_id // 2) * 8 + col = (logical_id % 2) * 2 + (local_id % 2) + return row, col def metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(thread_id: int, local_id: int) -> tuple[int, int]: @@ -98,16 +101,16 @@ def get_logical_id_8bit(thread_id: int) -> int: return thread_id -def metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: +def metadata_8bit_load_32x4_to_shared_16x8_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) - row = logical_id // 2 + local_id * 8 + row = logical_id // 4 + (logical_id % 2) * 8 col = (logical_id % 4) // 2 * 4 + local_id return row, col def metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> tuple[int, int]: logical_id = get_logical_id_8bit(thread_id) - row = logical_id // 2 + local_id * 8 + row = logical_id // 4 + (logical_id % 2) * 8 col = (logical_id % 4) // 2 * 2 + local_id return row, col diff --git a/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py index 826a0f58ec..362db42599 100644 --- a/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py @@ -11,6 +11,7 @@ get_ldmatrix_offset, ) from tilelang.utils import is_fragment, get_buffer_region_from_load +from tilelang.utils.sparse import get_e_factor, get_e_replicate_factor from tilelang.cuda.intrinsics.layout.mma_sp_layout import ( shared_16x16_to_mma_sp_layout_sr_a, @@ -29,7 +30,7 @@ metadata_16bit_load_32x2_to_shared_16x2_layout_32bit, metadata_8bit_load_32x4_to_shared_16x4_layout_16bit, metadata_16bit_load_32x2_to_shared_16x2_layout_16bit, - metadata_8bit_load_32x4_to_shared_16x4_layout_8bit, + metadata_8bit_load_32x4_to_shared_16x8_layout_8bit, metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, metadata_32bit_load_32x1_to_shared_16x2_layout_8bit, get_ldmatrix_offset_b, @@ -60,27 +61,6 @@ class SparseTensorCoreIntrinEmitter: "float8_e5m2": "e5m2", } - E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor - "float": {"int16": 8, "uint16": 8}, - "float32": {"int16": 8, "uint16": 8}, - "float16": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - "bfloat16": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - "int8": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - "uint8": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - "float8_e4m3": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - "float8_e5m2": {"int8": 8, "uint8": 8, "int16": 16, "uint16": 16, "int32": 32, "uint32": 32}, - } - - E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads - "float32": 2, - "float16": 2, # 2 of 4 consecutive threads provides - "bfloat16": 2, - "int8": 1, # 4 of 4 consecutive threads provides - "uint8": 1, - "float8_e4m3": 1, - "float8_e5m2": 1, - } - # Represent the thread binding in the form of (tx, warp_n, warp_m) is_m_first = False @@ -116,7 +96,7 @@ def __init__( self.warp_row_tiles = warp_row_tiles self.warp_col_tiles = warp_col_tiles self.warp_k = warp_k - self.e_factor = self.E_FACTOR_MAP[self.a_dtype][self.e_dtype] + self.e_factor = get_e_factor(self.a_dtype, self.e_dtype) self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_micro_size(self.M_DIM, self.k_dim) @@ -143,9 +123,13 @@ def _initialize_k_dim(self, a_dtype=T.float16): def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size // self.SPARSE_FACTOR - self.local_size_e = (m_dim * k_dim) // self.e_factor // warp_size * self.E_REPLICATE_FACTOR[self.a_dtype] + self.local_size_e = (m_dim * k_dim) // self.e_factor * get_e_replicate_factor(self.a_dtype) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size + assert self.local_size_a > 0, f"local_size_a must be greater than 0, got {self.local_size_a}" + assert self.local_size_e > 0, f"local_size_e must be greater than 0, got {self.local_size_e}" + assert self.local_size_b > 0, f"local_size_b must be greater than 0, got {self.local_size_b}" + assert self.local_size_out > 0, f"local_size_out must be greater than 0, got {self.local_size_out}" def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] @@ -331,7 +315,7 @@ def mma_load_layout(i, j): if not ldmatrix_available: if DataType(e_dtype).bits == 8: if DataType(a_dtype).bits == 8: - mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_8bit + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x8_layout_8bit elif DataType(a_dtype).bits == 16: mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit elif DataType(a_dtype).bits == 32: diff --git a/tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py b/tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py new file mode 100644 index 0000000000..6ed04b129d --- /dev/null +++ b/tilelang/cuda/intrinsics/macro/wgmma_sp_macro_generator.py @@ -0,0 +1,613 @@ +from __future__ import annotations +from tilelang.cuda.intrinsics.macro.wgmma_macro_generator import SwizzleMode, gcd +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +import tilelang.language as T +from tvm import DataType +from tvm.tir import PrimExpr, Buffer, Var, BufferRegion, IndexMap +from tilelang.utils import is_fragment, is_shared, retrive_ptr_from_buffer_region, is_full_region +from tilelang.cuda.intrinsics.layout.mma_layout import ( + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, +) +from tilelang.layout import ( + Layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +) +from tilelang.cuda.intrinsics.layout.mma_sp_layout import ( + metadata_8bit_load_32x4_to_shared_16x4_layout_32bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_32bit, + metadata_8bit_load_32x4_to_shared_16x4_layout_16bit, + metadata_16bit_load_32x2_to_shared_16x2_layout_16bit, + metadata_8bit_load_32x4_to_shared_16x8_layout_8bit, + metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, + metadata_32bit_load_32x1_to_shared_16x2_layout_8bit, +) + + +class WGSparseTensorCoreIntrinEmitter(SparseTensorCoreIntrinEmitter): + wgmma_prefix: str + + wgmma_inst_m: int + + wgmma_inst_n: int + + a_shared_layout: Layout = None + b_shared_layout: Layout = None + + def __init__( + self, + a_dtype: str = T.float16, + e_dtype: str = T.uint8, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, + a_transposed: bool = False, + b_transposed: bool = False, + e_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + warp_k: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: bool | None = False, + thread_var: Var | None = None, + ): + assert reduce_k == 1, f"{reduce_k=} is not supported" + super().__init__( + a_dtype=a_dtype, + e_dtype=e_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + e_transposed=e_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=warp_k, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + is_m_first=is_m_first, + thread_var=thread_var, + ) + self._initialize_wgmma_prefix(self.n_dim) + + def _assign_a_shared_layout(self, layout: Layout): + self.a_shared_layout = layout + return self + + def _assign_b_shared_layout(self, layout: Layout): + self.b_shared_layout = layout + return self + + def _initialize_wgmma_prefix(self, n_dim: int = 16): + inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) + assert inst_n % 8 == 0, ( + f"inst_n must be a multiple of 8, got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) + # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 + assert 8 <= inst_n <= 256, ( + f"inst_n must be within [8, 256], got {inst_n} (block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})" + ) + # 512 bits per instruction for sparse wgmma + inst_k = 512 // DataType(self.a_dtype).bits + self.wgmma_inst_m = inst_m + self.wgmma_inst_n = inst_n + self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" + + def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: + # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper + if layout is None or layout.is_equal(make_linear_layout(buffer)): + return SwizzleMode.NONE + elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_32B + elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_64B + elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + else: + raise ValueError(f"Unsupported swizzle mode: {layout}") + + def wgmma_ss( + self, + A_region: BufferRegion, + E_region: BufferRegion, + B_region: BufferRegion, + C_region: BufferRegion, + clear_accum: PrimExpr = False, + wg_wait: int = 0, + ): + assert is_shared(A_region), "A operand must be a shared buffer for wgmma_ss" + assert is_shared(E_region), "E operand must be a shared buffer for wgmma_ss" + + local_size_out = self.local_size_out + local_size_e = self.local_size_e + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.warp_k, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + + a_swizzle_mode = self._determinate_swizzle_mode(A_region, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + + elems_in_bits = DataType(self.a_dtype).bits + elems_in_bytes = elems_in_bits // 8 + + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + accum_bits = DataType(accum_dtype).bits + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 + + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * elems_in_bytes) + + if not a_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + # MN Major + # LBO represents the distance between two atoms along the M dimension + # SBO represents the distance between two atoms along the K dimension + a_m_axis_atoms = m_dim // a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + # MN Major, K * N + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + # for example, if [n, k] where k is 128, we should split it into 2 atoms + # where max specially handles the case when n_dim is 8. + ak_atom_size = max(a_swizzle_atom_elems // (micro_size_k // self.SPARSE_FACTOR), 1) + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() + + A_ptr = retrive_ptr_from_buffer_region(A_region) + B_ptr = retrive_ptr_from_buffer_region(B_region) + assert is_full_region(C_region), "Fragment output C must be a full region" + C_buf = C_region.buffer + + @T.macro + def _warp_mma(A_ptr, B_ptr, C_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + k_blocks = k_dim // micro_size_k + e_stage_elems = self.warp_rows * self.local_size_e + E_local = T.alloc_local((k_blocks * e_stage_elems), self.e_dtype) + + desc_a = T.alloc_wgmma_desc() + desc_b = T.alloc_wgmma_desc() + T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + + for ki in T.unroll(k_blocks): + for i in T.unroll(num_inst_m): + self.ldmatrix_e(E_local, E_region, i, warp_m, ki, ki) + + # NOTE: cutlass doesn't fence metadata, we follow the same here + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_arrive() + + for ki in T.unroll(k_blocks): + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + e_local_offset = ki * e_stage_elems + i * local_size_e + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + warp_i = (warp_m // 4) * num_inst_m + i + warp_j = warp_n * num_inst_n + j + A_offset = ( + (ki % ak_atom_size) * (micro_size_k // self.SPARSE_FACTOR) + + warp_i * 64 * a_swizzle_atom_elems + + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems + if a_is_k_major + else warp_i * 64 * (k_dim // self.SPARSE_FACTOR) + + ki * a_swizzle_atom_elems * (micro_size_k // self.SPARSE_FACTOR) + ) + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit + T.ptx_wgmma_sp_ss( + accum_dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + desc_a.data, + (A_offset * elems_in_bytes) >> 4, + E_local.data, + e_local_offset, + self.SPARSE_SELECTOR, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + T.warpgroup_commit_batch() + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + + return _warp_mma(A_ptr, B_ptr, C_buf) + + def wgmma_rs( + self, + A_region: BufferRegion, + E_region: BufferRegion, + B_region: BufferRegion, + C_region: BufferRegion, + clear_accum: PrimExpr = False, + wg_wait: int = 0, + ): + assert is_fragment(A_region), "A operand must be a fragment buffer for wgmma_rs" + assert is_shared(E_region), "E operand must be a shared buffer for wgmma_rs" + + local_size_a = self.local_size_a + local_size_out = self.local_size_out + local_size_e = self.local_size_e + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + m_dim = self.block_row_warps * self.warp_row_tiles + warp_rows, warp_cols = self.warp_rows, self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.warp_k, self.block_col_warps * self.warp_col_tiles + wgmma_prefix = self.wgmma_prefix + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_bits = DataType(self.a_dtype).bits + accum_bits = DataType(accum_dtype).bits + a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32 + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 + b_is_k_major = self.b_transposed + + b_swizzle_mode = self._determinate_swizzle_mode(B_region, self.b_shared_layout) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() + + assert is_full_region(A_region), "Fragment input A must be a full region" + assert is_full_region(C_region), "Fragment output C must be a full region" + A_buf = A_region.buffer + B_ptr = retrive_ptr_from_buffer_region(B_region) + C_buf = C_region.buffer + + k_blocks = k_dim // micro_size_k + e_stage_elems = self.warp_rows * self.local_size_e + + @T.macro + def _warp_mma(A_buf, B_ptr, C_buf): + _, warp_n, warp_m = self.extract_thread_binding(thread_binding) + E_local = T.alloc_local((k_blocks * e_stage_elems), self.e_dtype) + + desc_b = T.alloc_wgmma_desc() + T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + + for ki in T.unroll(k_blocks): + for i in T.unroll(num_inst_m): + self.ldmatrix_e(E_local, E_region, i, warp_m, ki, ki) + + # NOTE: cutlass doesn't fence metadata, we follow the same here + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_arrive() + + for ki in T.unroll(k_blocks): + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + e_local_offset = ki * e_stage_elems + i * local_size_e + scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1)) + warp_j = warp_n * num_inst_n + j + A_offset = ki * warp_rows * local_size_a + i * local_size_a + B_offset = ( + (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + + warp_j * wgmma_inst_n * b_swizzle_atom_elems + + (ki % bk_atom_size) * micro_size_k + if b_is_k_major + else ( + ki * b_swizzle_atom_elems * micro_size_k + + warp_j * wgmma_inst_n * (k_dim if n_dim // b_swizzle_atom_elems > 1 else 1) + ) + ) + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n + T.ptx_wgmma_sp_rs( + accum_dtype, + wgmma_prefix, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf.data, + A_offset, + E_local.data, + e_local_offset, + self.SPARSE_SELECTOR, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + T.warpgroup_commit_batch() + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) + T.warpgroup_fence_operand(C_buf, num_regs=accum_regs) + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + + return _warp_mma(A_buf, B_ptr, C_buf) + + def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, inst_i: PrimExpr, warp_m: PrimExpr, ki: PrimExpr, ki_slot: PrimExpr): + num_inst_m = 4 * self.warp_row_tiles // self.wgmma_inst_m + micro_size_k = self.micro_size_k + local_size_e = self.local_size_e + e_stage_elems = self.warp_rows * local_size_e + a_dtype = self.a_dtype + e_dtype = self.e_dtype + trans = self.e_transposed + # ldmatrix cannot be used for int8 + trans case. + # include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h + ldmatrix_available = False # TODO: use ldmatrix when possible + + def mma_load_layout(i, j): + return i, j + + if not ldmatrix_available: + if DataType(e_dtype).bits == 8: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x8_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_8bit_load_32x4_to_shared_16x4_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 8bit: {a_dtype}") + elif DataType(e_dtype).bits == 16: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit + elif DataType(a_dtype).bits == 16: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit + elif DataType(a_dtype).bits == 32: + mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}") + elif DataType(e_dtype).bits == 32: + if DataType(a_dtype).bits == 8: + mma_load_layout = metadata_32bit_load_32x1_to_shared_16x2_layout_8bit + else: + raise ValueError(f"Unsupported a_dtype for e_dtype 32bit: {a_dtype}") + else: + raise ValueError(f"Unsupported dtype: {e_dtype}") + + thread_binding = self.get_thread_binding() + + E_region = self._legalize_to_buffer_region(E_shared_buf) + E_buf = E_region.buffer + E_base0 = E_region.region[-2].min + E_base1 = E_region.region[-1].min + E_other = [r.min for r in E_region.region[:-2]] + + @T.macro + def _warp_ldmatrix_e( + E_local_buf, + E_shared_buf, + inst_i, + ki, + thread_binding, + ): + wi = ((warp_m // 4) * num_inst_m + inst_i) * 64 + (warp_m % 4) * 16 + wk = (ki * micro_size_k) // self.e_factor + e_local_base = ki_slot * e_stage_elems + tx, _, _ = self.extract_thread_binding(thread_binding) + for j in T.serial(local_size_e): + mi, mk = mma_load_layout(tx, j) + E_local_buf[e_local_base + inst_i * local_size_e + j] = ( + E_shared_buf[tuple(E_other) + (E_base0 + wk + mk, E_base1 + wi + mi)] + if trans + else E_shared_buf[tuple(E_other) + (E_base0 + wi + mi, E_base1 + wk + mk)] + ) + + return _warp_ldmatrix_e(E_local_buf, E_buf, inst_i, ki, thread_binding) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: + assert matrix == "A", "WGMMA sparse only supports A matrix load layout" + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" + + dtype = self.a_dtype + dtype_bits = DataType(dtype).bits + transposed = self.a_transposed + + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + elif dtype_bits == 8: + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + else: + raise ValueError(f"Unsupported dtype {dtype}") + + is_sr_axis_order = not transposed + + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(j, i) + + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) + + def forward_thread(i: int, j: int) -> int: + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + _, local_id = inverse_mma_load_layout.map_indices([i, j]) + return local_id + + micro_size_s = self.micro_size_x + # sparse: each instruction holds micro_size_k/SPARSE_FACTOR actual K elements + micro_size_r = self.micro_size_k // self.SPARSE_FACTOR + + base_fragment = T.Fragment( + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + + warp_rows = self.warp_rows + # number of instructions in K direction + warp_r = self.warp_k // self.micro_size_k + block_s = self.block_row_warps + replicate = self.block_col_warps + + if is_sr_axis_order: + warp_fragment = base_fragment.repeat([block_s, 1], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_rows, warp_r], repeat_on_thread=False, lower_dim_first=False) + else: + warp_fragment = base_fragment.repeat([1, block_s], repeat_on_thread=True, lower_dim_first=False).replicate(replicate) + block_fragment = warp_fragment.repeat([warp_r, warp_rows], repeat_on_thread=False, lower_dim_first=True) + + return block_fragment + + def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment: + """ + Create a layout function for storing MMA results into a fragment buffer. + This layout is used in conjunction with `inverse_mma_store_layout` to + map fragment indices to threads and local indices. + + Parameters + ---------- + local_buf : tir.Buffer + The local buffer representing a fragment of a matrix. + + Returns + ------- + T.Fragment + A fragment object that describes how threads and indices + in `local_buf` are laid out. + + Raises + ------ + AssertionError + If `local_buf` is not detected to be a fragment buffer. + """ + inverse_mma_store_layout = self.get_store_index_map(inverse=True) + assert is_fragment(local_buf), "local_buf must be a fragment" + micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y + block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps + warp_rows, warp_cols = self.warp_rows, self.warp_cols + + def forward_thread(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a thread index according to `inverse_mma_store_layout`. + """ + lane_id, _ = inverse_mma_store_layout.map_indices([i, j]) + return lane_id + + def forward_index(i: int, j: int) -> int: + """ + Given the row index `i` and column index `j` in the fragment, + map them to a local index in a single thread according + to `inverse_mma_store_layout`. + """ + _, local_id = inverse_mma_store_layout.map_indices([i, j]) + return local_id + + # reproduce src/layout/gemm_layouts.cc::makeGemmFragmentCHopper + base_fragment = T.Fragment( + [micro_size_x, micro_size_y], + forward_thread_fn=forward_thread, + forward_index_fn=forward_index, + ) + warp_n_layout = base_fragment.repeat([1, warp_cols], False, False) + block_layout = warp_n_layout.repeat([block_row_warps, block_col_warps], True, False) + warp_m_layout = block_layout.repeat([warp_rows, 1], False, False) + return warp_m_layout diff --git a/tilelang/cuda/op/gemm_sp/__init__.py b/tilelang/cuda/op/gemm_sp/__init__.py index ed8ade6377..8e5fdd8f53 100644 --- a/tilelang/cuda/op/gemm_sp/__init__.py +++ b/tilelang/cuda/op/gemm_sp/__init__.py @@ -3,8 +3,18 @@ from __future__ import annotations from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl -from .gemm_sp_mma import GemmSPMMA -from tilelang.utils.target import target_is_cuda +from tilelang.cuda.op.gemm_sp.gemm_sp_mma import GEMM_SP_INST_MMA_SP, GemmSPMMA +from tilelang.cuda.op.gemm_sp.gemm_sp_wgmma import GEMM_SP_INST_WGMMA_SP, GemmSPWGMMA +from tilelang.utils.target import target_is_cuda, target_is_turing, target_is_volta -register_gemm_sp_impl("cuda.GemmSPMMA", target_is_cuda, GemmSPMMA) +def _match_mma(target) -> bool: + return target_is_cuda(target) and not (target_is_volta(target) or target_is_turing(target)) + + +def _match_wgmma(target) -> bool: + return target_is_cuda(target) + + +register_gemm_sp_impl("cuda.mma.sp", GEMM_SP_INST_MMA_SP, _match_mma, GemmSPMMA) +register_gemm_sp_impl("cuda.wgmma.sp", GEMM_SP_INST_WGMMA_SP, _match_wgmma, GemmSPWGMMA) diff --git a/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py index dc381f7047..a641eba41f 100644 --- a/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py +++ b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py @@ -9,13 +9,12 @@ from tilelang.transform.simplify import _Simplify -GEMM_SP_INST_MMA = "cuda.mma" +GEMM_SP_INST_MMA_SP = "cuda.mma.sp" class GemmSPMMA(GemmSPBase): def infer_layout(self, target: Target, thread_nums: int): - # NOTE(wt): Actually gemm_sp v2 currently use GemmWarpPolicy - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_MMA) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_MMA_SP) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( @@ -59,9 +58,8 @@ def infer_layout(self, target: Target, thread_nums: int): else: raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") - def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): - # NOTE(wt): Actually gemm_sp v2 currently use GemmWarpPolicy - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_MMA) + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_MMA_SP) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mma_emitter = SparseTensorCoreIntrinEmitter( diff --git a/tilelang/cuda/op/gemm_sp/gemm_sp_wgmma.py b/tilelang/cuda/op/gemm_sp/gemm_sp_wgmma.py new file mode 100644 index 0000000000..54e8968b84 --- /dev/null +++ b/tilelang/cuda/op/gemm_sp/gemm_sp_wgmma.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase +from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, + Layout, +) +from tilelang.cuda.intrinsics.macro.wgmma_sp_macro_generator import WGSparseTensorCoreIntrinEmitter +from tilelang.utils.language import is_shared, is_fragment +from tvm.target import Target +from tvm.ir import Range +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify +from typing import Callable + +GEMM_SP_INST_WGMMA_SP = "cuda.wgmma.sp" + + +class GemmSPWGMMA(GemmSPBase): + def infer_shared_layout(self, continuity: int) -> Callable[[tir.Buffer], Layout]: + vectorized_size = 128 // self.in_dtype.bits + if continuity % (vectorized_size * 8) == 0: + return make_full_bank_swizzled_layout + elif continuity % (vectorized_size * 4) == 0: + return make_half_bank_swizzled_layout + elif continuity % (vectorized_size * 2) == 0: + return make_quarter_bank_swizzled_layout + else: + return make_linear_layout + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_WGMMA_SP) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = WGSparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + a_continuity = self.K // mma_emitter.SPARSE_FACTOR if a_is_k_major else mma_emitter.wgmma_inst_m + b_continuity = self.K if b_is_k_major else mma_emitter.wgmma_inst_n + if self.is_gemm_ss(): + return { + # WGMMA does not support padding + self.A: self.infer_shared_layout(a_continuity)(self.A), + self.B: self.infer_shared_layout(b_continuity)(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: self.infer_shared_layout(b_continuity)(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower( + self, + layout_map: dict, + target: Target, + thread_nums: Range, + thread_var: tir.Var, + mbar_phase_expr: tir.PrimExpr | None = None, + ): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_SP_INST_WGMMA_SP) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = WGSparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + thread_var=thread_var, + ) + + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B]) + + # Get base offsets from regions + # All dimensions may have offsets, including the matrix dimensions + # However, for WGMMA, we pass the Buffer directly and handle offsets + # through proper indexing in the access_ptr call or buffer slicing + + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + E_region = self.ERegion + + clear_accum = self.clear_accum + wg_wait = self.wg_wait + + if self.is_gemm_ss(): + # For WGMMA, we need to handle buffer region offsets + # If there are offsets, we create a BufferLoad inside the prim_func + # to properly generate offset access + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + + # Perform Matrix Multiplication with offset consideration + mma_emitter.wgmma_ss(A_region, E_region, B_region, C_region, clear_accum, wg_wait) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_rs(): + + @T.prim_func + def _gemm_rsr() -> None: + mma_emitter.wgmma_rs(A_region, E_region, B_region, C_region, clear_accum, wg_wait) + + return _Simplify(_gemm_rsr, inline_let=True) + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/ir.py b/tilelang/ir.py index d3005e6235..d6bcea2f80 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -39,15 +39,11 @@ class GemmSPWarpPolicy(Node, Scriptable): m_warp: int n_warp: int - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, gemm_inst: str, bits: int): - _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, gemm_inst, bits) + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, gemm_inst: str): + _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target, gemm_inst) return self.m_warp, self.n_warp -@tvm_ffi.register_object("tl.GemmSP") -class GemmSP(Node, Scriptable): ... - - @tvm_ffi.register_object("tl.FinalizeReducerOp") class FinalizeReducerOp(Node, Scriptable): ... diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 43c70563a2..6c5042e531 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -64,7 +64,11 @@ tcgen05_gemm_blockscaled, make_blockscaled_gemm_layout, ) -from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401 +from .experimental.gemm_sp_op import ( # noqa: F401 + gemm_sp, + wgmma_gemm_sp, + tcgen05_gemm_sp, +) from .fill_op import fill, clear # noqa: F401 from .reduce_op import ( reduce, # noqa: F401 diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 7865b74be7..671538cca3 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -1884,6 +1884,8 @@ def wrapped(*args, **kwargs): ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_wgmma_sp_ss = _dtype_forward(_tir_op.ptx_wgmma_sp_ss) +ptx_wgmma_sp_rs = _dtype_forward(_tir_op.ptx_wgmma_sp_rs) ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) ptx_tcgen05_mma_blockscaled_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_blockscaled_ss) @@ -2138,6 +2140,8 @@ def wrapped(*args, **kwargs): "ptx_mma_sp", "ptx_wgmma_ss", "ptx_wgmma_rs", + "ptx_wgmma_sp_ss", + "ptx_wgmma_sp_rs", "ptx_tcgen05_mma_ss", "ptx_tcgen05_mma_blockscaled_ss", "ptx_ldmatrix", diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py deleted file mode 100644 index fa722b6890..0000000000 --- a/tilelang/language/experimental/gemm_sp.py +++ /dev/null @@ -1,224 +0,0 @@ -"""The language interface for tl programs.""" - -from __future__ import annotations -from tilelang.tileop.base import GemmWarpPolicy -import tilelang.language as T -from tvm import tir -from tilelang.utils.language import ( - to_buffer_region, - retrieve_shape, - retrieve_stride, - retrieve_offset, - prim_expr_equal, -) -from tilelang.language.utils import ( - buffer_region_to_tile_region, -) -from tilelang._typing import BufferLikeType - - -def gemm_sp( - A_sparse: BufferLikeType | tir.Var, - E: BufferLikeType | tir.Var, - B: BufferLikeType | tir.Var, - C: BufferLikeType | tir.Var, - transpose_A: bool = False, - transpose_B: bool = False, - policy: GemmWarpPolicy = GemmWarpPolicy.Square, - clear_accum: bool = False, - k_pack: int = 1, - wg_wait: int = 0, -): - """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation. - - This function computes C = A @ B where A and B can optionally be transposed. - The operation supports various warp policies and accumulation modes. - - Args: - A_sparse (Union[BufferLikeType, tir.Var]): First input matrix dense values - E (Union[BufferLikeType, tir.Var]): First input matrix sparse metadata - B (Union[BufferLikeType, tir.Var]): Second input matrix - C (Union[BufferLikeType, tir.Var]): Output matrix for results - transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. - transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. - policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. - clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. - k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. - wg_wait (int, optional): Warp group wait count. Defaults to 0. - - Returns: - tir.Call: A handle to the GEMM operation - - Raises: - AssertionError: If the K dimensions of matrices A and B don't match - """ - - def legalize_arguments(arg: BufferLikeType | tir.Var): - """Convert let-bound variables to their corresponding buffers. - - Args: - arg (Union[BufferLikeType, tir.Var]): Input argument to legalize - - Returns: - Union[BufferLikeType, tir.Var]: The legalized argument - """ - if isinstance(arg, tir.Var) and T.has_let_value(arg): - return T.get_let_value(arg).buffer - return arg - - A_sparse = legalize_arguments(A_sparse) - B = legalize_arguments(B) - C = legalize_arguments(C) - M = C.shape[0] - N = C.shape[1] - K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1] - K_B = B.shape[1] if transpose_B else B.shape[0] - assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}" - # Build tl.region descriptors for operands - A_arg = to_buffer_region(A_sparse, access_type="r") - E_arg = to_buffer_region(E, access_type="r") - B_arg = to_buffer_region(B, access_type="r") - C_arg = to_buffer_region(C, access_type="rw") - return tir.call_intrin( - "handle", - tir.op.Op.get("tl.tileop.gemm_sp"), - A_arg, - E_arg, - B_arg, - C_arg, - transpose_A, - transpose_B, - M, - N, - K_B, - policy, - clear_accum, - k_pack, - wg_wait, - ) - - -# experimental currently, for fast compilation -def gemm_sp_v2( - A_sparse: BufferLikeType | tir.Var, - E: BufferLikeType | tir.Var, - B: BufferLikeType | tir.Var, - C: BufferLikeType | tir.Var, - transpose_A: bool = False, - transpose_B: bool = False, - transpose_E: bool = False, - policy: GemmWarpPolicy = GemmWarpPolicy.Square, - clear_accum: bool = False, - k_pack: int = 1, - wg_wait: int = 0, -): - """Perform a General Matrix Multiplication (GEMM) operation. - - This function computes C = A @ B where A and B can optionally be transposed. - The operation supports various warp policies and accumulation modes. - - Args: - A_sparse (Union[BufferLikeType, tir.Var]): First input matrix, contains only non-zero elements - E (Union[BufferLikeType, tir.Var]): The metadata of A_sparse, noted as E - B (Union[BufferLikeType, tir.Var]): Second input matrix - C (Union[BufferLikeType, tir.Var]): Output matrix for results - transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. - transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. - policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. - clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. - k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. - wg_wait (int, optional): Warp group wait count. Defaults to 0. - - Returns: - tir.Call: A handle to the GEMM operation - - Raises: - AssertionError: If the K dimensions of matrices A and B don't match - """ - - def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: - """Convert let-bound variables to their corresponding buffers. - - Args: - arg (Union[BufferLikeType, tir.Var]): Input argument to legalize - - Returns: - Union[BufferLikeType, tir.Var]: The legalized argument - """ - if isinstance(arg, tir.Var) and T.has_let_value(arg): - return T.get_let_value(arg).buffer - return arg - - A_sparse = legalize_arguments(A_sparse) - E = legalize_arguments(E) - B = legalize_arguments(B) - C = legalize_arguments(C) - - A_region = to_buffer_region(A_sparse) - E_region = to_buffer_region(E) - B_region = to_buffer_region(B) - C_region = to_buffer_region(C) - - A_shape = retrieve_shape(A_sparse) - E_shape = retrieve_shape(E) # nolint: F841 - B_shape = retrieve_shape(B) - C_shape = retrieve_shape(C) - - A_stride = retrieve_stride(A_sparse) - B_stride = retrieve_stride(B) - - assert len(C_shape) == 2, "current only support C as a 2D tensor" - assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" - assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" - if len(A_shape) > 2: - for i in range(len(A_shape) - 2): - assert A_shape[i] == 1, ( - "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" - ) - if len(B_shape) > 2: - for i in range(len(B_shape) - 2): - assert B_shape[i] == 1, ( - "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" - ) - - M, N = C_shape - K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) - K_B = B_shape[-1] if transpose_B else B_shape[-2] - assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" - - stride_a = A_stride[-2] - stride_b = B_stride[-2] - - A_offset = retrieve_offset(A_sparse) - B_offset = retrieve_offset(B) - assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" - assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" - offset_a = A_offset[-1] - offset_b = B_offset[-1] - - A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) - E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape]) - B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) - C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) - return tir.call_intrin( - "handle", - tir.op.Op.get("tl.tileop.gemm_sp_py"), - A_arg, - E_arg, - B_arg, - C_arg, - transpose_A, - transpose_B, - transpose_E, - M, - N, - K, - policy, - clear_accum, - stride_a, - stride_b, - offset_a, - offset_b, - k_pack, - wg_wait, - ) diff --git a/tilelang/language/experimental/gemm_sp_op.py b/tilelang/language/experimental/gemm_sp_op.py new file mode 100644 index 0000000000..e95d4bdead --- /dev/null +++ b/tilelang/language/experimental/gemm_sp_op.py @@ -0,0 +1,274 @@ +"""Sparse GEMM operators exposed on the TileLang language surface.""" + +from __future__ import annotations +from tilelang.tileop.base import GemmWarpPolicy +import tilelang.language as T +from tvm import tir +from tilelang.utils.language import ( + to_buffer_region, + retrieve_shape, + retrieve_stride, + retrieve_offset, + prim_expr_equal, +) +from tilelang.language.utils import ( + buffer_region_to_tile_region, +) +from tilelang._typing import BufferLikeType + + +def _gemm_sp_impl( + op_key: str, + A_sparse: BufferLikeType | tir.Var, + E: BufferLikeType | tir.Var, + B: BufferLikeType | tir.Var, + C: BufferLikeType | tir.Var, + transpose_A: bool = False, + transpose_E: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +) -> tir.Call: + """Shared sparse GEMM implementation. + + Returns a call_intrin handle for the given op key. + """ + + def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A_sparse = legalize_arguments(A_sparse) + E = legalize_arguments(E) + B = legalize_arguments(B) + C = legalize_arguments(C) + + A_region = to_buffer_region(A_sparse) + E_region = to_buffer_region(E) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + + A_shape = retrieve_shape(A_sparse) + E_shape = retrieve_shape(E) + B_shape = retrieve_shape(B) + C_shape = retrieve_shape(C) + + A_stride = retrieve_stride(A_sparse) + B_stride = retrieve_stride(B) + + assert len(C_shape) == 2, "current only support C as a 2D tensor" + assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" + assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" + if len(A_shape) > 2: + for i in range(len(A_shape) - 2): + assert A_shape[i] == 1, ( + "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + if len(B_shape) > 2: + for i in range(len(B_shape) - 2): + assert B_shape[i] == 1, ( + "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" + ) + + M, N = C_shape + K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}" + + stride_a = A_stride[-2] + stride_b = B_stride[-2] + + A_offset = retrieve_offset(A_sparse) + B_offset = retrieve_offset(B) + assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" + assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" + offset_a = A_offset[-1] + offset_b = B_offset[-1] + + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + return tir.call_intrin( + "handle", + tir.op.Op.get(op_key), + A_arg, + E_arg, + B_arg, + C_arg, + transpose_A, + transpose_E, + transpose_B, + M, + N, + K, + policy, + clear_accum, + stride_a, + stride_b, + offset_a, + offset_b, + k_pack, + wg_wait, + ) + + +def gemm_sp( + A_sparse: BufferLikeType | tir.Var, + E: BufferLikeType | tir.Var, + B: BufferLikeType | tir.Var, + C: BufferLikeType | tir.Var, + transpose_A: bool = False, + transpose_E: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +) -> tir.Call: + """TileLang sparse GEMM operator. + + This is the default synchronous sparse GEMM interface. On Hopper, if the + compiler selects WGMMA SP lowering, TileLang inserts the corresponding wait + implicitly. + + For manual asynchronous scheduling, use ``T.wgmma_gemm_sp(...)`` with + ``T.wait_wgmma(...)`` on Hopper, or ``T.tcgen05_gemm_sp(...)`` on Blackwell. + + Args: + A_sparse: Compressed sparse matrix containing only non-zero elements. + E: Metadata tensor encoding the sparsity pattern of A. + B: Dense input matrix. + C: Output accumulator matrix. + transpose_A: Whether to transpose A. Defaults to False. + transpose_E: Whether to transpose E. Defaults to False. + transpose_B: Whether to transpose B. Defaults to False. + policy: Warp partition policy. Defaults to GemmSPWarpPolicy.Square. + clear_accum: Whether to zero the accumulator before computation. Defaults to False. + k_pack: Number of K dimensions packed per warp. Defaults to 1. + wg_wait: Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the sparse GEMM operation. + """ + return _gemm_sp_impl( + "tl.tileop.gemm_sp", + A_sparse, + E, + B, + C, + transpose_A, + transpose_E, + transpose_B, + policy, + clear_accum, + k_pack, + wg_wait, + ) + + +def wgmma_gemm_sp( + A_sparse: BufferLikeType | tir.Var, + E: BufferLikeType | tir.Var, + B: BufferLikeType | tir.Var, + C: BufferLikeType | tir.Var, + transpose_A: bool = False, + transpose_E: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, +) -> tir.Call: + """Explicit Hopper WGMMA sparse GEMM without an implicit wait. + + This is the explicit asynchronous Hopper WGMMA counterpart to the default + synchronous ``T.gemm_sp(...)`` interface, with two stricter guarantees: + - it always requests the WGMMA SP lowering path + - it never auto-emits an inlined ``warpgroup_wait`` + + If the current target or operand pattern cannot use Hopper WGMMA SP, + compilation fails instead of silently falling back to MMA SP. + + Args: + A_sparse: Compressed sparse matrix containing only non-zero elements. + E: Metadata tensor encoding the sparsity pattern of A. + B: Dense input matrix. + C: Output accumulator matrix. + transpose_A: Whether to transpose A. Defaults to False. + transpose_E: Whether to transpose E. Defaults to False. + transpose_B: Whether to transpose B. Defaults to False. + policy: Warp partition policy. Defaults to GemmSPWarpPolicy.Square. + clear_accum: Whether to zero the accumulator before computation. Defaults to False. + + Returns: + tir.Call: A handle to the sparse GEMM operation. + """ + return _gemm_sp_impl( + "tl.tileop.wgmma_gemm_sp", + A_sparse, + E, + B, + C, + transpose_A, + transpose_E, + transpose_B, + policy, + clear_accum, + 1, + -1, + ) + + +def tcgen05_gemm_sp( + A_sparse: BufferLikeType | tir.Var, + E: BufferLikeType | tir.Var, + B: BufferLikeType | tir.Var, + C: BufferLikeType | tir.Var, + transpose_A: bool = False, + transpose_E: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, +) -> tir.Call: + """Explicit Blackwell TCGEN05 sparse GEMM without an implicit wait. + + This is the explicit asynchronous Blackwell TCGEN05 counterpart to the + default synchronous ``T.gemm_sp(...)`` interface, with two stricter + guarantees: + - it always requests the TCGEN05 SP lowering path + - it never auto-emits an inlined ``mbarrier_wait_parity`` + + If the current target or operand pattern cannot use Blackwell TCGEN05 SP, + compilation fails instead of silently falling back to another sparse GEMM + path. + + Args: + A_sparse: Compressed sparse matrix containing only non-zero elements. + E: Metadata tensor encoding the sparsity pattern of A. + B: Dense input matrix. + C: Output accumulator matrix. + transpose_A: Whether to transpose A. Defaults to False. + transpose_E: Whether to transpose E. Defaults to False. + transpose_B: Whether to transpose B. Defaults to False. + policy: Warp partition policy. Defaults to GemmSPWarpPolicy.Square. + clear_accum: Whether to zero the accumulator before computation. Defaults to False. + + Returns: + tir.Call: A handle to the sparse GEMM operation. + """ + return _gemm_sp_impl( + "tl.tileop.tcgen05_gemm_sp", + A_sparse, + E, + B, + C, + transpose_A, + transpose_E, + transpose_B, + policy, + clear_accum, + 1, + 0, + ) diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 384be21ccc..48b50b4744 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -290,6 +290,8 @@ def wrapped(*args, **kwargs): ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_wgmma_sp_ss = _dtype_forward(_tir_op.ptx_wgmma_sp_ss) +ptx_wgmma_sp_rs = _dtype_forward(_tir_op.ptx_wgmma_sp_rs) ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts) ptx_tcgen05_mma_blockscaled_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_blockscaled_ss) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 0aee38da35..4830d48a03 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1140,6 +1140,94 @@ def ptx_wgmma_rs( ) +def ptx_wgmma_sp_ss( + dtype, + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + E_data, + E_offset, + sparse_selector, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_sp_ss"), + wgmma_prefix, + a_is_k_major, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_desc, + A_offset, + E_data, + E_offset, + sparse_selector, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + +def ptx_wgmma_sp_rs( + dtype, + wgmma_prefix, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + E_buf, + E_offset, + sparse_selector, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, +): + return call_intrin( + dtype, + _tvm_op.Op.get("tl.ptx_wgmma_sp_rs"), + wgmma_prefix, + b_is_k_major, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf, + A_offset, + E_buf, + E_offset, + sparse_selector, + B_desc, + B_offset, + C_data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) + + def ptx_tcgen05_mma_ss( kind_dtype, desc_a, diff --git a/tilelang/tileop/__init__.py b/tilelang/tileop/__init__.py index 849ba541e1..3a6ccd6cdf 100644 --- a/tilelang/tileop/__init__.py +++ b/tilelang/tileop/__init__.py @@ -1,3 +1,3 @@ from .base import GemmWarpPolicy # noqa: F401 from .gemm import Gemm # noqa: F401 -from .gemm_sp import GemmSPPy # noqa: F401 +from .gemm_sp import GemmSP # noqa: F401 diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py index 1a49b86ec3..a8d6185f18 100644 --- a/tilelang/tileop/gemm_sp/__init__.py +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -1,38 +1,26 @@ from tilelang import tvm as tvm +from tilelang.tileop.gemm_sp.registry import resolve_gemm_sp_impl from tvm import tir from tvm.target import Target from tvm.ir.base import Node from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi -from .registry import resolve_gemm_sp_impl -from tilelang.tileop.base import GemmWarpPolicy +from tilelang import _ffi_api +from tilelang.ir import GemmSPWarpPolicy -@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout") -def gemm_sp_py_infer_layout(gemm_sp_py, target: Target, thread_bounds: Range): - thread_nums = thread_bounds.extent - return gemm_sp_py.infer_layout(target, thread_nums) - - -@tvm_ffi.register_global_func("tl.gemm_sp_py.lower") -def gemm_sp_py_lower(gemm_sp_py, target: Target, thread_bounds: Range, thread_var: tir.Var): - thread_nums = thread_bounds.extent - stmt = gemm_sp_py.lower(target, thread_nums, thread_var) - return stmt - - -@tvm_ffi.register_object("tl.GemmSPPy") -class GemmSPPy(Node, Scriptable): +@tvm_ffi.register_object("tl.GemmSP") +class GemmSP(Node, Scriptable): A: tir.Buffer E: tir.Buffer B: tir.Buffer C: tir.Buffer - APtr: tir.PrimExpr - EPtr: tir.PrimExpr - BPtr: tir.PrimExpr - CPtr: tir.PrimExpr + aRegion: tir.BufferRegion + eRegion: tir.BufferRegion + bRegion: tir.BufferRegion + cRegion: tir.BufferRegion M: int N: int @@ -40,20 +28,40 @@ class GemmSPPy(Node, Scriptable): trans_A: bool trans_B: bool + trans_E: bool stride_A: int stride_B: int offset_A: int offset_B: int clear_accum: bool - k_pack: int + kPack: int wg_wait: int - policy: GemmWarpPolicy + policy: GemmSPWarpPolicy + + @tvm_ffi.register_global_func("tl.gemm_sp.infer_layout") + def gemm_sp_infer_layout(self, target: Target, thread_bounds: Range): + thread_nums = thread_bounds.extent + return self.infer_layout(target, thread_nums) + + @tvm_ffi.register_global_func("tl.gemm_sp.lower") + def gemm_sp_lower(self, target: Target, layout_map: dict, thread_bounds: Range, thread_var: tir.Var): + thread_nums = thread_bounds.extent + stmt = self.lower(target, layout_map, thread_nums, thread_var) + return stmt def infer_layout(self, target: Target, thread_nums: int): - impl_class = resolve_gemm_sp_impl(target) + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst, target) return impl_class(self).infer_layout(target, thread_nums) - def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): - impl_class = resolve_gemm_sp_impl(target) - return impl_class(self).lower(target, thread_nums, thread_var) + def lower(self, target: Target, layout_map: dict, thread_nums: int, thread_var: tir.Var): + gemm_inst = self._select_gemm_instruction(thread_nums, target) + impl_class = self._get_implementation_class(gemm_inst, target) + return impl_class(self).lower(layout_map, target, thread_nums, thread_var) + + def _select_gemm_instruction(self, thread_nums: int, target: Target) -> str: + return str(_ffi_api.GemmSPGetGemmInstructionKey(self, int(thread_nums), target)) + + def _get_implementation_class(self, gemm_inst: str, target: Target): + return resolve_gemm_sp_impl(gemm_inst, target) diff --git a/tilelang/tileop/gemm_sp/gemm_sp_base.py b/tilelang/tileop/gemm_sp/gemm_sp_base.py index 3e6ae7c8fc..81e05162b8 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_base.py +++ b/tilelang/tileop/gemm_sp/gemm_sp_base.py @@ -3,7 +3,7 @@ from tvm.target import Target from tvm import tir from tilelang.utils.language import is_shared, is_fragment -from tilelang.tileop.base import GemmWarpPolicy +from tilelang.ir import GemmSPWarpPolicy from tvm.ir.base import Node @@ -127,5 +127,5 @@ def wg_wait(self) -> int: return self.gemm_sp_node.wg_wait @property - def policy(self) -> GemmWarpPolicy: + def policy(self) -> GemmSPWarpPolicy: return self.gemm_sp_node.policy diff --git a/tilelang/tileop/gemm_sp/gemm_sp_wgmma.py b/tilelang/tileop/gemm_sp/gemm_sp_wgmma.py new file mode 100644 index 0000000000..676164c900 --- /dev/null +++ b/tilelang/tileop/gemm_sp/gemm_sp_wgmma.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase +from tilelang.tileop.gemm.inst import GemmInst +from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, + Layout, +) +from tilelang.cuda.intrinsics.macro.wgmma_sp_macro_generator import WGSparseTensorCoreIntrinEmitter +from tilelang.utils.language import is_shared, is_fragment +from tvm.target import Target +from tvm.ir import Range +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify +from typing import Callable + + +class GemmSPWGMMA(GemmSPBase): + def infer_shared_layout(self, continuity: int) -> Callable[[tir.Buffer], Layout]: + vectorized_size = 128 // self.in_dtype.bits + if continuity % (vectorized_size * 8) == 0: + return make_full_bank_swizzled_layout + elif continuity % (vectorized_size * 4) == 0: + return make_half_bank_swizzled_layout + elif continuity % (vectorized_size * 2) == 0: + return make_quarter_bank_swizzled_layout + else: + return make_linear_layout + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.WGMMA) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = WGSparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + a_continuity = self.K // mma_emitter.SPARSE_FACTOR if a_is_k_major else mma_emitter.wgmma_inst_m + b_continuity = self.K if b_is_k_major else mma_emitter.wgmma_inst_n + if self.is_gemm_ss(): + return { + # WGMMA does not support padding + self.A: self.infer_shared_layout(a_continuity)(self.A), + self.B: self.infer_shared_layout(b_continuity)(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + elif self.is_gemm_rs(): + return { + self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"), + self.B: self.infer_shared_layout(b_continuity)(self.B), + self.C: mma_emitter.make_mma_store_layout(self.C), + } + else: + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def lower( + self, + layout_map: dict, + target: Target, + thread_nums: Range, + thread_var: tir.Var, + mbar_phase_expr: tir.PrimExpr | None = None, + ): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.WGMMA) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = WGSparseTensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + e_dtype=self.e_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + e_transposed=self.trans_E, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + warp_k=self.K, + thread_var=thread_var, + ) + + if self.A in layout_map: + mma_emitter._assign_a_shared_layout(layout_map[self.A]) + if self.B in layout_map: + mma_emitter._assign_b_shared_layout(layout_map[self.B]) + + # Get base offsets from regions + # All dimensions may have offsets, including the matrix dimensions + # However, for WGMMA, we pass the Buffer directly and handle offsets + # through proper indexing in the access_ptr call or buffer slicing + + # We use region for memory input to support strided gemm + # T.gemm(A_shared[0:128, :], B_shared, C_local) + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + E_region = self.ERegion + + clear_accum = self.clear_accum + wg_wait = self.wg_wait + + if self.is_gemm_ss(): + # For WGMMA, we need to handle buffer region offsets + # If there are offsets, we create a BufferLoad inside the prim_func + # to properly generate offset access + + @T.prim_func + def _gemm_ssr() -> None: + """ + The inner macro that loads data from shared buffers A_shared and + B_shared into local fragments, then issues Tensor Core mma ops, + accumulating into C_local. + """ + + # Perform Matrix Multiplication with offset consideration + mma_emitter.wgmma_ss(A_region, E_region, B_region, C_region, clear_accum, wg_wait) + + # Simplify to optimize the index computing + # Must inline let statements to simplify the analysis + return _Simplify(_gemm_ssr, inline_let=True) + elif self.is_gemm_rs(): + + @T.prim_func + def _gemm_rsr() -> None: + mma_emitter.wgmma_rs(A_region, E_region, B_region, C_region, clear_accum, wg_wait) + + return _Simplify(_gemm_rsr, inline_let=True) + raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}") + + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def is_gemm_sr(self) -> bool: + return is_shared(self.A) and is_fragment(self.B) + + def is_gemm_rs(self) -> bool: + return is_fragment(self.A) and is_shared(self.B) + + def is_gemm_rr(self) -> bool: + return is_fragment(self.A) and is_fragment(self.B) diff --git a/tilelang/tileop/gemm_sp/registry.py b/tilelang/tileop/gemm_sp/registry.py index 81fd107c8f..ee3b60de66 100644 --- a/tilelang/tileop/gemm_sp/registry.py +++ b/tilelang/tileop/gemm_sp/registry.py @@ -12,6 +12,7 @@ @dataclass(frozen=True) class GemmSPImplEntry: name: str + inst_name: str predicate: GemmSPTargetPredicate impl_class: type @@ -21,11 +22,12 @@ class GemmSPImplEntry: def register_gemm_sp_impl( name: str, + inst_name: str, predicate: GemmSPTargetPredicate, impl_class: type, ) -> None: """Register a backend-specific GEMM_SP Python implementation class.""" - entry = GemmSPImplEntry(name, predicate, impl_class) + entry = GemmSPImplEntry(name, inst_name, predicate, impl_class) for idx, registered in enumerate(_GEMM_SP_IMPLS): if registered.name == name: _GEMM_SP_IMPLS[idx] = entry @@ -33,9 +35,9 @@ def register_gemm_sp_impl( _GEMM_SP_IMPLS.append(entry) -def resolve_gemm_sp_impl(target: Target) -> type: +def resolve_gemm_sp_impl(gemm_inst: str, target: Target) -> type: """Resolve the registered GEMM_SP implementation class for a target.""" - matches = [entry for entry in _GEMM_SP_IMPLS if entry.predicate(target)] + matches = [entry for entry in _GEMM_SP_IMPLS if entry.inst_name == gemm_inst and entry.predicate(target)] if not matches: raise ValueError(f"No GEMM_SP implementation registered for target {target}") if len(matches) > 1: diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index d56554e7f3..ff745702c1 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -1,145 +1,278 @@ -from __future__ import annotations -import contextlib -import os import torch -import warnings -from tilelang.contrib import nvcc -from tilelang.utils.tensor import is_float8_dtype, fp8_remove_negative_zeros_ -from torch.utils.cpp_extension import load, _import_module_from_library -from tilelang import env - -# Include version information to ensure different versions use separate caches -from tilelang import __version__ - -# Define paths -compress_util = os.path.join(env.TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") -# Cache directory for compiled extensions -_TORCH_CUDA_VERSION = torch.version.cuda or "cpu" -_CACHE_DIR = os.path.join( - env.TILELANG_CACHE_DIR, - "sparse_compressor", - __version__, - f"torch_cuda_{_TORCH_CUDA_VERSION}", +from typing import Optional +import tilelang +import tilelang.language as T +from tilelang.language.dtypes import _TORCH_DTYPE_TO_STR, dtype + +GROUP_CONFIG: dict[dtype, tuple[int, int]] = { + T.float: (1, 2), + T.float16: (2, 4), + T.bfloat16: (2, 4), + T.int8: (2, 4), + T.uint8: (2, 4), + T.float8_e4m3: (2, 4), + T.float8_e5m2: (2, 4), +} + +_BITS_PER_GROUP = 4 + + +def get_e_factor(a_dtype: dtype, meta_dtype: dtype) -> int: + """Return how many a_dtype elements are indexed by one meta_dtype element.""" + _, group = GROUP_CONFIG[a_dtype] + return (dtype(meta_dtype).bits // _BITS_PER_GROUP) * group + + +def get_e_replicate_factor(a_dtype: dtype) -> int: + """Return how many consecutive threads share the same logical metadata value.""" + return 1 if dtype(a_dtype).bits <= 8 else 2 + + +def _to_tl_dtype(torch_dtype: torch.dtype) -> dtype: + return dtype(_TORCH_DTYPE_TO_STR[torch_dtype]) + + +_ELEM_PER_THREAD = 32 +_BLOCK_M = 16 +_BLOCK_K = 1024 +_DEFAULT_META_DTYPE = T.int16 + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, ) -os.makedirs(_CACHE_DIR, exist_ok=True) - - -def _torch_cuda_runtime_link_dir() -> str | None: - """Return a cache-local lib dir that makes -lcudart resolve to PyTorch's cudart.""" - cuda_version = torch.version.cuda - if not cuda_version: - return None - - major = cuda_version.split(".", maxsplit=1)[0] - torch_root = os.path.dirname(torch.__file__) - runtime_dir = os.path.abspath(os.path.join(torch_root, "..", "nvidia", f"cu{major}", "lib")) - runtime_lib = os.path.join(runtime_dir, f"libcudart.so.{major}") - if not os.path.exists(runtime_lib): - return None - - link_dir = os.path.join(_CACHE_DIR, "cuda_runtime_lib") - os.makedirs(link_dir, exist_ok=True) - link_path = os.path.join(link_dir, "libcudart.so") - if os.path.lexists(link_path) and os.path.realpath(link_path) != os.path.realpath(runtime_lib): - os.remove(link_path) - with contextlib.suppress(FileExistsError): - os.symlink(runtime_lib, link_path) - return link_dir - - -def _get_cached_lib(): - name = "compress_lib" - - if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")): - try: - return _import_module_from_library(name, _CACHE_DIR, is_python_module=True) - except Exception: - pass - - # Set TORCH_CUDA_ARCH_LIST - env._initialize_torch_cuda_arch_flags() - extra_ldflags = [] - runtime_link_dir = _torch_cuda_runtime_link_dir() - if runtime_link_dir is not None: - extra_ldflags.append(f"-L{runtime_link_dir}") - runtime_dir = os.path.dirname(os.path.realpath(os.path.join(runtime_link_dir, "libcudart.so"))) - extra_ldflags.append(f"-Wl,-rpath,{runtime_dir}") - - # Compile if not cached or loading failed - return load( - name=name, - sources=[compress_util], - extra_cuda_cflags=[ - "-O2", - "-std=c++17", - "-lineinfo", - f"-I{env.CUTLASS_INCLUDE_DIR}", - f"-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include", - "-arch=sm_90", - ], - extra_ldflags=extra_ldflags, - build_directory=_CACHE_DIR, - ) - - -def compress_sm90(A: torch.Tensor, block_k: int, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: - if block_k > 128: - block_k = 128 - # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 - warnings.warn(f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) - # Load the library (will use cache if available) - compress_lib = _get_cached_lib() - - return compress_lib.compress_sm90(A, block_k, transposed) - - -def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: - try: - from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - except ImportError as err: - raise ImportError( - "SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version." - ) from err - orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS - try: - SparseSemiStructuredTensor._FORCE_CUTLASS = True - if transposed is not False: - raise NotImplementedError("transposed flag is deprecated by pytorch") - compressed = to_sparse_semi_structured(A) - return compressed.packed, compressed.meta - finally: - SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val - - -def compress(A: torch.Tensor, transposed: bool, arch: str | None = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: +def _compress_fn(D, dtype, meta_dtype, block_M=_BLOCK_M, block_K=_BLOCK_K, elem_per_thread=_ELEM_PER_THREAD): + e_factor = get_e_factor(dtype, meta_dtype) + S = T.dynamic("S") + assert elem_per_thread >= e_factor + assert block_K % elem_per_thread == 0 + + if dtype.bits <= 16: + elem, group = 2, 4 + + @T.prim_func + def compress_8bit_16bit_ordered_metadata( + dense: T.Tensor([S, D], dtype), + nonzero: T.Tensor([S, D * elem // group], dtype), + meta: T.Tensor([S, D // e_factor], meta_dtype), + ): + with T.Kernel(S // block_M, D // block_K, threads=(block_M, block_K // elem_per_thread)) as (bz, bk): + tm = T.get_thread_binding(0) + tn = T.get_thread_binding(1) + dense_local = T.alloc_local([elem_per_thread], dtype) + sparse_local = T.alloc_local([elem_per_thread * elem // group], dtype) + meta_local = T.alloc_local([elem_per_thread // e_factor], meta_dtype) + nz_idx = T.alloc_local([elem], T.uint8) + nz_count = T.alloc_var(dtype=T.uint8) + + T.clear(sparse_local) + T.clear(meta_local) + + k_base = bk * block_K + T.copy( + dense[bz * block_M + tm, bk * block_K + tn * elem_per_thread : k_base + (tn + 1) * elem_per_thread], + dense_local, + ) + + for gid in T.unroll(elem_per_thread // group): + T.clear(nz_idx) + local_idx = gid * group + + nz_count = 0 + for i in T.unroll(group): + nz_idx[nz_count] = T.if_then_else(dense_local[local_idx + i] != 0, i, nz_idx[nz_count]) + nz_count = T.if_then_else(dense_local[local_idx + i] != 0, nz_count + 1, nz_count) + + T.device_assert(nz_count <= elem, "More nonzeros than expected in a group") + + if nz_count == 1: + if nz_idx[0] == 0: + nz_idx[1] = 1 + else: + nz_idx[0], nz_idx[1] = nz_idx[1], nz_idx[0] + elif nz_count == 0: + nz_idx[0], nz_idx[1] = 0, 1 + + for i in T.unroll(elem): + sparse_local[local_idx * elem // group + i] = dense_local[local_idx + nz_idx[i]] + meta_local[local_idx // e_factor] |= T.shift_left( + nz_idx[i].astype(meta_dtype), + (4 * (gid % (e_factor // group)) + 2 * i), + ) + + sparse_per_thread = elem_per_thread * elem // group + sparse_base = k_base * elem // group + meta_base = k_base // e_factor + T.copy( + sparse_local, + nonzero[ + bz * block_M + tm, + sparse_base + tn * sparse_per_thread : sparse_base + (tn + 1) * sparse_per_thread, + ], + ) + T.copy( + meta_local, + meta[ + bz * block_M + tm, + meta_base + tn * (elem_per_thread // e_factor) : meta_base + (tn + 1) * (elem_per_thread // e_factor), + ], + ) + + return compress_8bit_16bit_ordered_metadata + elif dtype.bits == 32: + elem, group = 1, 2 + + @T.prim_func + def compress_32bit_ordered_metadata( + dense: T.Tensor([S, D], dtype), + nonzero: T.Tensor([S, D * elem // group], dtype), + meta: T.Tensor([S, D // e_factor], meta_dtype), + ): + with T.Kernel(S // block_M, D // block_K, threads=(block_M, block_K // elem_per_thread)) as (bz, bk): + tm = T.get_thread_binding(0) + tn = T.get_thread_binding(1) + dense_local = T.alloc_local([elem_per_thread], dtype) + sparse_local = T.alloc_local([elem_per_thread * elem // group], dtype) + meta_local = T.alloc_local([elem_per_thread // e_factor], meta_dtype) + nz_idx = T.alloc_local([elem], T.uint8) + nz_count = T.alloc_var(dtype=T.uint8) + + T.clear(sparse_local) + T.clear(meta_local) + + k_base = bk * block_K + T.copy( + dense[bz * block_M + tm, k_base + tn * elem_per_thread : k_base + (tn + 1) * elem_per_thread], + dense_local, + ) + + for gid in T.unroll(elem_per_thread // group): + T.clear(nz_idx) + local_idx = gid * group + + nz_count = 0 + for i in T.unroll(group): + nz_idx[nz_count] = T.if_then_else(dense_local[local_idx + i] != 0, i, nz_idx[nz_count]) + nz_count = T.if_then_else(dense_local[local_idx + i] != 0, nz_count + 1, nz_count) + + T.device_assert(nz_count <= elem, "More nonzeros than expected in a group") + + if nz_count == 0: + sparse_local[local_idx * elem // group] = 0 + meta_local[local_idx // e_factor] |= T.shift_left(0b0100, 4 * (gid % (e_factor // group))) + else: + sparse_local[local_idx * elem // group] = dense_local[local_idx + nz_idx[0]] + meta_local[local_idx // e_factor] |= T.shift_left( + T.if_then_else(nz_idx[0] == 0, 0b0100, 0b1110), + 4 * (gid % (e_factor // group)), + ) + + sparse_per_thread = elem_per_thread * elem // group + sparse_base = k_base * elem // group + meta_base = k_base // e_factor + T.copy( + sparse_local, + nonzero[ + bz * block_M + tm, + sparse_base + tn * sparse_per_thread : sparse_base + (tn + 1) * sparse_per_thread, + ], + ) + T.copy( + meta_local, + meta[ + bz * block_M + tm, + meta_base + tn * (elem_per_thread // e_factor) : meta_base + (tn + 1) * (elem_per_thread // e_factor), + ], + ) + + return compress_32bit_ordered_metadata + + +def torch_compress(dense: torch.Tensor, meta_dtype: Optional[torch.dtype] = None) -> tuple[torch.Tensor, torch.Tensor]: # noqa: FA100 """ - Compress a tensor using the appropriate method based on the CUDA architecture. + Reference 2:4 sparse compressor in pure PyTorch with natural row-major metadata. Modified from https://github.com/pytorch/pytorch/blob/bfa6895a345f6568624a4769238af6a9225e3fb8/torch/sparse/_semi_structured_conversions.py#L47 + + Each 4-bit chunk of the metadata integer encodes the two nonzero positions + within one group of 4 consecutive elements: + bits [1:0] = index of first nonzero (0-3) + bits [3:2] = index of second nonzero (0-3) + """ - if arch is None: - arch = nvcc.get_target_compute_version() - - compute_version = nvcc.parse_compute_version(arch) - - if compute_version >= (9, 0): - return compress_sm90(A, transposed=transposed, **kwargs) - elif compute_version >= (8, 0): - if transposed: - A = A.t().contiguous() - origin_dtype = A.dtype - if is_float8_dtype(origin_dtype): - fp8_remove_negative_zeros_(A) - A = A.view(torch.int8) - A_sp, E = compress_sm80(A, transposed=False) - if is_float8_dtype(origin_dtype): - A_sp = A_sp.view(origin_dtype) - if transposed: - A_sp = A_sp.t().contiguous() - return A_sp, E + if dense.dim() != 2: + raise RuntimeError(f"Expected 2D tensor, got {dense.dim()}D") + m, k = dense.shape + + is_32bit = dense.dtype == torch.float32 + ksparse = 2 if is_32bit else 4 + # int8 uses int32 metadata to match CUTLASS convention; all others use int16 + if meta_dtype is None: + meta_dtype = torch.int32 if dense.dtype == torch.int8 else torch.int16 + quadbits = meta_dtype.itemsize * 8 // 4 # 4-bit groups that fit in one meta element + + # 8-bit non-integer types (float8 variants) may not support gather; view as int8 + gather_dtype = torch.int8 if (dense.element_size() == 1 and dense.dtype != torch.int8) else None + work = dense.view(gather_dtype) if gather_dtype is not None else dense + + groups = work.view(-1, k // ksparse, ksparse) + nz = groups != 0 + if not is_32bit: + m0, m1, _m2, m3 = nz.unbind(-1) else: - raise ValueError(f"Unsupported CUDA compute version: {compute_version}. Supported versions are sm_80 and sm_90.") + m0, _m2 = m1, m3 = nz.unbind(-1) + + meta_ncols = k // (ksparse * quadbits) + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + idxs0 = expr1.to(torch.int64) | (expr2.to(torch.int64) << 1) + idxs1 = (expr0 | expr2 | m3).to(torch.int64) | ((expr1 | ~m1).to(torch.int64) << 1) -def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): + if not is_32bit: + sp0 = groups.gather(-1, idxs0.unsqueeze(-1)) + sp1 = groups.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sp0, sp1), dim=-1).view(m, k // 2) + else: + sparse = groups.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) + + if gather_dtype is not None: + sparse = sparse.view(dense.dtype) + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view(-1, meta_ncols, quadbits).to(meta_dtype) + # Pack 4-bit chunks into each meta element (little-endian) + meta = meta_n[:, :, 0] + for i in range(1, quadbits): + meta = meta | (meta_n[:, :, i] << (4 * i)) + + return sparse, meta + + +def compress( + A: torch.Tensor, + meta_dtype: Optional[torch.dtype] = None, # noqa: FA100 + block_m: Optional[int] = None, # noqa: FA100 + block_k: Optional[int] = None, # noqa: FA100 +) -> tuple[torch.Tensor, torch.Tensor]: + assert A.is_contiguous(), "Input must be contiguous" + assert A.dim() == 2, "Input must be 2D" + + tl_meta_dtype = _to_tl_dtype(meta_dtype) if meta_dtype is not None else _DEFAULT_META_DTYPE + S, D = A.shape + block_m = min(_BLOCK_M, S) if block_m is None else block_m + block_k = min(_BLOCK_K, D) if block_k is None else block_k + assert block_k % _ELEM_PER_THREAD == 0, f"block_k={block_k} must be divisible by {_ELEM_PER_THREAD}" + assert D % block_k == 0, f"Last dim D={D} must be divisible by block_k={block_k}" + assert S % block_m == 0, f"Rows S={S} must be divisible by block_M={block_m}" + assert D % _ELEM_PER_THREAD == 0, f"Last dim D={D} must be divisible by {_ELEM_PER_THREAD}" + + A_sparse, E = _compress_fn(D, _to_tl_dtype(A.dtype), tl_meta_dtype, block_m, block_k, _ELEM_PER_THREAD)(A) + + return A_sparse, E + + +def randn_semi_sparse(M: int, K: int, dtype: torch.dtype = torch.float16, device: torch.device = "cuda", transposed: bool = False): """ Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -149,9 +282,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transp device: Device to create the tensor on transposed (bool): If True, returns a transposed tensor of shape (K, M) """ - elem, group = 2, 4 - if dtype == torch.float32: - elem, group = 1, 2 + elem, group = GROUP_CONFIG[_to_tl_dtype(dtype)] tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group) indice = tensor.topk(elem, dim=-1).indices tensor.scatter_(-1, indice, 0) @@ -161,7 +292,15 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transp return tensor.to(dtype) # dtype like float8 might not have randn kernel -def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device="cuda", transposed: bool = False): +def randint_semi_sparse( + M: int, + K: int, + low: int, + high: int, + dtype: torch.dtype = torch.int32, + device: torch.device = "cuda", + transposed: bool = False, +): """ Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -173,9 +312,7 @@ def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device: Device to create the tensor on transposed (bool): If True, returns a transposed tensor of shape (K, M) """ - elem, group = 2, 4 - if dtype == torch.float32: - elem, group = 1, 2 + elem, group = GROUP_CONFIG[_to_tl_dtype(dtype)] tensor = torch.randint(low, high, (M, K), dtype=dtype, device=device).view(M, -1, group) indice = tensor.topk(elem, dim=-1).indices tensor.scatter_(-1, indice, 0) @@ -185,7 +322,7 @@ def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, return tensor -def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False): +def arange_semi_sparse(M: int, K: int, dtype: torch.dtype = torch.float16, device: torch.device = "cuda", transposed: bool = False): """ Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. Args: @@ -195,9 +332,7 @@ def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", trans device: Device to create the tensor on transposed (bool): If True, returns a transposed tensor of shape (K, M) """ - elem, group = 2, 4 - if dtype == torch.float32: - elem, group = 1, 2 + elem, group = GROUP_CONFIG[_to_tl_dtype(dtype)] tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group) indice = tensor.topk(elem, dim=-1).indices tensor.scatter_(-1, indice, 0)