diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py deleted file mode 100644 index 4ef860c210..0000000000 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ /dev/null @@ -1,316 +0,0 @@ -import argparse -import logging -from tilelang import tvm as tvm -from tvm import DataType -import tilelang as tl -import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func -from tilelang.autotuner import autotune -import itertools - -# Configure logger -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - block_row_warps=1, - block_col_warps=1, - warp_row_tiles=16, - warp_col_tiles=16, - chunk=32, - stage=2, - enable_rasteration=False, -): - assert in_dtype in [ - T.float16, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - if out_dtype == T.int32: - micro_size_k = 32 - - # This is a debug config - # chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M, - block_N, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10, enable=enable_rasteration) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a(A_local, A_shared, ki) - - # Load B into fragment - mma_emitter.ldmatrix_b(B_local, B_shared, ki) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix(C_local, C_shared) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[i, j] - - return main - - -def ref_program(A, B): - """Reference matrix multiplication program.""" - return A @ B.T - - -def get_configs(args, kwargs): - """ - Generate a list of configuration dictionaries that will be used for tuning. - - Parameters - ---------- - with_roller : bool - Whether to enable bitblas roller to deduce search spaces - - Returns - ------- - list of dict - Each configuration dict includes various block sizes, pipeline stages, - thread numbers, and other parameters to explore during autotuning. - """ - M, N, K = args[:3] - with_roller = args[6] - - if with_roller: - from tilelang.carver.template import MatmulTemplate - from tilelang.carver.arch import CUDA - from tilelang.carver.arch import CDNA - from tilelang.carver.roller.rasterization import NoRasterization - import torch - - arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") - topk = 10 - - carve_template = MatmulTemplate( - M=M, - N=N, - K=K, - in_dtype=T.float16, - out_dtype=T.float16, - accum_dtype=T.float16, - ).with_arch(arch) - - func = carve_template.equivalent_function() - assert func is not None, "Function is None" - - roller_hints = carve_template.recommend_hints(topk=topk) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - configs = [] - for hint in roller_hints: - config = {} - block_m, block_n = hint.block - warp_m, warp_n = hint.warp - config["block_row_warps"] = block_m // warp_m - config["block_col_warps"] = block_n // warp_n - config["warp_row_tiles"] = warp_m - config["warp_col_tiles"] = warp_n - config["chunk"] = hint.rstep[0] - config["stage"] = hint.pipeline_stage - config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization - configs.append(config) - for config in configs: - print(config) - else: - iter_params = dict( - block_row_warps=[1, 2, 4], - block_col_warps=[1, 2, 4], - warp_row_tiles=[16, 32, 64, 128], - warp_col_tiles=[16, 32, 64, 128], - chunk=[32, 64, 128, 256], - stage=[0, 2], - enable_rasteration=[True, False], - ) - return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] - - return configs - - -@autotune( - configs=get_configs, - warmup=3, - rep=5, - ref_prog=ref_program, - skip_check=True, -) -@tl.jit( - out_idx=[2], -) -def matmul( - M, - N, - K, - in_dtype=T.float16, - out_dtype=T.float16, - accum_dtype=T.float16, - with_roller=False, - block_row_warps=None, - block_col_warps=None, - warp_row_tiles=None, - warp_col_tiles=None, - chunk=None, - stage=None, - enable_rasteration=None, -): - """Create an autotuned tensor core matrix multiplication kernel.""" - - def kernel(): - return tl_matmul( - M, - N, - K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - stage=stage, - enable_rasteration=enable_rasteration, - ) - - return kernel() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Autotuned TensorCore MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") - parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") - args = parser.parse_args() - - M, N, K = args.m, args.n, args.k - in_dtype = T.dtype(args.dtype) - out_dtype = T.float32 if in_dtype == T.int8 else T.float16 - accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 - with_roller = args.with_roller - with_roller = True - # Compute total floating-point operations - total_flops = 2 * M * N * K - - # Run autotuning - best_result = matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_roller) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency - - # Print benchmark results - print(f"Best latency (s): {best_latency}") - print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") - print(f"Best config: {best_config}") - print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index 12189eb8fa..076be9f0f8 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -62,7 +62,7 @@ Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplicatio ```python import tilelang import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index 9d7ebcf88c..031783910d 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -7,12 +7,12 @@ import tilelang.language as T from tilelang import tvm as tvm from tvm import DataType -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( make_mma_swizzle_layout as make_swizzle_layout, ) import numpy as np -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index a870208083..3343b43267 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -141,8 +141,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( accum_dtype, transform_b, ): - from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout - from tilelang.intrinsics.mma_macro_generator import ( + from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout as make_swizzle_layout + from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform, ) diff --git a/examples/gemm/README.md b/examples/gemm/README.md index 9ab7fb6614..fdd919dbf8 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -174,7 +174,7 @@ Below is a more advanced snippet that showcases how to apply memory layouts, ena import tilelang.language as T # `make_mma_swizzle_layout` is a python-defined layout function # that helps align data for MMA (Matrix Multiply-Accumulate) operations. -from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout as make_swizzle_layout def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 15e552587e..4c264c0e4f 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -2,8 +2,8 @@ from tvm import DataType import tilelang import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py index fc7fb44003..a82cb54084 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py @@ -6,7 +6,7 @@ import tilelang.language as T from tilelang.tileop.base import GemmWarpPolicy from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py deleted file mode 100644 index d9f749d9f2..0000000000 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ /dev/null @@ -1,248 +0,0 @@ -import torch -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter -from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter -from tilelang.utils import determine_fp8_type - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@tilelang.jit(out_idx=[2]) -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.float8_e4m3fn, - T.float8_e4m3fnuz, - T.float8_e5m2, - T.float8_e5m2fnuz, - T.int8, - ], "Currently only float16, float8, and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - is_hip = torch.version.hip is not None - # MMA Wrapper to Auto Generate Code for MMA/MFMA - if is_hip: - mma_emitter = MatrixCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - else: - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - micro_size_x = mma_emitter.M_DIM - micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x)) - micro_size_k = mma_emitter.k_dim - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - threads = mma_emitter.threads - local_size_a = mma_emitter.local_size_a - local_size_b = mma_emitter.local_size_b - local_size_c = mma_emitter.local_size_out - warp_rows = mma_emitter.warp_rows - warp_cols = mma_emitter.warp_cols - - @T.prim_func - def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - if is_hip: - mma_emitter.mfma(A_local, B_local, C_local, ki) - else: - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return gemm_fp8_intrinsic - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = in_dtype.as_torch() - out_dtype = out_dtype.as_torch() - accum_dtype = accum_dtype.as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) - - profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) - - C = profiler(A, B) - - latency = profiler.do_bench(warmup=25) - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -def main(): - e4m3_dtype = determine_fp8_type() - assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) - e5m2_dtype = determine_fp8_type("e5m2") - assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) - - -def run_regression_perf(): - M, N, K = 4096, 4096, 4096 - out_dtype, accum_dtype = T.float32, T.float32 - in_dtype = determine_fp8_type() - kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - if torch.version.hip is None: - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - else: - latency_e4m3 = profiler_e4m3.do_bench() - return latency_e4m3 - - -if __name__ == "__main__": - main() diff --git a/examples/gemm_fp8/regression_example_gemm_fp8.py b/examples/gemm_fp8/regression_example_gemm_fp8.py index 3ba2f4f274..5bf0c80505 100644 --- a/examples/gemm_fp8/regression_example_gemm_fp8.py +++ b/examples/gemm_fp8/regression_example_gemm_fp8.py @@ -1,17 +1,12 @@ import tilelang.testing import example_tilelang_gemm_fp8 import example_tilelang_gemm_fp8_2xAcc -import example_tilelang_gemm_fp8_intrinsic def regression_example_tilelang_gemm_fp8_2xAcc(): tilelang.testing.process_func(example_tilelang_gemm_fp8_2xAcc.run_regression_perf) -def regression_example_tilelang_gemm_fp8_intrinsic(): - tilelang.testing.process_func(example_tilelang_gemm_fp8_intrinsic.run_regression_perf) - - def regression_example_tilelang_gemm_fp8(): tilelang.testing.process_func(example_tilelang_gemm_fp8.run_regression_perf) diff --git a/examples/gemm_fp8/test_example_gemm_fp8.py b/examples/gemm_fp8/test_example_gemm_fp8.py index 19a9ee00a7..3b657d72ae 100644 --- a/examples/gemm_fp8/test_example_gemm_fp8.py +++ b/examples/gemm_fp8/test_example_gemm_fp8.py @@ -1,6 +1,5 @@ import tilelang.testing import example_tilelang_gemm_fp8_2xAcc -import example_tilelang_gemm_fp8_intrinsic import example_tilelang_gemm_fp8 @@ -8,10 +7,6 @@ def test_example_tilelang_gemm_fp8_2xAcc(): example_tilelang_gemm_fp8_2xAcc.main() -def test_example_tilelang_gemm_fp8_intrinsic(): - example_tilelang_gemm_fp8_intrinsic.main() - - def test_example_tilelang_gemm_fp8(): example_tilelang_gemm_fp8.main() diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 65f463b71b..15efbf4467 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -1,6 +1,6 @@ import tilelang import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout import math import argparse diff --git a/examples/plot_layout/README.md b/examples/plot_layout/README.md index 8204e93d80..c2d3839e97 100644 --- a/examples/plot_layout/README.md +++ b/examples/plot_layout/README.md @@ -7,7 +7,7 @@ import tilelang.language as T from tvm import DataType from tvm.tir import IndexMap from typing import Literal, Callable -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.cuda.intrinsics.layout.utils import get_mma_micro_size from tilelang.tools import plot_layout def make_mma_load_base_layout(dtype: str = T.float16, @@ -36,7 +36,7 @@ def make_mma_load_base_layout(dtype: str = T.float16, AssertionError If `local_buf` is not detected to be a fragment buffer. """ - from tilelang.intrinsics.mma_layout import ( + from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x16_to_mma_32x8_layout_sr, shared_16x16_to_mma_32x8_layout_rs, shared_16x32_to_mma_32x16_layout, diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py index d45cc227bc..20a5cbba48 100644 --- a/examples/plot_layout/fragment_mfma_load_a.py +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -1,9 +1,9 @@ import tilelang.language as T from typing import Literal, Callable from tvm.tir import IndexMap -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.rocm.intrinsics.utils import get_mma_micro_size -from tilelang.intrinsics.mfma_layout import ( +from tilelang.rocm.intrinsics.mfma_layout import ( shared_16x4_to_local_64x1_layout_A, shared_16x16_to_local_64x4_layout_A, shared_16x32_to_local_64x8_layout_A, diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index df4a0b8870..7ac6bff30e 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -2,7 +2,7 @@ from typing import Literal, Callable from tvm import DataType from tvm.tir import IndexMap -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.cuda.intrinsics.layout.utils import get_mma_micro_size def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: @@ -26,7 +26,7 @@ def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", " Describes how threads and indices in fragment are laid out. """ - from tilelang.intrinsics.mma_layout import ( + from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_a, shared_16x32_to_mma_32x16_layout_sr_a, diff --git a/src/backend/cuda/CMakeLists.txt b/src/backend/cuda/CMakeLists.txt index 40ac455411..5918282457 100644 --- a/src/backend/cuda/CMakeLists.txt +++ b/src/backend/cuda/CMakeLists.txt @@ -137,7 +137,7 @@ if(TILELANG_USE_CUDA_STUBS) endif() file(GLOB TILE_LANG_CUDA_SRCS - src/runtime/runtime.cc + src/backend/cuda/runtime.cc src/backend/cuda/codegen/ptx.cc src/backend/cuda/codegen/codegen_cuda.cc src/backend/cuda/codegen/codegen_py.cc diff --git a/src/runtime/runtime.cc b/src/backend/cuda/runtime.cc similarity index 99% rename from src/runtime/runtime.cc rename to src/backend/cuda/runtime.cc index f6112d8420..2c56dfef0d 100644 --- a/src/runtime/runtime.cc +++ b/src/backend/cuda/runtime.cc @@ -1,5 +1,5 @@ /*! - * \file tl/runtime/runtime.h + * \file tl/backend/cuda/runtime.cc * \brief Runtime functions. * */ diff --git a/src/runtime/runtime.h b/src/backend/cuda/runtime.h similarity index 80% rename from src/runtime/runtime.h rename to src/backend/cuda/runtime.h index 4b389fc03e..90540fd789 100644 --- a/src/runtime/runtime.h +++ b/src/backend/cuda/runtime.h @@ -1,11 +1,11 @@ /*! - * \file tl/runtime/runtime.h + * \file tl/backend/cuda/runtime.h * \brief Runtime functions. * */ -#ifndef TVM_TL_RUNTIME_RUNTIME_H_ -#define TVM_TL_RUNTIME_RUNTIME_H_ +#ifndef TVM_TL_BACKEND_CUDA_RUNTIME_H_ +#define TVM_TL_BACKEND_CUDA_RUNTIME_H_ namespace tvm { namespace tl { @@ -25,4 +25,4 @@ constexpr const char *tvm_cuda_stream_reset_access_policy_window = } // namespace tl } // namespace tvm -#endif // TVM_TL_RUNTIME_RUNTIME_H_ +#endif // TVM_TL_BACKEND_CUDA_RUNTIME_H_ diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 18b405f2bb..e9ea2cdbc4 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -1,5 +1,5 @@ /*! - * \file lower hopper intrin.cc + * \file tl/transform/lower_hopper_intrin.cc * \brief Lower Hopper intrinsics cuda GPU(sm90+) */ @@ -13,8 +13,8 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "backend/cuda/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/src/transform/lower_l2_persistent_annotation.cc b/src/transform/lower_l2_persistent_annotation.cc index 1f7be710de..5f9f44a5c2 100644 --- a/src/transform/lower_l2_persistent_annotation.cc +++ b/src/transform/lower_l2_persistent_annotation.cc @@ -1,5 +1,5 @@ /*! - * \file lower_l2_persistent_annotation.cc + * \file tl/transform/lower_l2_persistent_annotation.cc * \brief Lower L2 persistent annotation */ @@ -9,8 +9,7 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/src/transform/persist_threadblock.cc b/src/transform/persist_threadblock.cc index b64ffdcce8..d9183d1e2b 100644 --- a/src/transform/persist_threadblock.cc +++ b/src/transform/persist_threadblock.cc @@ -1,6 +1,6 @@ /*! - * \file lower_l2_persistent_annotation.cc - * \brief Lower L2 persistent annotation + * \file tl/transform/persist_threadblock.cc + * \brief Persist thread blocks with cooperative groups. */ #include @@ -9,8 +9,7 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 3fe33aebf0..00fac1a3a3 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -3,8 +3,8 @@ import tilelang.testing from tilelang import tvm as tvm import tilelang.language as T -from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout -from tilelang.intrinsics.mfma_macro_generator import ( +from tilelang.rocm.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.rocm.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index abedd1f19b..864ac58c7b 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -4,8 +4,8 @@ import tilelang.testing from tilelang import tvm as tvm import tilelang.language as T -from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout -from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.rocm.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter from tilelang.transform import simplify_prim_func from tilelang.utils import determine_fp8_type diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index 33eef09a56..12ff9c0586 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -4,8 +4,8 @@ import tilelang.testing from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py deleted file mode 100644 index 501b38fda8..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py +++ /dev/null @@ -1,109 +0,0 @@ -import tilelang.testing -from tilelang import language as T -import torch - - -def elementwise_add( - M, - N, - block_M, - block_N, - in_dtype, - out_dtype, - threads, -): - @T.prim_func - def main( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - start_x = bx * block_N - start_y = by * block_M - - for local_y, local_x in T.Parallel(block_M, block_N): - y = start_y + local_y - x = start_x + local_x - - C[y, x] = A[y, x] + B[y, x] - - return main - - -def run_elementwise_add( - M, - N, - in_dtype, - out_dtype, - block_M, - block_N, - num_threads=128, -): - program = elementwise_add( - M, - N, - block_M, - block_N, - in_dtype, - out_dtype, - num_threads, - ) - - kernel = tilelang.compile(program, out_idx=[2]) - profiler = kernel.get_profiler() - - def ref_program(A, B): - C = torch.add(A, B) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -def test_elementwise_add_f32(): - run_elementwise_add( - 512, - 1024, - T.float32, - T.float32, - 128, - 256, - ) - - -def test_elementwise_add_f16(): - run_elementwise_add( - 512, - 1024, - T.float16, - T.float16, - 128, - 256, - ) - - -def test_elementwise_add_i32(): - run_elementwise_add( - 512, - 1024, - T.int32, - T.int32, - 128, - 256, - ) - - -def test_elementwise_add_f32f16(): - run_elementwise_add( - 512, - 1024, - T.float32, - T.float16, - 128, - 256, - ) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py deleted file mode 100644 index f8793ba2e9..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.backends -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.float8_e4m3fn, - T.float8_e5m2, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - print(src_code) - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = T.dtype(in_dtype).as_torch() - out_dtype = T.dtype(out_dtype).as_torch() - accum_dtype = T.dtype(accum_dtype).as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = kernel(A, B) - - latency = profiler.do_bench() - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - print(C) - print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 9) -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py deleted file mode 100644 index 7f7f36c51d..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ /dev/null @@ -1,240 +0,0 @@ -import torch -import torch.backends -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.bfloat16, - T.float8_e4m3fn, - T.float8_e5m2, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = T.dtype(in_dtype).as_torch() - out_dtype = T.dtype(out_dtype).as_torch() - accum_dtype = T.dtype(accum_dtype).as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = kernel(A, B) - - latency = profiler.do_bench() - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype) - tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 0) -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) - assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) - assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 0) -def test_assert_tl_matmul_bfloat16(): - assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 9) -def test_assert_tl_matmul_fp8(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py index 5c52f432d0..dd96e38f1a 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -4,7 +4,7 @@ from tilelang import tvm as tvm from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics import get_swizzle_layout from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(0) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 27388911b7..78e38de6b9 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -160,7 +160,7 @@ def test_reshape_fragment(): def reshape_layout_transform_shared(N, M, dtype): - from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout + from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout @T.prim_func def main( diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index f042339d42..7446d73eae 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -2,7 +2,7 @@ import tilelang.testing import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout import pytest diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index 8ffffd8ce0..de7808d9f0 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -7,7 +7,7 @@ from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.tensor import torch_assert_close -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter torch.backends.cuda.matmul.allow_tf32 = False diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 32742a005f..921e3b4de2 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -3,7 +3,7 @@ from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.utils.tensor import torch_assert_close from tilelang.layout import make_cutlass_metadata_layout -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter import tilelang.testing import torch diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index 640e991828..67b063b139 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -1,5 +1,6 @@ from tilelang import tvm as tvm import tilelang as tl +from tilelang.cuda import transform as cuda_transform from tilelang.utils.target import determine_target import tilelang.language as T import tilelang.testing @@ -12,7 +13,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = cuda_transform.LowerHopperIntrin()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) @@ -90,7 +91,7 @@ def before(): mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = cuda_transform.LowerHopperIntrin()(mod) func = mod["main"] assert not tvm.tir.analysis.undefined_vars(func.body, func.params) diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 034ae3cbfb..f1c9fa9e4f 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -186,5 +186,8 @@ def _load_tile_lang_lib(): from .math import * # noqa: F403 from . import ir # noqa: F401 from . import tileop # noqa: F401 + from . import cpu as cpu # noqa: F401 + from . import cuda as cuda # noqa: F401 + from . import rocm as rocm # noqa: F401 del _lazy_load_lib diff --git a/tilelang/backend/__init__.py b/tilelang/backend/__init__.py deleted file mode 100644 index cc871934ca..0000000000 --- a/tilelang/backend/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .gemm import register_gemm_impl, resolve_gemm_impl # noqa: F401 -from .gemm_sp import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401 - -# Import built-in backend packages so their implementations register. -from . import cpu as _cpu # noqa: F401,E402 -from . import cuda as _cuda # noqa: F401,E402 -from . import rocm as _rocm # noqa: F401,E402 diff --git a/tilelang/cpu/__init__.py b/tilelang/cpu/__init__.py new file mode 100644 index 0000000000..e8fb2c24d2 --- /dev/null +++ b/tilelang/cpu/__init__.py @@ -0,0 +1 @@ +from . import op # noqa: F401 diff --git a/tilelang/backend/cpu/__init__.py b/tilelang/cpu/op/__init__.py similarity index 100% rename from tilelang/backend/cpu/__init__.py rename to tilelang/cpu/op/__init__.py diff --git a/tilelang/backend/cpu/gemm.py b/tilelang/cpu/op/gemm/__init__.py similarity index 60% rename from tilelang/backend/cpu/gemm.py rename to tilelang/cpu/op/gemm/__init__.py index affeaa3308..bdd710755d 100644 --- a/tilelang/backend/cpu/gemm.py +++ b/tilelang/cpu/op/gemm/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_scalar import GEMM_INST_SCALAR, GemmScalar +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_scalar import GEMM_INST_SCALAR, GemmScalar def _match_scalar(target) -> bool: diff --git a/tilelang/tileop/gemm/gemm_scalar.py b/tilelang/cpu/op/gemm/gemm_scalar.py similarity index 100% rename from tilelang/tileop/gemm/gemm_scalar.py rename to tilelang/cpu/op/gemm/gemm_scalar.py diff --git a/tilelang/cuda/__init__.py b/tilelang/cuda/__init__.py new file mode 100644 index 0000000000..8ce2aa2507 --- /dev/null +++ b/tilelang/cuda/__init__.py @@ -0,0 +1,3 @@ +from . import intrinsics # noqa: F401 +from . import op # noqa: F401 +from . import transform # noqa: F401 diff --git a/tilelang/cuda/intrinsics/__init__.py b/tilelang/cuda/intrinsics/__init__.py new file mode 100644 index 0000000000..8601d9342e --- /dev/null +++ b/tilelang/cuda/intrinsics/__init__.py @@ -0,0 +1,13 @@ +from .layout.utils import ( # noqa: F401 + mma_store_index_map, + get_ldmatrix_offset, + get_mma_micro_size, +) + +from .macro.mma_macro_generator import ( # noqa: F401 + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) + +from .layout.mma_layout import get_swizzle_layout # noqa: F401 +from .layout.mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/cuda/intrinsics/layout/__init__.py b/tilelang/cuda/intrinsics/layout/__init__.py new file mode 100644 index 0000000000..ff517fe501 --- /dev/null +++ b/tilelang/cuda/intrinsics/layout/__init__.py @@ -0,0 +1,8 @@ +from .utils import ( # noqa: F401 + mma_store_index_map, + get_ldmatrix_offset, + get_mma_micro_size, +) + +from .mma_layout import get_swizzle_layout # noqa: F401 +from .mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/cuda/intrinsics/layout/mma_layout.py similarity index 100% rename from tilelang/intrinsics/mma_layout.py rename to tilelang/cuda/intrinsics/layout/mma_layout.py diff --git a/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/cuda/intrinsics/layout/mma_sm70_layout.py similarity index 100% rename from tilelang/intrinsics/mma_sm70_layout.py rename to tilelang/cuda/intrinsics/layout/mma_sm70_layout.py diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py similarity index 99% rename from tilelang/intrinsics/mma_sp_layout.py rename to tilelang/cuda/intrinsics/layout/mma_sp_layout.py index 73da1289ab..c814e32307 100644 --- a/tilelang/intrinsics/mma_sp_layout.py +++ b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py @@ -1,7 +1,7 @@ from tvm import DataType from typing import Literal -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( mma_load_a_32x4_to_shared_16x8_layout, mma_load_a_32x16_to_shared_16x32_layout, mma_load_a_32x8_to_shared_16x16_layout, diff --git a/tilelang/intrinsics/utils.py b/tilelang/cuda/intrinsics/layout/utils.py similarity index 90% rename from tilelang/intrinsics/utils.py rename to tilelang/cuda/intrinsics/layout/utils.py index f65fff1a9b..050ad09327 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/cuda/intrinsics/layout/utils.py @@ -10,11 +10,9 @@ mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) -from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m, thread_id_shared_access_64x16_to_32x32_layout_C_m_n from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401 -from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 # the original implementation and insight is from the following code snippet @@ -89,14 +87,6 @@ def mma_store_index_map_fp64(thread_id, local_id): return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id) -def mfma_store_index_map(thread_id, local_id): - return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) - - -def mfma_store_index_map_32x32(thread_id, local_id): - return thread_id_shared_access_64x16_to_32x32_layout_C_m_n(thread_id, local_id) - - def get_mma_micro_size(dtype: Literal["float16", "int8"]): # TODO(lei): FP8 related precision support. # Basic Tensor Core Matrix Multiply operation Unit diff --git a/tilelang/cuda/intrinsics/macro/__init__.py b/tilelang/cuda/intrinsics/macro/__init__.py new file mode 100644 index 0000000000..658f791220 --- /dev/null +++ b/tilelang/cuda/intrinsics/macro/__init__.py @@ -0,0 +1,6 @@ +from .mma_macro_generator import ( # noqa: F401 + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) + +from .mma_sp_macro_generator import SparseTensorCoreIntrinEmitter # noqa: F401 diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 26e34e6a51..6461dbd885 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -8,12 +8,12 @@ from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tilelang import tvm as tvm from tvm.runtime import convert -from .utils import ( +from ..layout.utils import ( mma_store_index_map, get_ldmatrix_offset, ) from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_b, shared_16x16_to_mma_32x8_layout_sr_a, @@ -206,7 +206,7 @@ def get_thread_binding(self): return self.thread_var def get_store_index_map(self, inverse: bool = False) -> IndexMap: - from .utils import mma_store_index_map, mma_store_index_map_fp64 + from ..layout.utils import mma_store_index_map, mma_store_index_map_fp64 warp_size, local_size_c = self.WARP_SIZE, self.local_size_out if DataType(self.accum_dtype).bits == 64: diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_sm70_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py index 52679b169a..4fee93087a 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py @@ -8,7 +8,7 @@ from tilelang import tvm as tvm from tvm.runtime import convert from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_sm70_layout import ( +from tilelang.cuda.intrinsics.layout.mma_sm70_layout import ( shared_16x4_to_mma_a_32x4_layout, shared_4x16_to_mma_b_32x4_layout, shared_16x4_to_mma_b_32x4_layout_trans, diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_sp_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py index 18a37b8e83..826a0f58ec 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py @@ -6,13 +6,13 @@ from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.ir import Range from tvm.runtime import convert -from .utils import ( +from ..layout.utils import ( mma_store_index_map, get_ldmatrix_offset, ) from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_sp_layout import ( +from tilelang.cuda.intrinsics.layout.mma_sp_layout import ( shared_16x16_to_mma_sp_layout_sr_a, shared_16x16_to_mma_sp_layout_sr_b, shared_16x32_to_mma_sp_layout_sr_a, diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py similarity index 100% rename from tilelang/intrinsics/tcgen05_macro_generator.py rename to tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py similarity index 99% rename from tilelang/intrinsics/wgmma_macro_generator.py rename to tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py index 864420c771..f31c12fb94 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py @@ -15,7 +15,7 @@ make_linear_layout, ) from tvm.runtime import convert -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_a, shared_16x32_to_mma_32x16_layout_sr_a, diff --git a/tilelang/backend/cuda/__init__.py b/tilelang/cuda/op/__init__.py similarity index 63% rename from tilelang/backend/cuda/__init__.py rename to tilelang/cuda/op/__init__.py index 5d013cefee..743e9e3c63 100644 --- a/tilelang/backend/cuda/__init__.py +++ b/tilelang/cuda/op/__init__.py @@ -1,2 +1,4 @@ +"""CUDA op registration frontends.""" + from . import gemm # noqa: F401 from . import gemm_sp # noqa: F401 diff --git a/tilelang/backend/cuda/gemm.py b/tilelang/cuda/op/gemm/__init__.py similarity index 69% rename from tilelang/backend/cuda/gemm.py rename to tilelang/cuda/op/gemm/__init__.py index 0072fda1aa..f78878b54a 100644 --- a/tilelang/backend/cuda/gemm.py +++ b/tilelang/cuda/op/gemm/__init__.py @@ -1,10 +1,12 @@ +"""CUDA GEMM op registrations.""" + from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_mma import GEMM_INST_MMA, GemmMMA -from tilelang.tileop.gemm.gemm_mma_sm70 import GemmMMASm70 -from tilelang.tileop.gemm.gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5 -from tilelang.tileop.gemm.gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_mma import GEMM_INST_MMA, GemmMMA +from .gemm_mma_sm70 import GemmMMASm70 +from .gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5 +from .gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA from tilelang.utils.target import target_is_cuda, target_is_volta diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/cuda/op/gemm/gemm_mma.py similarity index 98% rename from tilelang/tileop/gemm/gemm_mma.py rename to tilelang/cuda/op/gemm/gemm_mma.py index 99e4eb4d9c..bd572c4075 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/cuda/op/gemm/gemm_mma.py @@ -1,8 +1,8 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/tileop/gemm/gemm_mma_sm70.py b/tilelang/cuda/op/gemm/gemm_mma_sm70.py similarity index 98% rename from tilelang/tileop/gemm/gemm_mma_sm70.py rename to tilelang/cuda/op/gemm/gemm_mma_sm70.py index 1d4fd21058..ca5068cbc5 100644 --- a/tilelang/tileop/gemm/gemm_mma_sm70.py +++ b/tilelang/cuda/op/gemm/gemm_mma_sm70.py @@ -1,9 +1,9 @@ from __future__ import annotations # for Volta GPUs, which use legacy MMA instructions -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_volta_swizzled_layout -from tilelang.intrinsics.mma_sm70_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_sm70_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/cuda/op/gemm/gemm_tcgen05.py similarity index 98% rename from tilelang/tileop/gemm/gemm_tcgen05.py rename to tilelang/cuda/op/gemm/gemm_tcgen05.py index 28d4c805be..a6107083df 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/cuda/op/gemm/gemm_tcgen05.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import ( Layout, make_full_bank_swizzled_layout, @@ -8,7 +8,7 @@ make_quarter_bank_swizzled_layout, make_linear_layout, ) -from tilelang.intrinsics.tcgen05_macro_generator import ( +from tilelang.cuda.intrinsics.macro.tcgen05_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang import language as T diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/cuda/op/gemm/gemm_wgmma.py similarity index 98% rename from tilelang/tileop/gemm/gemm_wgmma.py rename to tilelang/cuda/op/gemm/gemm_wgmma.py index 6618309263..5eabb1b797 100644 --- a/tilelang/tileop/gemm/gemm_wgmma.py +++ b/tilelang/cuda/op/gemm/gemm_wgmma.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import ( make_full_bank_swizzled_layout, make_half_bank_swizzled_layout, @@ -8,7 +8,7 @@ make_linear_layout, Layout, ) -from tilelang.intrinsics.wgmma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.wgmma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment diff --git a/tilelang/backend/cuda/gemm_sp.py b/tilelang/cuda/op/gemm_sp/__init__.py similarity index 51% rename from tilelang/backend/cuda/gemm_sp.py rename to tilelang/cuda/op/gemm_sp/__init__.py index fead5b3d3d..ed8ade6377 100644 --- a/tilelang/backend/cuda/gemm_sp.py +++ b/tilelang/cuda/op/gemm_sp/__init__.py @@ -1,7 +1,9 @@ +"""CUDA sparse GEMM op registrations.""" + from __future__ import annotations -from tilelang.backend.gemm_sp import register_gemm_sp_impl -from tilelang.tileop.gemm_sp.gemm_sp_mma import GemmSPMMA +from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl +from .gemm_sp_mma import GemmSPMMA from tilelang.utils.target import target_is_cuda diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py similarity index 98% rename from tilelang/tileop/gemm_sp/gemm_sp_mma.py rename to tilelang/cuda/op/gemm_sp/gemm_sp_mma.py index 1a7964f9f4..dc381f7047 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_mma.py +++ b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py @@ -1,6 +1,6 @@ -from .gemm_sp_base import GemmSPBase +from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter from tilelang.utils.language import is_shared, is_fragment from tilelang import tvm as tvm from tvm.target import Target diff --git a/tilelang/cuda/transform/__init__.py b/tilelang/cuda/transform/__init__.py new file mode 100644 index 0000000000..e6e4705c6f --- /dev/null +++ b/tilelang/cuda/transform/__init__.py @@ -0,0 +1,27 @@ +"""CUDA-specific transformation frontends.""" + +from tilelang.transform import _ffi_api + + +def LowerHopperIntrin(): + """LowerHopperIntrin""" + if hasattr(_ffi_api, "LowerHopperIntrin"): + return _ffi_api.LowerHopperIntrin() # type: ignore + return lambda f: f + + +def LowerL2Persistent(): + """LowerL2Persistent""" + return _ffi_api.LowerL2Persistent() # type: ignore + + +def PersistThreadblock(): + """PersistThreadblock""" + return _ffi_api.PersistThreadblock() # type: ignore + + +__all__ = [ + "LowerHopperIntrin", + "LowerL2Persistent", + "PersistThreadblock", +] diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5563845214..7c5b433dcb 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -204,7 +204,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map - mod = tilelang.transform.LowerL2Persistent()(mod) + mod = tilelang.cuda.transform.LowerL2Persistent()(mod) # Decouple type cast vectorization constraints before vectorization mod = tilelang.transform.DecoupleTypeCast()(mod) # Legalize vectorized loops to ensure they are valid @@ -270,7 +270,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tir.transform.InferFragment()(mod) mod = tilelang.transform.LowerThreadAllreduce()(mod) mod = tilelang.transform.LowerLDGSTG()(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) + mod = tilelang.cuda.transform.LowerHopperIntrin()(mod) # Global Barrier Synchronization must be applied before # SplitHostDevice pass, as the global barrier if allow_global_thread_synchronization(): @@ -305,6 +305,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) # Transform threadblock to persistent threadblock - mod = tilelang.transform.PersistThreadblock()(mod) + mod = tilelang.cuda.transform.PersistThreadblock()(mod) return mod diff --git a/tilelang/intrinsics/__init__.py b/tilelang/intrinsics/__init__.py index 1b3f106e71..b944ae89d0 100644 --- a/tilelang/intrinsics/__init__.py +++ b/tilelang/intrinsics/__init__.py @@ -1,14 +1,14 @@ -from .utils import ( +from tilelang.cuda.intrinsics.layout.utils import ( mma_store_index_map, # noqa: F401 get_ldmatrix_offset, # noqa: F401 ) -from .mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) -from .mma_layout import get_swizzle_layout # noqa: F401 -from .mma_layout import make_mma_swizzle_layout # noqa: F401 +from tilelang.cuda.intrinsics.layout.mma_layout import get_swizzle_layout # noqa: F401 +from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout # noqa: F401 -from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 +from tilelang.rocm.intrinsics.mfma_layout import make_mfma_swizzle_layout # noqa: F401 diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index d5ec03728f..1aa8aea17d 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -453,7 +453,7 @@ def make_blockscaled_gemm_layout( Returns: A Layout object for C's TMEM storage. """ - from tilelang.intrinsics.tcgen05_macro_generator import TensorCoreIntrinEmitter + from tilelang.cuda.intrinsics.macro.tcgen05_macro_generator import TensorCoreIntrinEmitter C_region = to_buffer_region(C) A_region = to_buffer_region(A) diff --git a/tilelang/rocm/__init__.py b/tilelang/rocm/__init__.py new file mode 100644 index 0000000000..a3b9cf6b63 --- /dev/null +++ b/tilelang/rocm/__init__.py @@ -0,0 +1,2 @@ +from . import intrinsics # noqa: F401 +from . import op # noqa: F401 diff --git a/tilelang/rocm/intrinsics/__init__.py b/tilelang/rocm/intrinsics/__init__.py new file mode 100644 index 0000000000..a972f683f1 --- /dev/null +++ b/tilelang/rocm/intrinsics/__init__.py @@ -0,0 +1,12 @@ +from .utils import ( # noqa: F401 + mfma_store_index_map, + mfma_store_index_map_32x32, + get_mma_micro_size, +) + +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 +from .mfma_macro_generator import ( # noqa: F401 + MatrixCoreIntrinEmitter, + MatrixCorePreshuffleIntrinEmitter, +) +from .wmma_macro_generator import WMMAIntrinEmitter # noqa: F401 diff --git a/tilelang/intrinsics/mfma_layout.py b/tilelang/rocm/intrinsics/mfma_layout.py similarity index 100% rename from tilelang/intrinsics/mfma_layout.py rename to tilelang/rocm/intrinsics/mfma_layout.py diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/rocm/intrinsics/mfma_macro_generator.py similarity index 100% rename from tilelang/intrinsics/mfma_macro_generator.py rename to tilelang/rocm/intrinsics/mfma_macro_generator.py diff --git a/tilelang/rocm/intrinsics/utils.py b/tilelang/rocm/intrinsics/utils.py new file mode 100644 index 0000000000..b1b4f68f76 --- /dev/null +++ b/tilelang/rocm/intrinsics/utils.py @@ -0,0 +1,23 @@ +from typing import Literal + +from .mfma_layout import ( + thread_id_shared_access_64x4_to_16x16_layout_C_n_m, + thread_id_shared_access_64x16_to_32x32_layout_C_m_n, +) +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 + + +def mfma_store_index_map(thread_id, local_id): + return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) + + +def mfma_store_index_map_32x32(thread_id, local_id): + return thread_id_shared_access_64x16_to_32x32_layout_C_m_n(thread_id, local_id) + + +def get_mma_micro_size(dtype: Literal["float16", "int8"]): + micro_size_x = micro_size_y = 16 + micro_size_k = 16 + if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: + micro_size_k = 32 + return micro_size_x, micro_size_y, micro_size_k diff --git a/tilelang/intrinsics/wmma_layout.py b/tilelang/rocm/intrinsics/wmma_layout.py similarity index 100% rename from tilelang/intrinsics/wmma_layout.py rename to tilelang/rocm/intrinsics/wmma_layout.py diff --git a/tilelang/intrinsics/wmma_macro_generator.py b/tilelang/rocm/intrinsics/wmma_macro_generator.py similarity index 100% rename from tilelang/intrinsics/wmma_macro_generator.py rename to tilelang/rocm/intrinsics/wmma_macro_generator.py diff --git a/tilelang/backend/rocm/__init__.py b/tilelang/rocm/op/__init__.py similarity index 100% rename from tilelang/backend/rocm/__init__.py rename to tilelang/rocm/op/__init__.py diff --git a/tilelang/backend/rocm/gemm.py b/tilelang/rocm/op/gemm/__init__.py similarity index 65% rename from tilelang/backend/rocm/gemm.py rename to tilelang/rocm/op/gemm/__init__.py index 94e7d17724..c08e949b35 100644 --- a/tilelang/backend/rocm/gemm.py +++ b/tilelang/rocm/op/gemm/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_mfma import GEMM_INST_MFMA, GemmMFMA -from tilelang.tileop.gemm.gemm_wmma import GEMM_INST_WMMA, GemmWMMA +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_mfma import GEMM_INST_MFMA, GemmMFMA +from .gemm_wmma import GEMM_INST_WMMA, GemmWMMA from tilelang.utils.target import target_is_hip diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/rocm/op/gemm/gemm_mfma.py similarity index 98% rename from tilelang/tileop/gemm/gemm_mfma.py rename to tilelang/rocm/op/gemm/gemm_mfma.py index 786baba96e..81f53d6eeb 100644 --- a/tilelang/tileop/gemm/gemm_mfma.py +++ b/tilelang/rocm/op/gemm/gemm_mfma.py @@ -1,8 +1,8 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mfma_macro_generator import ( +from tilelang.rocm.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/tileop/gemm/gemm_wmma.py b/tilelang/rocm/op/gemm/gemm_wmma.py similarity index 98% rename from tilelang/tileop/gemm/gemm_wmma.py rename to tilelang/rocm/op/gemm/gemm_wmma.py index ab4ae6d50a..4e0e0646e6 100644 --- a/tilelang/tileop/gemm/gemm_wmma.py +++ b/tilelang/rocm/op/gemm/gemm_wmma.py @@ -2,9 +2,9 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.wmma_macro_generator import WMMAIntrinEmitter +from tilelang.rocm.intrinsics.wmma_macro_generator import WMMAIntrinEmitter from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 22a2c91007..37f7e2d235 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -5,7 +5,7 @@ from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi -from tilelang.backend.gemm import resolve_gemm_impl +from .registry import resolve_gemm_impl from tilelang import _ffi_api diff --git a/tilelang/backend/gemm.py b/tilelang/tileop/gemm/registry.py similarity index 100% rename from tilelang/backend/gemm.py rename to tilelang/tileop/gemm/registry.py diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py index 6e2c4a7d2b..1a49b86ec3 100644 --- a/tilelang/tileop/gemm_sp/__init__.py +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -5,7 +5,7 @@ from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi -from tilelang.backend.gemm_sp import resolve_gemm_sp_impl +from .registry import resolve_gemm_sp_impl from tilelang.tileop.base import GemmWarpPolicy diff --git a/tilelang/backend/gemm_sp.py b/tilelang/tileop/gemm_sp/registry.py similarity index 100% rename from tilelang/backend/gemm_sp.py rename to tilelang/tileop/gemm_sp/registry.py diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 677887bd49..599ccef1ad 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -121,17 +121,6 @@ def VerifyParallelLoop(): return _ffi_api.VerifyParallelLoop() # type: ignore -def LowerHopperIntrin(): - """LowerHopperIntrin - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore - - def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. @@ -424,21 +413,11 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore -def LowerL2Persistent(): - """LowerL2Persistent""" - return _ffi_api.LowerL2Persistent() # type: ignore - - def MarkCudaSyncCalls(have_pdl: bool = False): """MarkCudaSyncCalls""" return _ffi_api.MarkCudaSyncCalls(have_pdl) # type: ignore -def PersistThreadblock(): - """PersistThreadblock""" - return _ffi_api.PersistThreadblock() # type: ignore - - def LowerSharedBarrier(): """LowerSharedBarrier""" return _ffi_api.LowerSharedBarrier() # type: ignore