From 4b015bfdb84d1e3b6feaa55f18e63dfb1c905b6e Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Sun, 26 Apr 2026 20:57:39 +0800 Subject: [PATCH 1/3] [Metal] Add Metal GEMM support with simdgroup_matrix MMA Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device). --- CMakeLists.txt | 7 + .../matmul_metal/benchmark_matmul_metal.py | 119 ++++ pyproject.toml | 1 + requirements-dev.txt | 1 + requirements.txt | 1 + src/op/copy.cc | 100 +++- src/op/copy.h | 21 +- src/op/fill.cc | 25 +- src/op/gemm.cc | 9 +- src/op/gemm.h | 5 +- src/op/parallel.cc | 5 +- src/op/utils.h | 9 +- src/target/codegen_metal.cc | 521 ++++++++++++++++++ src/target/codegen_metal.h | 74 +++ src/transform/layout_inference.cc | 12 +- .../lower_device_storage_access_info.cc | 2 +- testing/python/metal/test_metal_gemm_v2.py | 91 +++ .../python/metal/test_metal_gemm_v2_linux.py | 82 +++ .../metal/test_metal_simdgroup_store.py | 133 +++++ tilelang/engine/lower.py | 4 +- tilelang/engine/phase.py | 5 + tilelang/intrinsics/metal_macro_generator.py | 203 +++++++ tilelang/jit/adapter/torch/metal.py | 3 + tilelang/tileop/gemm/__init__.py | 13 +- tilelang/tileop/gemm/gemm_metal.py | 105 ++++ tilelang/tileop/gemm/inst.py | 6 +- tilelang/transform/decouple_type_cast.py | 6 +- .../transform/metal_fragment_to_simdgroup.py | 133 +++++ tilelang/utils/language.py | 14 + 29 files changed, 1682 insertions(+), 28 deletions(-) create mode 100644 benchmark/matmul_metal/benchmark_matmul_metal.py create mode 100644 src/target/codegen_metal.cc create mode 100644 src/target/codegen_metal.h create mode 100644 testing/python/metal/test_metal_gemm_v2.py create mode 100644 testing/python/metal/test_metal_gemm_v2_linux.py create mode 100644 testing/python/metal/test_metal_simdgroup_store.py create mode 100644 tilelang/intrinsics/metal_macro_generator.py create mode 100644 tilelang/tileop/gemm/gemm_metal.py create mode 100644 tilelang/transform/metal_fragment_to_simdgroup.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 212b1f67ae..f2ca9bc1ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -190,6 +190,13 @@ list(APPEND TILE_LANG_SRCS src/runtime/error_helpers.cc ) +# Metal codegen is pure C++ (no Apple frameworks) and can generate Metal shader +# source on any platform. Always compile it so that "target.build.tilelang_metal" +# is available for cross-compilation on Linux/Windows. +list(APPEND TILE_LANG_SRCS + src/target/codegen_metal.cc +) + set(TILELANG_OUTPUT_TARGETS tilelang tvm) # Track if the user explicitly selected a backend via cache options. diff --git a/benchmark/matmul_metal/benchmark_matmul_metal.py b/benchmark/matmul_metal/benchmark_matmul_metal.py new file mode 100644 index 0000000000..20d75ad24b --- /dev/null +++ b/benchmark/matmul_metal/benchmark_matmul_metal.py @@ -0,0 +1,119 @@ +import argparse +import logging +import time + +import torch + +import tilelang +import tilelang.language as T + +logging.getLogger("tilelang").setLevel(logging.WARNING) + +BLOCK_CONFIGS = [ + (16, 16, 16), + (32, 32, 16), + (32, 32, 32), + (64, 64, 32), +] + + +@tilelang.jit +def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32): + + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +def _tflops(M, N, K, seconds): + return 2.0 * M * N * K / seconds / 1e12 + + +def _bench(fn, warmup, repeats): + for _ in range(warmup): + fn() + torch.mps.synchronize() + t0 = time.perf_counter() + for _ in range(repeats): + fn() + torch.mps.synchronize() + return (time.perf_counter() - t0) / repeats + + +def bench_torch_mps(M, N, K, warmup, repeats): + a = torch.randn(M, K, dtype=torch.float16, device="mps") + b = torch.randn(K, N, dtype=torch.float16, device="mps") + avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats) + return _tflops(M, N, K, avg_s) + + +def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats): + kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K) + a = torch.randn(M, K, dtype=torch.float16, device="mps") + b = torch.randn(K, N, dtype=torch.float16, device="mps") + c = torch.zeros(M, N, dtype=torch.float32, device="mps") + avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats) + return _tflops(M, N, K, avg_s) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)") + parser.add_argument("--m", type=int, default=4096) + parser.add_argument("--n", type=int, default=4096) + parser.add_argument("--k", type=int, default=4096) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--repeats", type=int, default=100) + parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + + print(f"torch: {torch.__version__}") + print(f"tilelang: {tilelang.__version__}") + print(f"MPS: {torch.backends.mps.is_available()}") + print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}") + print() + + ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats) + print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS") + print() + + configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)] + + print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}") + print("-" * 44) + + best_tflops = 0.0 + best_config = configs[0] + for bM, bN, bK in configs: + try: + tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) + ratio = tl / ref_tflops * 100 + tag = "" + if tl > best_tflops: + best_tflops = tl + best_config = (bM, bN, bK) + print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") + except Exception as e: + print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") + + if args.sweep: + print() + print(f"Best config: {best_config}") + print(f"Best TFlops: {best_tflops:.1f}") + print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}") diff --git a/pyproject.toml b/pyproject.toml index 80aba41e32..d2c046df43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ # requirement as wide as possible to be compatible with other libraries # pip will try to use latest version whenever possible. "apache-tvm-ffi~=0.1.0,>=0.1.2", + "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", # torch-c-dlpack-ext provides prebuilt torch extensions. # Without it, TVM FFI may require JIT compilation on first import. "torch-c-dlpack-ext; python_version < '3.14'", diff --git a/requirements-dev.txt b/requirements-dev.txt index f8dccdc871..e4403be050 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # Requirements to run local build with `--no-build-isolation` or other developments apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi<0.1.8; platform_system == 'Darwin' build cmake>=3.26 cython>=3.1.0 diff --git a/requirements.txt b/requirements.txt index 2dbe070d9a..9c841cee5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Runtime requirements apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi<0.1.8; platform_system == 'Darwin' torch-c-dlpack-ext; python_version < '3.14' cloudpickle ml-dtypes diff --git a/src/op/copy.cc b/src/op/copy.cc index 93bd0cf70b..6440e424c0 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, return result_map; } + if (copy_inst == CopyInst::kMetalSIMDGroup) { + return {}; + } + // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -792,11 +796,16 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, if (!CheckCPAsyncCopyPreconditions()) { return false; } - // Skip vectorize size check here because, during the Infer Layout stage, - // the layout is not stable and the vectorized size cannot be determined. return true; } +bool CopyNode::CheckSIMDGroupCopy(Target target) const { + if (TargetIsMetal(target) && IsSIMDGroupBuffer(src)) { + return IsSharedBuffer(dst) || IsGlobalBuffer(dst); + } + return false; +} + // Selects the most specific copy instruction for the given target and buffers. // Priority: BulkLoad1D, BulkStore1D, BulkLoad, BulkStore, LDSM, STSM, // TMemLoad, TMemStore, CPAsync, Normal. @@ -864,6 +873,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (CheckSIMDGroupCopy(target)) { + return CopyInst::kMetalSIMDGroup; } else { return CopyInst::kNormal; } @@ -897,6 +908,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto cp_async_copy = LowerCPAsyncCopy(T, analyzer); ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy"; return cp_async_copy; + } else if (copy_inst == CopyInst::kMetalSIMDGroup) { + return LowerSIMDGroupCopy(T, analyzer); } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); } else { @@ -982,7 +995,88 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T, return cp_async_loop; } -// Lowers the copy using standard load/store with loop transformations. +Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + ICHECK(IsSIMDGroupBuffer(src)); + int total_elements = 1; + for (auto s : src->shape) { + auto imm = s.as(); + ICHECK(imm) << "simdgroup buffer must have constant shape"; + total_elements *= imm->value; + } + ICHECK(total_elements % 64 == 0) + << "simdgroup buffer size must be multiple of 64 (8x8), got " + << total_elements; + + ICHECK(dst_range.size() == 2) + << "Expected 2D destination for simdgroup store"; + PrimExpr dst_row_base = dst_range[0]->min; + PrimExpr dst_col_base = dst_range[1]->min; + PrimExpr dst_stride = dst->shape[dst->shape.size() - 1]; + + int warp_size = TargetGetWarpSize(T.target); + int block_size = T.thread_bounds->extent.as()->value; + int num_warps = block_size / warp_size; + PrimExpr warp_id = FloorDiv(T.thread_var, warp_size); + + int M = src_range[0]->extent.as()->value; + int N = src_range[1]->extent.as()->value; + + int kMPerWarp = 8; + int kNPerWarp = 8; + int m_warp = 1, n_warp = num_warps; + int max_m = M / kMPerWarp; + int max_n = N / kNPerWarp; + float ideal = N > 0 ? static_cast(M) / N : 1.f; + float best_score = std::numeric_limits::max(); + for (int m = 1; m <= std::min(num_warps, max_m); ++m) { + if (num_warps % m != 0) + continue; + int n = num_warps / m; + if (n > max_n) + continue; + float m_per = static_cast(M) / (m * kMPerWarp); + float n_per = static_cast(N) / (n * kNPerWarp); + float score = std::abs(m_per / n_per - ideal); + if (score < best_score) { + best_score = score; + m_warp = m; + n_warp = n; + } + } + + ICHECK(M >= m_warp * 8 && N >= n_warp * 8) + << "Cannot partition " << M << "x" << N << " matrix across " << m_warp + << "x" << n_warp << " warps with 8x8 simdgroup tiles"; + int warp_row_tiles = M / m_warp / 8; + int warp_col_tiles = N / n_warp / 8; + ICHECK(warp_row_tiles > 0 && warp_col_tiles > 0); + ICHECK(warp_row_tiles * warp_col_tiles * 64 <= total_elements) + << "Warp partition produces more tiles than buffer capacity"; + + PrimExpr warp_m = FloorMod(warp_id, m_warp); + PrimExpr warp_n = FloorDiv(warp_id, m_warp); + + Array stmts; + for (int i = 0; i < warp_row_tiles; i++) { + for (int j = 0; j < warp_col_tiles; j++) { + int tile_idx = i * warp_col_tiles + j; + PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8; + PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8; + PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(dst, {row, col})}); + stmts.push_back(Evaluate( + Call(DataType::Handle(), builtin::simdgroup_store(), + {src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride, + IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8), + Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))}))); + } + } + if (stmts.size() == 1) + return stmts[0]; + return SeqStmt(stmts); +} + Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; diff --git a/src/op/copy.h b/src/op/copy.h index d20f519815..4a80b75e46 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -24,10 +24,11 @@ enum class CopyInst : uint8_t { kCPAsync = 5, // cp.async global->shared copy // we should separate the bulk load and store for 1d and multi-dim // as they have different memory access patterns - kBulkLoad1D = 6, // utilize tma load 1d - kBulkStore1D = 7, // utilize tma store 1d - kTMemLoad = 8, // tcgen05.ld (tensor memory -> register) - kTMemStore = 9, // tcgen05.st (register -> tensor memory) + kBulkLoad1D = 6, // utilize tma load 1d + kBulkStore1D = 7, // utilize tma store 1d + kTMemLoad = 8, // tcgen05.ld (tensor memory -> register) + kTMemStore = 9, // tcgen05.st (register -> tensor memory) + kMetalSIMDGroup = 10, // Metal simdgroup load/store }; /// Convert CopyInst enum to string for debugging @@ -53,6 +54,8 @@ inline const char *CopyInstToString(CopyInst inst) { return "TMemLoad"; case CopyInst::kTMemStore: return "TMemStore"; + case CopyInst::kMetalSIMDGroup: + return "MetalSIMDGroup"; default: return "Unknown"; } @@ -290,6 +293,11 @@ class CopyNode : public TileOperatorNode { arith::Analyzer *analyzer) const; protected: + /*! + * \brief Check if copy from Metal simdgroup to shared/global is supported. + */ + bool CheckSIMDGroupCopy(Target target) const; + /*! * \brief Get the copy instruction type. */ @@ -331,6 +339,11 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerCPAsyncCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! + * \brief Generate lowering for simdgroup store. + */ + Stmt LowerSIMDGroupCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! * \brief Generate SIMT (thread-level) loop for copying. */ diff --git a/src/op/fill.cc b/src/op/fill.cc index be0dd8dc10..7434dd24c5 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -156,7 +156,30 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { * @return Stmt The lowered TIR statement implementing the fill. */ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (IsFragmentBuffer(dst)) { + if (IsSIMDGroupBuffer(dst)) { + int region_elements = 1; + for (auto r : region) { + auto imm = r->extent.as(); + ICHECK(imm) << "simdgroup fill region must have constant extents"; + region_elements *= imm->value; + } + int total_elements = region_elements; + ICHECK(total_elements % 64 == 0) + << "simdgroup buffer size must be multiple of 64 (8x8), got " + << total_elements; + int num_matrices = total_elements / 64; + PrimExpr fill_value = Cast(dst->dtype, value); + Array stmts; + for (int i = 0; i < num_matrices; i++) { + stmts.push_back(Evaluate( + Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(), + {dst->data, IntImm(DataType::Int(32), i), fill_value, + IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8)}))); + } + if (stmts.size() == 1) + return stmts[0]; + return SeqStmt(stmts); + } else if (IsFragmentBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5472b33386..1a5bc65a59 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -183,6 +183,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const { return GemmInst::kMMA; } else if (TargetIsCPU(target)) { return GemmInst::kScalar; + } else if (TargetIsMetal(target)) { + return GemmInst::kMetalSimdgroup; } else { ICHECK(0) << "Unsupported target for gemm: " << target->str(); return GemmInst::kMMA; @@ -199,8 +201,11 @@ std::pair GemmWarpPolicyNode::computeWarpPartition( } int m_warp = 1, n_warp = 1; - constexpr int kMPerWarp = 16; // Rows processed by a single warp - int kNPerWarp = 8; // Columns processed by a single warp + int kMPerWarp = 16; // Rows processed by a single warp + if (TargetIsMetal(target)) { + kMPerWarp = 8; + } + int kNPerWarp = 8; // Columns processed by a single warp if (TargetIsVolta(target)) { kNPerWarp = 16; } else if (TargetIsCDNA(target)) { diff --git a/src/op/gemm.h b/src/op/gemm.h index 26b6678402..90f85b00ef 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -45,7 +45,8 @@ enum class GemmInst : uint8_t { kTCGEN5MMA, kMFMA, kScalar, - kWMMA + kWMMA, + kMetalSimdgroup }; /// Convert GemmInst enum to string for debugging @@ -63,6 +64,8 @@ inline const char *GemmInstToString(GemmInst inst) { return "Scalar"; case GemmInst::kWMMA: return "WMMA"; + case GemmInst::kMetalSimdgroup: + return "Metal"; default: return "Unknown"; } diff --git a/src/op/parallel.cc b/src/op/parallel.cc index b93467b54f..e31e94426b 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -359,7 +359,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, info && info.value()->rep == ReducerRepType::ALL) continue; - auto frag = T.layout_map[buffer].as().value(); bool is_fully_replicated = IsBufferCompletelyReplicated(buffer, T.layout_map); @@ -379,7 +378,9 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // If the buffer is not replicated and shape is equal to the // source_buffer, use it as source_buffer because the layout inference // is more accurate - if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) { + auto frag = T.layout_map[buffer].as(); + if (frag.has_value() && is_one(frag.value()->ReplicateExtent()) && + !source_buffer.defined()) { source_buffer = buffer; } } diff --git a/src/op/utils.h b/src/op/utils.h index 77e21feda6..8831fa6877 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -53,11 +53,18 @@ TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, TVM_DLL PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask); -// Check if a buffer is a fragment buffer (scope == "local.fragment") inline bool IsFragmentBuffer(const Buffer &buffer) { return buffer.defined() && buffer.scope() == "local.fragment"; } +inline bool IsSIMDGroupBuffer(const Buffer &buffer) { + return buffer.defined() && buffer.scope() == "metal.simdgroup"; +} + +inline bool IsRegisterBuffer(const Buffer &buffer) { + return IsFragmentBuffer(buffer) || IsSIMDGroupBuffer(buffer); +} + // Expand a lower-rank layout by prepending the leading dimensions of `buffer` // so that the resulting layout input shape matches `buffer->shape`. // diff --git a/src/target/codegen_metal.cc b/src/target/codegen_metal.cc new file mode 100644 index 0000000000..dd9daf3545 --- /dev/null +++ b/src/target/codegen_metal.cc @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_metal.cc + */ +#include "codegen_metal.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "runtime/metal/metal_module.h" +#include "runtime/thread_storage_scope.h" +#include "target/build_common.h" + +namespace tvm { +namespace codegen { + +void CodeGenTileLangMetal::InitFuncState(const PrimFunc &f) { + CodeGenC::InitFuncState(f); + // analyze the data; + for (Var arg : f->params) { + if (arg.dtype().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } +} + +CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) { + decl_stream << "#include \n"; + decl_stream << "using namespace metal;\n\n"; + decl_stream << "union __TVMArgUnion {\n" + << " int v_int[2];\n" + << "};\n\n"; +} + +void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar, + const PrimFunc &func) { + // NOTE: There is no inter-function calls among Metal kernels. + // For now we keep the metal codegen without inter-function call + // process. + // We can switch to follow the flow with inter-function call process + // after the Metal function declaration is properly printed. + // In Metal, for PrimFuncs with signature + // def func(A: Buffer, B: Buffer, x: int, y: float) -> None + // where there are trailing pod parameters, the codegen emits a struct + // struct func_params{ x: int; y: float; } + // for the function. In the flow of inter-function call process, + // the struct will be emitted for every time a function is declared. + // So consequently there are duplicate appearances of a same struct, + // which makes the Metal compiler unable to recognize. + + // clear previous generated state. + this->InitFuncState(func); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); + + // add to alloc buffer type. + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.has_value()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + + // Function header. + this->stream << "kernel void " + << static_cast(global_symbol.value()) << "("; + + // Buffer arguments + size_t num_buffer = 0; + size_t limit = + target_->GetAttr("max_function_args").value().IntValue(); + if (func->params.size() > limit) { + LOG(WARNING) << "Probably you won't be able to execute your kernel due to " + "high number of " + "buffers in the kernel"; + } + for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { + Var v = func->params[i]; + if (!v.dtype().is_handle()) + break; + this->stream << " "; + std::string vid = AllocVarID(v.get()); + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, this->stream); + } + PrintType(GetType(v), this->stream); + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + } + // Setup normal arguments. + size_t nargs = func->params.size() - num_buffer; + std::string varg = name_supply_->FreshName("arg"); + if (nargs != 0) { + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; + this->stream << " constant " << arg_buf_type << "& " << varg + << " [[ buffer(" << num_buffer << ") ]],\n"; + // declare the struct + decl_stream << "struct " << arg_buf_type << " {\n"; + for (size_t i = num_buffer; i < func->params.size(); ++i) { + Var v = func->params[i]; + ICHECK(!v.dtype().is_handle()); + std::string vid = AllocVarID(v.get()); + std::ostringstream vref; + if (v.dtype().bits() == 32) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << "[2];\n"; + vref << varg << "." << vid << "[0]"; + } else if (v.dtype().bits() == 64) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << ";\n"; + vref << varg << "." << vid; + } else { + // For non 32bit type, ref through arg union. + decl_stream << " __TVMArgUnion " << vid << ";\n"; + vref << varg << "." << vid << ".v_"; + PrintType(v.dtype(), vref); + } + var_idmap_[v.get()] = vref.str(); + } + decl_stream << "};\n\n"; + } + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + int work_dim = 0; + auto launch_params = + func->GetAttr>(tir::attr::kKernelLaunchParams) + .value(); + for (const auto &tag : launch_params) { + if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } + } + + if (work_dim != 0) { + // use ushort by default for now + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " threadIdx [[thread_position_in_threadgroup]]\n"; + } + thread_work_dim_ = work_dim; + + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(func->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +void CodeGenTileLangMetal::BindThreadIndex(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + // if we only have threadIdx.x + // metal will directly print as threadIdx + std::string vname = iv->thread_tag; + if (thread_work_dim_ <= 1) { + vname = vname.substr(0, iv->thread_tag.length() - 2); + } + var_idmap_[iv->var.get()] = + CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); +} + +void CodeGenTileLangMetal::PrintType(DataType t, + std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; + } + + if (t.is_void()) { + os << "void"; + return; + } + if (t == DataType::Bool()) { + os << "bool"; + return; + } + if (t.is_float() && t.bits() == 16 && lanes > 4 && lanes <= 8 && + lanes % 2 == 0) { + os << "uint" << lanes / 2; + return; + } + bool fail = false; + if (t.is_float()) { + if (lanes == 3) { + os << "packed_"; + } + switch (t.bits()) { + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + switch (t.bits()) { + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 64: + os << "long"; + break; + case 1: + os << "bool"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_bfloat16()) { + os << "bfloat"; + return; + } + LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; +} + +void CodeGenTileLangMetal::PrintStorageSync(const CallNode *op) { + const std::string &sync = op->args[0].as()->value; + if (sync == "warp") { + this->PrintIndent(); + this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "shared") { + this->PrintIndent(); + this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "global") { + LOG(FATAL) << "global barrier not supported"; + } +} + +void CodeGenTileLangMetal::PrintVecElemLoad(const std::string &vec, DataType t, + int i, + std::ostream &os) { // NOLINT(*) + if (t.is_float16() && t.lanes() > 4) { + os << "((thread half*)(&" << vec << "))[" << i << "]"; + } else { + os << vec << "[" << i << "]"; + } +} + +void CodeGenTileLangMetal::PrintVecElemStore(const std::string &vec, DataType t, + int i, const std::string &value) { + this->PrintIndent(); + if (t.is_float16() && t.lanes() > 4) { + stream << "((thread half*)(&" << vec << "))[" << i << "] = " << value + << ";\n"; + } else { + stream << vec << "[" << i << "]" + << " = " << value << ";\n"; + } +} + +void CodeGenTileLangMetal::PrintStorageScope(const std::string &scope, + std::ostream &os) { // NOLINT(*) + if (scope == "global") { + os << "device "; + } else if (scope == "shared" || scope == "shared.dyn") { + os << "threadgroup "; + } else if (scope == "local") { + os << "thread "; + } else { + LOG(FATAL) << "Unknown storage scope `" << scope << "`"; + } +} + +void CodeGenTileLangMetal::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + if (scope == "metal.simdgroup") { + ICHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Float(32) || + op->dtype == DataType::BFloat(16)) + << "Only float16, float32, and bfloat16 are supported, but got " + << op->dtype; + ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " + << constant_size << " bytes\n"; + + std::ostringstream dtype_os; + PrintType(op->dtype, dtype_os); + std::string dtype_str = dtype_os.str(); + simdgroup_dtype_[op->buffer_var.get()] = dtype_str; + stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' + << constant_size / 64 << "];\n"; + } else { + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; + } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenTileLangMetal::VisitExpr_(const SelectNode *op, + std::ostream &os) { // NOLINT(*) + os << "select(" << PrintExpr(op->false_value) << ", " + << PrintExpr(op->true_value) << ", " << PrintExpr(op->condition) << ")"; +} + +void CodeGenTileLangMetal::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + int lanes = op->dtype.lanes(); + if (op->dtype.is_float16() && lanes > 4 && lanes % 2 == 0) { + os << "uint" << lanes / 2 << "("; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "as_type(half2(" << v << ", " << v << "))"; + } + os << ')'; + } else { + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << ')'; + } +} + +void CodeGenTileLangMetal::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + CHECK(!op->op.as()) + << "CodegenMetal does not support inter-function calls, " + << "but expression " << ffi::GetRef(op) << " calls PrimFunc " + << op->op; + auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { + ICHECK(col->IsInstance() && row->IsInstance()) + << "Only constant shape is supported for simdgroup matrix, but got " + << col << "x" << row; + int col_val = col.as()->value; + int row_val = row.as()->value; + ICHECK(col_val == 8 && row_val == 8) + << "Only 8x8 matrix is supported, but got " << col_val << "x" + << row_val; + }; + if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { + ICHECK_EQ(op->args.size(), 5); + Var var = Downcast(op->args[0]); + // Get the data type of the simdgroup matrix + auto it = simdgroup_dtype_.find(var.get()); + ICHECK(it != simdgroup_dtype_.end()) + << "Cannot find variable allocation for simdgroup: " << var; + const std::string &dtype_str = it->second; + f_check_simdgroup_shape(op->args[3], op->args[4]); + os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) + << "] = make_filled_simdgroup_matrix<" << dtype_str << ", " + << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" + << PrintExpr(op->args[2]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_load())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" + << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_store())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" + << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) { + ICHECK_EQ(op->args.size(), 8); + os << "simdgroup_multiply_accumulate(" // + << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " // + << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " // + << PrintExpr(op->args[4]) << "[" << PrintExpr(op->args[5]) << "], " // + << PrintExpr(op->args[6]) << "[" << PrintExpr(op->args[7]) << "])"; + } else if (op->op.same_as(builtin::reinterpret())) { + // generate as_type(ARG) + os << "(as_type<"; + this->PrintType(op->dtype, os); + os << ">("; + this->PrintExpr(op->args[0], os); + os << "))"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << "INFINITY"; + } else if (std::isnan(op->value)) { + temp << "NAN"; + } else { + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) + temp << 'f'; + else if (op->dtype.bits() == 16) + temp << 'h'; + } + MarkConst(temp.str()); + os << temp.str(); +} + +ffi::Module BuildTileLangMetal(IRModule mod, Target target) { + bool output_ssa = false; + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + + std::ostringstream source_maker; + std::unordered_map smap; + const auto fmetal_compile = + tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile"); + std::string fmt = fmetal_compile ? "metallib" : "metal"; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangMetal: Can only take PrimFunc"; + auto global_symbol = + kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.has_value()); + std::string func_name = global_symbol.value(); + + source_maker << "// Function: " << func_name << "\n"; + CodeGenTileLangMetal cg(target); + cg.Init(output_ssa); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenTileLangMetal: expect calling_conv equals " + "CallingConv::kDeviceKernelLaunch"; + + cg.AddFunction(kv.first, f); + + std::string fsource = cg.Finish(); + source_maker << fsource << "\n"; + if (fmetal_compile) { + fsource = (*fmetal_compile)(fsource, target).cast(); + } + smap[func_name] = fsource; + } + + return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_metal", BuildTileLangMetal); +} +} // namespace codegen +} // namespace tvm diff --git a/src/target/codegen_metal.h b/src/target/codegen_metal.h new file mode 100644 index 0000000000..69e137bf7a --- /dev/null +++ b/src/target/codegen_metal.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_metal.h + * \brief Generate Metal device code. + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_METAL_H_ +#define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ + +#include + +#include +#include + +#include "target/source/codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangMetal final : public CodeGenC { +public: + explicit CodeGenTileLangMetal(Target target); + // override print thread tag. + void PrintArgUnionDecl(); + void AddFunction(const GlobalVar &gvar, const PrimFunc &func) final; + void InitFuncState(const PrimFunc &f) final; + void PrintStorageScope(const std::string &scope, + std::ostream &os) final; // NOLINT(*) + void PrintStorageSync(const CallNode *op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) + void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) + // print load of single element + void PrintVecElemLoad(const std::string &vec, DataType t, int i, + std::ostream &os) final; // NOLINT(*) + // print store of single element. + void PrintVecElemStore(const std::string &vec, DataType t, int i, + const std::string &value) final; + // overload visitor + void VisitStmt_(const AllocateNode *op) final; // NOLINT(*) + void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*) + + // reuse parent's function. + using CodeGenC::PrintType; + +private: + std::unordered_map simdgroup_dtype_; + int thread_index_bits_{32}; + int thread_work_dim_{0}; + Target target_; +}; +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_METAL_H_ diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 4cfdb6bf82..73baa98208 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -433,12 +433,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } - // Check that all local.fragment buffers have inferred layouts + // Check that all local.fragment buffers have inferred layouts. + // On Metal targets, fragment buffers used as GEMM accumulators are + // lowered to opaque simdgroup matrices, so they have no explicit + // thread-level layout and can be safely skipped. for (const auto &[buffer, _] : use_list_) { if (IsFragmentBuffer(buffer)) { - ICHECK_NE(layout_map.count(buffer), 0) - << "The layout for fragment " << buffer - << " can not be inferred correctly."; + if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { + ICHECK(false) << "The layout for fragment " << buffer + << " can not be inferred correctly."; + } } } diff --git a/src/transform/lower_device_storage_access_info.cc b/src/transform/lower_device_storage_access_info.cc index 1f3077843d..b9c54b724f 100644 --- a/src/transform/lower_device_storage_access_info.cc +++ b/src/transform/lower_device_storage_access_info.cc @@ -46,7 +46,7 @@ class StorageAccessInfoLower : public StmtExprMutator { auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" && scope.tag != ".barrier" && scope.tag != ".cluster_barrier" && - scope.tag.find(".descriptor") != 0) { + scope.tag != ".fragment" && scope.tag.find(".descriptor") != 0) { auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); diff --git a/testing/python/metal/test_metal_gemm_v2.py b/testing/python/metal/test_metal_gemm_v2.py new file mode 100644 index 0000000000..cb0ce26e8f --- /dev/null +++ b/testing/python/metal/test_metal_gemm_v2.py @@ -0,0 +1,91 @@ +"""Test Metal gemm_v2 with actual execution on Metal hardware. + +These tests verify correctness of T.gemm (gemm_v2) using simdgroup matrix +operations by comparing results against torch.matmul. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import torch + + +@tilelang.jit +def matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared") + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +def assert_gemm_v2( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float16, + accum_dtype=T.float32, + atol=1e-2, +): + jit_kernel = matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + + torch_dtype = dtype.as_torch() + torch_accum_dtype = accum_dtype.as_torch() + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_accum_dtype, device="mps") + + jit_kernel(a, b, c) + + ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) + assert torch.allclose(ref, c, atol=atol), ( + f"Result mismatch for M={M}, N={N}, K={K}, " + f"block=({block_M},{block_N},{block_K}), dtype={dtype}\n" + f"max diff: {(ref - c).abs().max().item()}" + ) + + +@tilelang.testing.requires_metal +def test_gemm_v2_16x16x16(): + assert_gemm_v2(128, 128, 128, 16, 16, 16) + + +@tilelang.testing.requires_metal +def test_gemm_v2_16x16x8(): + assert_gemm_v2(128, 128, 128, 16, 16, 8) + + +@tilelang.testing.requires_metal +def test_gemm_v2_large(): + assert_gemm_v2(128, 128, 128, 32, 32, 32) + + +@tilelang.testing.requires_metal +def test_gemm_v2_1024(): + assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) + + +if __name__ == "__main__": + if torch.mps.is_available(): + tilelang.testing.main() diff --git a/testing/python/metal/test_metal_gemm_v2_linux.py b/testing/python/metal/test_metal_gemm_v2_linux.py new file mode 100644 index 0000000000..2976646fc2 --- /dev/null +++ b/testing/python/metal/test_metal_gemm_v2_linux.py @@ -0,0 +1,82 @@ +"""Test Metal gemm_v2 code generation on any platform (including Linux). + +These tests verify that TileLang can compile kernels using T.gemm (gemm_v2) +down to Metal shader source code with simdgroup matrix operations, +without requiring a Metal runtime or macOS. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) + T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) + + return main + + +def assert_metal_gemm_v2_codegen( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float16, + accum_dtype=T.float32, +): + func = matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(func, target="metal") + + src_code = artifact.kernel_source + assert src_code is not None + assert "kernel void" in src_code + # Verify simdgroup matrix operations are present + assert "simdgroup_multiply_accumulate" in src_code + assert "simdgroup_load" in src_code + assert "simdgroup_store" in src_code + + +def test_metal_gemm_v2_float16(): + assert_metal_gemm_v2_codegen(64, 64, 64, 16, 16, 16, dtype=T.float16) + + +def test_metal_gemm_v2_float32(): + assert_metal_gemm_v2_codegen(64, 64, 64, 16, 16, 16, dtype=T.float32, accum_dtype=T.float32) + + +def test_metal_gemm_v2_larger(): + assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16) + + +def test_metal_gemm_v2_small_blocks(): + """Test with blocks where warp_rows > 1 and warp_cols > 1, which previously + produced incorrect results due to swizzle padding changing the stride. + """ + assert_metal_gemm_v2_codegen(16, 16, 16, 16, 16, 16, dtype=T.float16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/metal/test_metal_simdgroup_store.py b/testing/python/metal/test_metal_simdgroup_store.py new file mode 100644 index 0000000000..ea8429589c --- /dev/null +++ b/testing/python/metal/test_metal_simdgroup_store.py @@ -0,0 +1,133 @@ +"""Test Metal simdgroup register GEMM with direct simdgroup_store to device memory. + +These tests verify the simdgroup register accumulation path where C is allocated +in metal.simdgroup scope. This eliminates C_simd load/store round-trips through +shared memory on each K iteration. The final T.copy(C_local, C[...]) is lowered +to simdgroup_store directly to device memory via LowerSIMDGroupStore. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import torch + + +def _make_simdgroup_gemm_func(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +matmul_simdgroup = tilelang.jit(_make_simdgroup_gemm_func) + + +def assert_simdgroup_store_correctness(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, atol=1e-2): + kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + + torch_dtype = dtype.as_torch() + torch_accum_dtype = accum_dtype.as_torch() + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_accum_dtype, device="mps") + + kernel(a, b, c) + + ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) + assert torch.allclose(ref, c, atol=atol), ( + f"Result mismatch for M={M}, N={N}, K={K}, " + f"block=({block_M},{block_N},{block_K}), dtype={dtype}\n" + f"max diff: {(ref - c).abs().max().item()}" + ) + + +def assert_simdgroup_store_codegen(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + func = _make_simdgroup_gemm_func(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(func, target="metal") + + src = artifact.kernel_source + assert src is not None + assert "kernel void" in src + assert "simdgroup_multiply_accumulate" in src + assert "make_filled_simdgroup_matrix" in src + + assert "simdgroup_float8x8" in src or "simdgroup_half8x8" in src, "Expected simdgroup_float8x8 or simdgroup_half8x8 for C accumulator" + + store_to_device = src.count("simdgroup_store(C_local") + assert store_to_device > 0, "Expected simdgroup_store of C_local to device memory" + + load_c_from_shared = [line for line in src.split("\n") if "simdgroup_load" in line and "C_local" in line] + assert len(load_c_from_shared) == 0, f"C_local should not be loaded from shared memory, but found: {load_c_from_shared}" + + +# --- Codegen tests (cross-platform) --- + + +def test_codegen_square_small(): + assert_simdgroup_store_codegen(64, 64, 64, 16, 16, 16) + + +def test_codegen_square_large(): + assert_simdgroup_store_codegen(128, 128, 128, 32, 32, 32) + + +def test_codegen_non_square(): + assert_simdgroup_store_codegen(128, 128, 128, 32, 64, 16) + + +def test_codegen_float32_accum(): + assert_simdgroup_store_codegen(64, 64, 64, 16, 16, 16, dtype=T.float32, accum_dtype=T.float32) + + +# --- Correctness tests (require Metal hardware) --- + + +@tilelang.testing.requires_metal +def test_correctness_16x16x16(): + assert_simdgroup_store_correctness(128, 128, 128, 16, 16, 16) + + +@tilelang.testing.requires_metal +def test_correctness_32x32x32(): + assert_simdgroup_store_correctness(128, 128, 128, 32, 32, 32) + + +@tilelang.testing.requires_metal +def test_correctness_non_square_block(): + assert_simdgroup_store_correctness(128, 128, 128, 32, 64, 16) + + +@tilelang.testing.requires_metal +def test_correctness_64x64x32(): + assert_simdgroup_store_correctness(128, 128, 128, 64, 64, 32) + + +@tilelang.testing.requires_metal +def test_correctness_large_matrix(): + assert_simdgroup_store_correctness(1024, 1024, 1024, 32, 32, 32, atol=1.0) + + +@tilelang.testing.requires_metal +def test_correctness_non_square_matrix(): + assert_simdgroup_store_correctness(256, 512, 128, 32, 32, 16) + + +if __name__ == "__main__": + if torch.mps.is_available(): + tilelang.testing.main() diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index e900890300..d7c3547a40 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -236,7 +236,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: elif target.kind.name == "hip": device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) elif target.kind.name == "metal": - device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) else: raise ValueError(f"Target {target.kind.name} is not supported") @@ -261,7 +261,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> elif target.kind.name == "webgpu": device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target) elif target.kind.name == "metal": - device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) else: raise ValueError(f"Target {target.kind.name} is not supported") diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5563845214..854d9e73c4 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -197,6 +197,11 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.Simplify()(mod) + # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup + # before layout inference (which would otherwise require a layout for them) + from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup + + mod = MetalFragmentToSimdgroup(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) # Visualize the layout diff --git a/tilelang/intrinsics/metal_macro_generator.py b/tilelang/intrinsics/metal_macro_generator.py new file mode 100644 index 0000000000..647d474dc6 --- /dev/null +++ b/tilelang/intrinsics/metal_macro_generator.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import tilelang.language as T +from tvm import tir +from tvm.tir import Buffer, BufferRegion + + +class MPSIntrinEmitter: + WARP_SIZE = 32 + + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float32", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 1, + block_col_warps: int = 1, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 32, + thread_var: tir.Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self.thread_var = thread_var + + # Metal simdgroup matrix size (always 8x8) + self.micro_size_x = 8 + self.micro_size_y = 8 + self.micro_size_k = 8 + + # Number of 8x8 tiles per warp + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def _get_warp_indices(self): + thread_binding = self.get_thread_binding() + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + warp_m = (thread_binding // WARP_SIZE) % block_row_warps + warp_n = (thread_binding // (WARP_SIZE * block_row_warps)) % block_col_warps + return warp_m, warp_n + + @staticmethod + def _parse_buffer_2d(buf): + """Extract (buffer, row_offset, col_offset, stride) from Buffer or BufferRegion.""" + if isinstance(buf, BufferRegion): + buffer = buf.buffer + off_row = buf.region[-2].min + off_col = buf.region[-1].min + else: + buffer = buf + off_row = 0 + off_col = 0 + stride = buffer.shape[-1] + return buffer, off_row, off_col, stride + + def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki): + warp_rows = self.warp_rows + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + a_transposed = self.a_transposed + + warp_m, _ = self._get_warp_indices() + + buffer, offset_m, offset_k, stride = self._parse_buffer_2d(A_shared_buf) + + @T.macro + def _warp_ldmatrix_a(A_local_buf, buffer, offset_m, offset_k, stride, warp_m, ki): + for i in T.serial(warp_rows): + if a_transposed: + row_idx = offset_k + ki * micro_size_k + col_idx = offset_m + warp_m * (self.warp_row_tiles) + i * micro_size_x + else: + row_idx = offset_m + warp_m * (self.warp_row_tiles) + i * micro_size_x + col_idx = offset_k + ki * micro_size_k + + ptr = T.access_ptr(buffer[row_idx, col_idx], "r") + + T.simdgroup_load( + A_local_buf.data, + i, + ptr, + stride, + micro_size_x, + micro_size_k, + T.bool(a_transposed), + ) + + return _warp_ldmatrix_a(A_local_buf, buffer, offset_m, offset_k, stride, warp_m, ki) + + def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki): + warp_cols = self.warp_cols + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + b_transposed = self.b_transposed + + _, warp_n = self._get_warp_indices() + + buffer, offset_k, offset_n, stride = self._parse_buffer_2d(B_shared_buf) + + @T.macro + def _warp_ldmatrix_b(B_local_buf, buffer, offset_k, offset_n, stride, warp_n, ki): + for j in T.serial(warp_cols): + if b_transposed: + row_idx = offset_n + warp_n * (self.warp_col_tiles) + j * micro_size_y + col_idx = offset_k + ki * micro_size_k + else: + row_idx = offset_k + ki * micro_size_k + col_idx = offset_n + warp_n * (self.warp_col_tiles) + j * micro_size_y + + ptr = T.access_ptr(buffer[row_idx, col_idx], "r") + + T.simdgroup_load( + B_local_buf.data, + j, + ptr, + stride, + micro_size_k, + micro_size_y, + T.bool(b_transposed), + ) + + return _warp_ldmatrix_b(B_local_buf, buffer, offset_k, offset_n, stride, warp_n, ki) + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.simdgroup_multiply_accumulate( + C_local_buf.data, + i * warp_cols + j, + A_local_buf.data, + i, + B_local_buf.data, + j, + C_local_buf.data, + i * warp_cols + j, + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + def simdgroup_copy(self, C_simd_buf, C_dst, is_store=True): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + micro_size_x = self.micro_size_x + micro_size_y = self.micro_size_y + + warp_m, warp_n = self._get_warp_indices() + + buffer, offset_m, offset_n, stride = self._parse_buffer_2d(C_dst) + + simd_op = T.simdgroup_store if is_store else T.simdgroup_load + access_mode = "w" if is_store else "r" + + @T.macro + def _simdgroup_copy(C_simd_buf, buffer, offset_m, offset_n, stride, warp_m, warp_n): + for i, j in T.grid(warp_rows, warp_cols): + row = offset_m + warp_m * self.warp_row_tiles + i * micro_size_x + col = offset_n + warp_n * self.warp_col_tiles + j * micro_size_y + + index_c = i * warp_cols + j + + simd_op( + C_simd_buf.data, + index_c, + T.access_ptr(buffer[row, col], access_mode), + stride, + micro_size_x, + micro_size_y, + T.bool(False), + ) + + return _simdgroup_copy(C_simd_buf, buffer, offset_m, offset_n, stride, warp_m, warp_n) + + def simd_store(self, C_simd_buf, C_dst): + return self.simdgroup_copy(C_simd_buf, C_dst, is_store=True) + + def simd_load(self, C_simd_buf, C_src): + return self.simdgroup_copy(C_simd_buf, C_src, is_store=False) diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 4690cf59bd..841f9b7cd5 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -53,6 +53,9 @@ def __init__( _kernel = None + def get_kernel_source(self, kernel_only: bool = True) -> str: + return self.kernel_global_source + def _convert_torch_func(self) -> Callable: if self._kernel is None: _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index a2290b0191..8117f83eac 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -13,8 +13,9 @@ from .gemm_mfma import GemmMFMA from .gemm_wmma import GemmWMMA from .gemm_scalar import GemmScalar +from .gemm_metal import GemmMetal from tilelang import _ffi_api -from tilelang.utils.target import target_is_volta +from tilelang.utils.target import target_is_volta, target_is_metal @tvm_ffi.register_global_func("tl.gemm.infer_layout") @@ -157,8 +158,10 @@ def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst 1. TCGEN5MMA for Blackwell architecture 2. WGMMA for Hopper architecture with sufficient matrix size and warp count 3. MFMA for CDNA (AMD) architecture - 4. MMA for CUDA architecture - 5. Scalar for CPU target (scalar fallback) + 4. WMMA for RDNA (AMD) architecture + 5. MMA for CUDA architecture + 6. METAL_SIMDGROUP for Metal target (simdgroup_matrix) + 7. Scalar for CPU target (scalar fallback) Args: thread_nums: Number of threads in the block @@ -167,6 +170,8 @@ def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst Returns: GemmInst: The selected GEMM instruction type """ + if target_is_metal(target): + return GemmInst.METAL_SIMDGROUP return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): @@ -197,5 +202,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmWMMA elif gemm_inst.is_scalar(): return GemmScalar + elif gemm_inst.is_metal_simdgroup(): + return GemmMetal else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/tileop/gemm/gemm_metal.py b/tilelang/tileop/gemm/gemm_metal.py new file mode 100644 index 0000000000..fb4bff1f33 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_metal.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from .gemm_base import GemmBase +from .inst import GemmInst +from tilelang.utils.language import is_shared, is_full_region, is_metal_simdgroup, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm.ir import Range +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMetal(GemmBase): + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def infer_layout(self, target: Target, thread_nums: int): + return {} + + def lower( + self, layout_map: dict, target: Target, thread_bounds: Range, thread_var: tir.Var, mbar_phase_expr: tir.PrimExpr | None = None + ): + thread_nums = thread_bounds.extent + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.METAL_SIMDGROUP) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + + from tilelang.intrinsics.metal_macro_generator import MPSIntrinEmitter + + mps_emitter = MPSIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + accum_dtype = self.accum_dtype + warp_rows = mps_emitter.warp_rows + warp_cols = mps_emitter.warp_cols + num_simd_c = warp_rows * warp_cols + block_K = mps_emitter.chunk + micro_size_k = mps_emitter.micro_size_k + + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + C_buf = C_region.buffer + + clear_accum = self.clear_accum + c_in_simdgroup_reg = is_metal_simdgroup(C_buf) or is_fragment(C_buf) + + assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert is_full_region(C_region), "Fragment output C must be a full region" + assert c_in_simdgroup_reg or is_shared(C_buf), ( + f"Metal GEMM requires C in local.fragment, metal.simdgroup, or shared scope, got {C_buf.scope()}" + ) + + if self.is_gemm_ss(): + if c_in_simdgroup_reg: + + @T.prim_func + def _gemm_ss_simdgroup() -> None: + A_local = T.alloc_local((warp_rows * 64), in_dtype, scope="metal.simdgroup") + B_local = T.alloc_local((warp_cols * 64), in_dtype, scope="metal.simdgroup") + if clear_accum: + for _i in T.serial(num_simd_c): + T.make_filled_simdgroup_matrix(C_buf.data, _i, T.cast(0, accum_dtype)) + for ki in T.serial(0, (block_K // micro_size_k)): + mps_emitter.ldmatrix_a(A_local, A_region, ki) + mps_emitter.ldmatrix_b(B_local, B_region, ki) + mps_emitter.mma(A_local, B_local, C_buf) + + return _Simplify(_gemm_ss_simdgroup, inline_let=True) + else: + + @T.prim_func + def _gemm_ss_shared() -> None: + A_local = T.alloc_local((warp_rows * 64), in_dtype, scope="metal.simdgroup") + B_local = T.alloc_local((warp_cols * 64), in_dtype, scope="metal.simdgroup") + C_simd = T.alloc_local((num_simd_c * 64), accum_dtype, scope="metal.simdgroup") + if clear_accum: + for _i in T.serial(num_simd_c): + T.make_filled_simdgroup_matrix(C_simd.data, _i, T.cast(0, accum_dtype)) + else: + mps_emitter.simd_load(C_simd, C_buf) + for ki in T.serial(0, (block_K // micro_size_k)): + mps_emitter.ldmatrix_a(A_local, A_region, ki) + mps_emitter.ldmatrix_b(B_local, B_region, ki) + mps_emitter.mma(A_local, B_local, C_simd) + + mps_emitter.simd_store(C_simd, C_buf) + + return _Simplify(_gemm_ss_shared, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") diff --git a/tilelang/tileop/gemm/inst.py b/tilelang/tileop/gemm/inst.py index cbfaf016f2..66902a42cc 100644 --- a/tilelang/tileop/gemm/inst.py +++ b/tilelang/tileop/gemm/inst.py @@ -8,7 +8,8 @@ class GemmInst(IntEnum): TCGEN5MMA = 2 MFMA = 3 Scalar = 4 - WMMA = 5 # AMD RDNA WMMA (gfx11/gfx12) + WMMA = 5 + METAL_SIMDGROUP = 6 def is_mma(self) -> bool: return self == GemmInst.MMA @@ -28,5 +29,8 @@ def is_scalar(self) -> bool: def is_wmma(self) -> bool: return self == GemmInst.WMMA + def is_metal_simdgroup(self) -> bool: + return self == GemmInst.METAL_SIMDGROUP + def __repr__(self) -> str: return self.name diff --git a/tilelang/transform/decouple_type_cast.py b/tilelang/transform/decouple_type_cast.py index eafd8b36f5..8c1d20234b 100644 --- a/tilelang/transform/decouple_type_cast.py +++ b/tilelang/transform/decouple_type_cast.py @@ -68,14 +68,14 @@ # Cache the Op for if_then_else to avoid repeated lookups _IF_THEN_ELSE_OP = Op.get("tir.if_then_else") -from tilelang.utils.language import is_fragment, is_global, is_local, is_local_var, is_shared +from tilelang.utils.language import is_fragment, is_global, is_local, is_local_var, is_shared, is_metal_simdgroup def is_local_buffer(buffer: Buffer) -> bool: - """Check if a buffer is local (register-level), including local.var.""" + """Check if a buffer is local (register-level), including local.var and metal.simdgroup.""" if buffer is None: return False - return is_local(buffer) or is_fragment(buffer) or is_local_var(buffer) + return is_local(buffer) or is_fragment(buffer) or is_local_var(buffer) or is_metal_simdgroup(buffer) def is_global_or_shared_buffer(buffer: Buffer) -> bool: diff --git a/tilelang/transform/metal_fragment_to_simdgroup.py b/tilelang/transform/metal_fragment_to_simdgroup.py new file mode 100644 index 0000000000..0b5dd1bda7 --- /dev/null +++ b/tilelang/transform/metal_fragment_to_simdgroup.py @@ -0,0 +1,133 @@ +"""Rewrite local.fragment → metal.simdgroup for GEMM accumulators on Metal.""" + +from __future__ import annotations + +from tvm import tir, IRModule +from tvm.ir import Op, PointerType +from tvm.tir.transform import prim_func_pass + +_GEMM_OPS = None + + +def _get_gemm_ops(): + global _GEMM_OPS + if _GEMM_OPS is None: + _GEMM_OPS = { + Op.get("tl.tileop.gemm"), + Op.get("tl.tileop.wgmma_gemm"), + Op.get("tl.tileop.tcgen05_gemm"), + } + return _GEMM_OPS + + +def _extract_buffer_var_from_region(region_call): + if not isinstance(region_call, tir.Call): + return None + if len(region_call.args) < 1: + return None + buf_load = region_call.args[0] + if isinstance(buf_load, tir.BufferLoad): + return buf_load.buffer.data + return None + + +def _collect_fragment_gemm_accum_vars(body: tir.Stmt) -> set: + accum_vars: set = set() + gemm_ops = _get_gemm_ops() + + def _visitor(stmt): + if isinstance(stmt, tir.Evaluate) and isinstance(stmt.value, tir.Call): + call = stmt.value + if call.op in gemm_ops and len(call.args) >= 3: + var = _extract_buffer_var_from_region(call.args[2]) + if var is not None and hasattr(var, "type_annotation"): + ta = var.type_annotation + if ta is not None and hasattr(ta, "storage_scope") and ta.storage_scope == "local.fragment": + accum_vars.add(var) + + tir.stmt_functor.post_order_visit(body, _visitor) + return accum_vars + + +def _remap_buffer(buf, var_map): + old_data = buf.data + new_data = var_map.get(old_data, None) + if new_data is None: + return buf + return tir.decl_buffer( + buf.shape, + buf.dtype, + buf.name, + data=new_data, + scope="metal.simdgroup", + data_alignment=buf.data_alignment, + offset_factor=buf.offset_factor, + ) + + +def _rewrite_scope(body, var_map): + buf_map = {} + + def _pre_order(stmt): + if isinstance(stmt, tir.Block): + new_alloc_bufs = [] + changed = False + for buf in stmt.alloc_buffers: + new_buf = _remap_buffer(buf, var_map) + new_alloc_bufs.append(new_buf) + if not new_buf.same_as(buf): + buf_map[buf] = new_buf + changed = True + if changed: + new_body = tir.stmt_functor.substitute(stmt.body, var_map) + new_block = tir.Block( + stmt.iter_vars, + stmt.reads, + stmt.writes, + stmt.name_hint, + new_body, + stmt.init, + new_alloc_bufs, + stmt.match_buffers, + stmt.annotations, + ) + return ( + tir.BlockRealize( + stmt.iter_vars, + tir.const(True, "bool"), + new_block, + ) + if False + else new_block + ) + elif isinstance(stmt, tir.Allocate): + new_var = var_map.get(stmt.buffer_var, None) + if new_var is not None: + new_body = tir.stmt_functor.substitute(stmt.body, var_map) + return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations) + return None + + return tir.stmt_functor.ir_transform(body, _pre_order, None, ["tir.Block", "tir.Allocate"]) + + +def _metal_fragment_to_simdgroup(func: tir.PrimFunc, mod: IRModule, ctx) -> tir.PrimFunc: + target = func.attrs.get("target", None) + if target is None or target.kind.name != "metal": + return func + + accum_vars = _collect_fragment_gemm_accum_vars(func.body) + if not accum_vars: + return func + + var_map: dict = {} + for var in accum_vars: + ptr_type = var.type_annotation + new_ptr = PointerType(ptr_type.element_type, "metal.simdgroup") + new_var = tir.Var(var.name, new_ptr) + var_map[var] = new_var + + new_body = _rewrite_scope(func.body, var_map) + return func.with_body(new_body) + + +MetalFragmentToSimdgroup = prim_func_pass(_metal_fragment_to_simdgroup, opt_level=0, name="tl.MetalFragmentToSimdgroup") diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 53a5730ca2..0ad52839e4 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -118,6 +118,20 @@ def is_fragment(buffer: BufferLikeType) -> bool: return buffer.scope().startswith("local.fragment") +def is_metal_simdgroup(buffer: BufferLikeType) -> bool: + """ + Check if the buffer is in the Metal simdgroup scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in metal.simdgroup scope, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope() == "metal.simdgroup" + + def is_local_var(buffer: BufferLikeType) -> bool: """ Check if the buffer is in the local.var memory scope. From a3fdb11d58ffbae1aa55573d14ce38ce37417a64 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 7 May 2026 15:32:21 +0800 Subject: [PATCH 2/3] Move Metal buffer helpers to Metal backend --- src/backend/metal/op/copy.cc | 1 + src/backend/metal/op/fill.cc | 1 + src/backend/metal/op/utils.h | 27 +++++++++++++++++++++++++++ src/op/utils.h | 8 -------- 4 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 src/backend/metal/op/utils.h diff --git a/src/backend/metal/op/copy.cc b/src/backend/metal/op/copy.cc index de16c76bd7..cdf5c96451 100644 --- a/src/backend/metal/op/copy.cc +++ b/src/backend/metal/op/copy.cc @@ -5,6 +5,7 @@ #include "op/copy.h" +#include "backend/metal/op/utils.h" #include "op/utils.h" #include "target/utils.h" diff --git a/src/backend/metal/op/fill.cc b/src/backend/metal/op/fill.cc index 92bb48a6e2..ea5e5bcc7d 100644 --- a/src/backend/metal/op/fill.cc +++ b/src/backend/metal/op/fill.cc @@ -5,6 +5,7 @@ #include "op/fill.h" +#include "backend/metal/op/utils.h" #include "op/utils.h" #include "target/utils.h" #include "transform/loop_partition.h" diff --git a/src/backend/metal/op/utils.h b/src/backend/metal/op/utils.h new file mode 100644 index 0000000000..90134f4267 --- /dev/null +++ b/src/backend/metal/op/utils.h @@ -0,0 +1,27 @@ +/*! + * \file tl/backend/metal/op/utils.h + * \brief Metal-specific operator helpers. + */ + +#ifndef TVM_TL_BACKEND_METAL_OP_UTILS_H_ +#define TVM_TL_BACKEND_METAL_OP_UTILS_H_ + +#include "op/utils.h" + +namespace tvm { +namespace tl { +namespace metal { + +inline bool IsSIMDGroupBuffer(const tir::Buffer &buffer) { + return buffer.defined() && buffer.scope() == "metal.simdgroup"; +} + +inline bool IsRegisterBuffer(const tir::Buffer &buffer) { + return IsFragmentBuffer(buffer) || IsSIMDGroupBuffer(buffer); +} + +} // namespace metal +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_BACKEND_METAL_OP_UTILS_H_ diff --git a/src/op/utils.h b/src/op/utils.h index 1e8b6221ca..fb5066ef48 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -57,14 +57,6 @@ inline bool IsFragmentBuffer(const Buffer &buffer) { return buffer.defined() && buffer.scope() == "local.fragment"; } -inline bool IsSIMDGroupBuffer(const Buffer &buffer) { - return buffer.defined() && buffer.scope() == "metal.simdgroup"; -} - -inline bool IsRegisterBuffer(const Buffer &buffer) { - return IsFragmentBuffer(buffer) || IsSIMDGroupBuffer(buffer); -} - // Expand a lower-rank layout by prepending the leading dimensions of `buffer` // so that the resulting layout input shape matches `buffer->shape`. // From 950d0093e50251d5a1ecb62ffd3cb503452f86d9 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 11 May 2026 22:12:57 +0800 Subject: [PATCH 3/3] fix(metal): remove layout_inference bypass, fix build & clarify docs - Remove duplicate BuildTileLangMetal in rt_mod_metal.cc (fixes linker error) - Remove TargetIsMetal bypass in layout_inference.cc. The bypass is unnecessary because MetalFragmentToSimdgroup already converts all GEMM accumulator fragment buffers to metal.simdgroup scope before LayoutInference runs, so IsFragmentBuffer never matches them. - Clarify docstrings for MetalFragmentToSimdgroup pass explaining why it runs before LayoutInference (simdgroup matrices are opaque and have no explicit thread-level layout). - Restore is_metal_simdgroup check in GemmMetal.lower() so the GEMM lowering correctly handles buffers already converted to simdgroup scope by the pass. --- src/backend/metal/codegen/rt_mod_metal.cc | 18 ++++++------------ src/transform/layout_inference.cc | 12 ++++-------- tilelang/engine/phase.py | 3 ++- tilelang/tileop/gemm/gemm_metal.py | 10 +++++----- .../transform/metal_fragment_to_simdgroup.py | 18 +++++++----------- 5 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/backend/metal/codegen/rt_mod_metal.cc b/src/backend/metal/codegen/rt_mod_metal.cc index 00f3098897..6ea00035ed 100644 --- a/src/backend/metal/codegen/rt_mod_metal.cc +++ b/src/backend/metal/codegen/rt_mod_metal.cc @@ -2,10 +2,10 @@ * \file rt_mod_metal.cc * \brief Metal codegen entry point. * - * Metal codegen is handled by CodeGenCHost (target/codegen_c_host.cc), which - * has built-in Metal context support via the is_in_metal_context flag. - * When IR contains AttrStmt with attr_key == "metal_context", the host - * codegen emits Metal-specific dispatch_sync / MTLCommandBuffer code. + * Metal codegen is implemented in target/codegen_metal.cc, which handles + * simdgroup types, intrinsics, and MSL emission. + * This file exists to satisfy the backend/metal/CMakeLists.txt dependency + * but delegates to the main implementation. */ #include "target/codegen_c_host.h" @@ -14,14 +14,8 @@ namespace tvm { namespace codegen { -ffi::Module BuildTileLangMetal(IRModule mod, Target target) { - return tl::BuildTileLangCHost(mod, target); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("target.build.tilelang_metal", BuildTileLangMetal); -} +// Metal codegen entry point is in target/codegen_metal.cc. +// This backend path is kept for future migration. } // namespace codegen } // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 73baa98208..4cfdb6bf82 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -433,16 +433,12 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } - // Check that all local.fragment buffers have inferred layouts. - // On Metal targets, fragment buffers used as GEMM accumulators are - // lowered to opaque simdgroup matrices, so they have no explicit - // thread-level layout and can be safely skipped. + // Check that all local.fragment buffers have inferred layouts for (const auto &[buffer, _] : use_list_) { if (IsFragmentBuffer(buffer)) { - if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { - ICHECK(false) << "The layout for fragment " << buffer - << " can not be inferred correctly."; - } + ICHECK_NE(layout_map.count(buffer), 0) + << "The layout for fragment " << buffer + << " can not be inferred correctly."; } } diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 854d9e73c4..c3a3419bd6 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -198,7 +198,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.Simplify()(mod) # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup - # before layout inference (which would otherwise require a layout for them) + # before layout inference. simdgroup matrices are opaque and have no + # explicit thread-level layout, so layout inference must not see them. from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup mod = MetalFragmentToSimdgroup(mod) diff --git a/tilelang/tileop/gemm/gemm_metal.py b/tilelang/tileop/gemm/gemm_metal.py index 3e52e9f531..a3794faa25 100644 --- a/tilelang/tileop/gemm/gemm_metal.py +++ b/tilelang/tileop/gemm/gemm_metal.py @@ -1,7 +1,7 @@ from __future__ import annotations from .gemm_base import GemmBase -from tilelang.utils.language import is_shared, is_full_region, is_metal_simdgroup, is_fragment +from tilelang.utils.language import is_shared, is_full_region, is_fragment, is_metal_simdgroup from tilelang import tvm as tvm from tvm.target import Target from tvm.ir import Range @@ -61,16 +61,16 @@ def lower( C_buf = C_region.buffer clear_accum = self.clear_accum - c_in_simdgroup_reg = is_metal_simdgroup(C_buf) or is_fragment(C_buf) + c_in_register = is_fragment(C_buf) or is_metal_simdgroup(C_buf) assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert is_full_region(C_region), "Fragment output C must be a full region" - assert c_in_simdgroup_reg or is_shared(C_buf), ( - f"Metal GEMM requires C in local.fragment, metal.simdgroup, or shared scope, got {C_buf.scope()}" + assert c_in_register or is_shared(C_buf), ( + f"Metal GEMM requires C in local.fragment, metal.simdgroup or shared scope, got {C_buf.scope()}" ) if self.is_gemm_ss(): - if c_in_simdgroup_reg: + if c_in_register: @T.prim_func def _gemm_ss_simdgroup() -> None: diff --git a/tilelang/transform/metal_fragment_to_simdgroup.py b/tilelang/transform/metal_fragment_to_simdgroup.py index 0b5dd1bda7..4e3f067857 100644 --- a/tilelang/transform/metal_fragment_to_simdgroup.py +++ b/tilelang/transform/metal_fragment_to_simdgroup.py @@ -1,4 +1,9 @@ -"""Rewrite local.fragment → metal.simdgroup for GEMM accumulators on Metal.""" +"""Rewrite local.fragment → metal.simdgroup for GEMM accumulator buffers on Metal. + +This pass runs after pipelining and before LayoutInference, so that +simdgroup matrices (which are hardware-opaque and have no explicit +thread-level layout) are never seen by LayoutInference. +""" from __future__ import annotations @@ -80,7 +85,7 @@ def _pre_order(stmt): changed = True if changed: new_body = tir.stmt_functor.substitute(stmt.body, var_map) - new_block = tir.Block( + return tir.Block( stmt.iter_vars, stmt.reads, stmt.writes, @@ -91,15 +96,6 @@ def _pre_order(stmt): stmt.match_buffers, stmt.annotations, ) - return ( - tir.BlockRealize( - stmt.iter_vars, - tir.const(True, "bool"), - new_block, - ) - if False - else new_block - ) elif isinstance(stmt, tir.Allocate): new_var = var_map.get(stmt.buffer_var, None) if new_var is not None: