-
Notifications
You must be signed in to change notification settings - Fork 552
[Metal] route FP8-input T.gemm to scalar fallback #2140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ee01743
01dff39
1ab387c
d1ccdc4
911a3a2
3ee8b7f
79158bd
d4fb922
7f948ec
971c17b
1962b1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| import argparse | ||
| import logging | ||
| import time | ||
|
|
||
| import torch | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| logging.getLogger("tilelang").setLevel(logging.WARNING) | ||
|
|
||
| BLOCK_CONFIGS = [ | ||
| (16, 16, 16), | ||
| (32, 32, 16), | ||
| (32, 32, 32), | ||
| (64, 64, 32), | ||
| ] | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32): | ||
|
|
||
| @T.prim_func | ||
| def gemm_kernel( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((K, N), dtype), | ||
| C: T.Tensor((M, N), accum_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") | ||
| B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| T.clear(C_local) | ||
| for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): | ||
| T.copy(A[by * block_M, ko * block_K], A_shared) | ||
| T.copy(B[ko * block_K, bx * block_N], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local) | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| return gemm_kernel | ||
|
|
||
|
|
||
| def _tflops(M, N, K, seconds): | ||
| return 2.0 * M * N * K / seconds / 1e12 | ||
|
|
||
|
|
||
| def _bench(fn, warmup, repeats): | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.mps.synchronize() | ||
| t0 = time.perf_counter() | ||
| for _ in range(repeats): | ||
| fn() | ||
| torch.mps.synchronize() | ||
| return (time.perf_counter() - t0) / repeats | ||
|
|
||
|
|
||
| def bench_torch_mps(M, N, K, warmup, repeats): | ||
| a = torch.randn(M, K, dtype=torch.float16, device="mps") | ||
| b = torch.randn(K, N, dtype=torch.float16, device="mps") | ||
| avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats) | ||
| return _tflops(M, N, K, avg_s) | ||
|
|
||
|
|
||
| def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats): | ||
| kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K) | ||
| a = torch.randn(M, K, dtype=torch.float16, device="mps") | ||
| b = torch.randn(K, N, dtype=torch.float16, device="mps") | ||
| c = torch.zeros(M, N, dtype=torch.float32, device="mps") | ||
| avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats) | ||
| return _tflops(M, N, K, avg_s) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)") | ||
| parser.add_argument("--m", type=int, default=4096) | ||
| parser.add_argument("--n", type=int, default=4096) | ||
| parser.add_argument("--k", type=int, default=4096) | ||
| parser.add_argument("--warmup", type=int, default=10) | ||
| parser.add_argument("--repeats", type=int, default=100) | ||
| parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)") | ||
| args = parser.parse_args() | ||
|
|
||
| M, N, K = args.m, args.n, args.k | ||
| for name, value in (("m", M), ("n", N), ("k", K), ("repeats", args.repeats)): | ||
| if value <= 0: | ||
| raise SystemExit(f"--{name} must be a positive integer, got {value}") | ||
| if args.warmup < 0: | ||
| raise SystemExit(f"--warmup must be non-negative, got {args.warmup}") | ||
|
|
||
| print(f"torch: {torch.__version__}") | ||
| print(f"tilelang: {tilelang.__version__}") | ||
| print(f"MPS: {torch.backends.mps.is_available()}") | ||
| if not torch.backends.mps.is_available(): | ||
| raise SystemExit("Metal GEMM benchmark requires PyTorch MPS support") | ||
| print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}") | ||
| print() | ||
|
|
||
| ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats) | ||
| print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS") | ||
| print() | ||
|
|
||
| configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)] | ||
|
|
||
| print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}") | ||
| print("-" * 44) | ||
|
|
||
| best_tflops = 0.0 | ||
| best_config = configs[0] | ||
| for bM, bN, bK in configs: | ||
| try: | ||
| tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) | ||
| ratio = tl / ref_tflops * 100 | ||
| tag = "" | ||
| if tl > best_tflops: | ||
| best_tflops = tl | ||
| best_config = (bM, bN, bK) | ||
| print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") | ||
| except Exception as e: | ||
| print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") | ||
|
|
||
| if args.sweep: | ||
| print() | ||
| print(f"Best config: {best_config}") | ||
| print(f"Best TFlops: {best_tflops:.1f}") | ||
| print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, | |
| return result_map; | ||
| } | ||
|
|
||
| if (copy_inst == CopyInst::kMetalSIMDGroup) { | ||
| return {}; | ||
| } | ||
|
|
||
| // for LDSM/STSM, the layout was deduced from register layout | ||
| // so we can directly apply the layout of normal copy | ||
| // Use parallel op to infer the layout | ||
|
|
@@ -792,8 +796,47 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, | |
| if (!CheckCPAsyncCopyPreconditions()) { | ||
| return false; | ||
| } | ||
| // Skip vectorize size check here because, during the Infer Layout stage, | ||
| // the layout is not stable and the vectorized size cannot be determined. | ||
| return true; | ||
| } | ||
|
|
||
| bool CopyNode::CheckSIMDGroupCopy(Target target) const { | ||
| if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) { | ||
| return false; | ||
| } | ||
| if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) { | ||
| return false; | ||
| } | ||
| if (src->dtype != dst->dtype) { | ||
| return false; | ||
| } | ||
| if (src_range.size() != 2 || dst_range.size() != 2 || | ||
| dst->shape.size() != 2) { | ||
| return false; | ||
| } | ||
|
|
||
| int total_elements = 1; | ||
| for (auto extent : src->shape) { | ||
| auto imm = extent.as<IntImmNode>(); | ||
| if (!imm) { | ||
| return false; | ||
| } | ||
| total_elements *= imm->value; | ||
| } | ||
| if (total_elements % 64 != 0) { | ||
| return false; | ||
| } | ||
|
|
||
| for (int i = 0; i < 2; ++i) { | ||
| auto src_shape = src->shape[i].as<IntImmNode>(); | ||
| auto src_min = src_range[i]->min.as<IntImmNode>(); | ||
| auto src_extent = src_range[i]->extent.as<IntImmNode>(); | ||
| auto dst_extent = dst_range[i]->extent.as<IntImmNode>(); | ||
| if (!src_shape || !src_min || src_min->value != 0 || !src_extent || | ||
| !dst_extent || src_extent->value != src_shape->value || | ||
| src_extent->value != dst_extent->value || src_extent->value % 8 != 0) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
|
Comment on lines
+802
to
840
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard Metal SIMD-group copy against edge-tile OOB stores. This legality check never proves that 🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
|
|
@@ -864,6 +907,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map, | |
| return CopyInst::kTMemLoad; | ||
| } else if (CheckTMemStore(target)) { | ||
| return CopyInst::kTMemStore; | ||
| } else if (CheckSIMDGroupCopy(target)) { | ||
| return CopyInst::kMetalSIMDGroup; | ||
| } else { | ||
| return CopyInst::kNormal; | ||
| } | ||
|
|
@@ -897,6 +942,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |
| auto cp_async_copy = LowerCPAsyncCopy(T, analyzer); | ||
| ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy"; | ||
| return cp_async_copy; | ||
| } else if (copy_inst == CopyInst::kMetalSIMDGroup) { | ||
| return LowerSIMDGroupCopy(T, analyzer); | ||
| } else if (copy_inst == CopyInst::kNormal) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } else { | ||
|
|
@@ -982,7 +1029,125 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T, | |
| return cp_async_loop; | ||
| } | ||
|
|
||
| // Lowers the copy using standard load/store with loop transformations. | ||
| Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T, | ||
| arith::Analyzer *analyzer) const { | ||
| ICHECK(IsSIMDGroupBuffer(src)); | ||
| int total_elements = 1; | ||
| for (auto s : src->shape) { | ||
| auto imm = s.as<IntImmNode>(); | ||
| ICHECK(imm) << "simdgroup buffer must have constant shape"; | ||
| total_elements *= imm->value; | ||
| } | ||
| ICHECK(total_elements % 64 == 0) | ||
| << "simdgroup buffer size must be multiple of 64 (8x8), got " | ||
| << total_elements; | ||
|
|
||
| ICHECK(dst_range.size() == 2) | ||
| << "Expected 2D destination for simdgroup store"; | ||
| PrimExpr dst_row_base = dst_range[0]->min; | ||
| PrimExpr dst_col_base = dst_range[1]->min; | ||
| ICHECK_EQ(dst->shape.size(), 2U) | ||
| << "simdgroup store currently supports 2D destination buffers"; | ||
| Array<PrimExpr> dst_strides = dst->strides; | ||
| if (dst_strides.empty()) { | ||
| PrimExpr stride = 1; | ||
| dst_strides.resize(dst->shape.size()); | ||
| for (int i = static_cast<int>(dst->shape.size()) - 1; i >= 0; --i) { | ||
| dst_strides.Set(i, stride); | ||
| stride *= dst->shape[i]; | ||
| } | ||
| } | ||
| if (dst_strides.size() != dst->shape.size()) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| if (!analyzer->CanProveEqual(dst_strides[1], 1)) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| PrimExpr dst_stride = dst_strides[0]; | ||
|
|
||
| int warp_size = TargetGetWarpSize(T.target); | ||
| auto block_extent = T.thread_bounds->extent.as<IntImmNode>(); | ||
| if (!block_extent || warp_size <= 0 || block_extent->value % warp_size != 0) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| int block_size = block_extent->value; | ||
| int num_warps = block_size / warp_size; | ||
| if (num_warps <= 0) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| PrimExpr relative_thread = T.thread_var - T.thread_bounds->min; | ||
| PrimExpr warp_id = FloorDiv(relative_thread, warp_size); | ||
|
|
||
| auto M_imm = src_range[0]->extent.as<IntImmNode>(); | ||
| auto N_imm = src_range[1]->extent.as<IntImmNode>(); | ||
| if (!M_imm || !N_imm) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| int M = M_imm->value; | ||
| int N = N_imm->value; | ||
|
|
||
| int kMPerWarp = 8; | ||
| int kNPerWarp = 8; | ||
| int m_warp = 1, n_warp = num_warps; | ||
| int max_m = M / kMPerWarp; | ||
| int max_n = N / kNPerWarp; | ||
| if (max_m <= 0 || max_n <= 0) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| float ideal = N > 0 ? static_cast<float>(M) / N : 1.f; | ||
| float best_score = std::numeric_limits<float>::max(); | ||
| for (int m = 1; m <= std::min(num_warps, max_m); ++m) { | ||
| if (num_warps % m != 0) | ||
| continue; | ||
| int n = num_warps / m; | ||
| if (n > max_n) | ||
| continue; | ||
| if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0) | ||
| continue; | ||
| float m_per = static_cast<float>(M) / (m * kMPerWarp); | ||
| float n_per = static_cast<float>(N) / (n * kNPerWarp); | ||
| float score = std::abs(m_per / n_per - ideal); | ||
| if (score < best_score) { | ||
| best_score = score; | ||
| m_warp = m; | ||
| n_warp = n; | ||
| } | ||
| } | ||
|
|
||
| if (best_score == std::numeric_limits<float>::max() || M < m_warp * 8 || | ||
| N < n_warp * 8) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
| int warp_row_tiles = M / m_warp / 8; | ||
| int warp_col_tiles = N / n_warp / 8; | ||
| if (warp_row_tiles <= 0 || warp_col_tiles <= 0 || | ||
| warp_row_tiles * warp_col_tiles * 64 > total_elements) { | ||
| return LowerNormalCopy(T, analyzer); | ||
| } | ||
|
|
||
| PrimExpr warp_m = FloorMod(warp_id, m_warp); | ||
| PrimExpr warp_n = FloorDiv(warp_id, m_warp); | ||
|
|
||
| Array<Stmt> stmts; | ||
| for (int i = 0; i < warp_row_tiles; i++) { | ||
| for (int j = 0; j < warp_col_tiles; j++) { | ||
| int tile_idx = i * warp_col_tiles + j; | ||
| PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8; | ||
| PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8; | ||
| PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), | ||
| {BufferLoad(dst, {row, col})}); | ||
| stmts.push_back(Evaluate( | ||
| Call(DataType::Handle(), builtin::simdgroup_store(), | ||
| {src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride, | ||
| IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8), | ||
| Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))}))); | ||
| } | ||
| } | ||
| if (stmts.size() == 1) | ||
| return stmts[0]; | ||
| return SeqStmt(stmts); | ||
| } | ||
|
|
||
| Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, | ||
| arith::Analyzer *analyzer) const { | ||
| bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1331
Unconditionally compiling
src/target/codegen_metal.ccbreaks non-Metal builds.The CMakeLists.txt appends
src/target/codegen_metal.cctoTILE_LANG_SRCSbefore theUSE_METALandAPPLEguards take effect. However, this file includesruntime/metal/metal_module.h, which is not available on non-Metal or non-Apple platforms. This causes compilation to fail before the earlyreturn()statements can skip the Metal-specific configuration.To fix: either move the source append behind the appropriate guards, or refactor
codegen_metal.ccto isolate the Metal runtime dependencies from the codegen-only logic.🤖 Prompt for AI Agents