diff --git a/benchmark/matmul_metal/benchmark_matmul_metal.py b/benchmark/matmul_metal/benchmark_matmul_metal.py new file mode 100644 index 0000000000..546c596071 --- /dev/null +++ b/benchmark/matmul_metal/benchmark_matmul_metal.py @@ -0,0 +1,126 @@ +import argparse +import logging +import time + +import torch + +import tilelang +import tilelang.language as T + +logging.getLogger("tilelang").setLevel(logging.WARNING) + +BLOCK_CONFIGS = [ + (16, 16, 16), + (32, 32, 16), + (32, 32, 32), + (64, 64, 32), +] + + +@tilelang.jit +def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32): + + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +def _tflops(M, N, K, seconds): + return 2.0 * M * N * K / seconds / 1e12 + + +def _bench(fn, warmup, repeats): + for _ in range(warmup): + fn() + torch.mps.synchronize() + t0 = time.perf_counter() + for _ in range(repeats): + fn() + torch.mps.synchronize() + return (time.perf_counter() - t0) / repeats + + +def bench_torch_mps(M, N, K, warmup, repeats): + a = torch.randn(M, K, dtype=torch.float16, device="mps") + b = torch.randn(K, N, dtype=torch.float16, device="mps") + avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats) + return _tflops(M, N, K, avg_s) + + +def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats): + kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K) + a = torch.randn(M, K, dtype=torch.float16, device="mps") + b = torch.randn(K, N, dtype=torch.float16, device="mps") + c = torch.zeros(M, N, dtype=torch.float32, device="mps") + avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats) + return _tflops(M, N, K, avg_s) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)") + parser.add_argument("--m", type=int, default=4096) + parser.add_argument("--n", type=int, default=4096) + parser.add_argument("--k", type=int, default=4096) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--repeats", type=int, default=100) + parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + for name, value in (("m", M), ("n", N), ("k", K), ("repeats", args.repeats)): + if value <= 0: + raise SystemExit(f"--{name} must be a positive integer, got {value}") + if args.warmup < 0: + raise SystemExit(f"--warmup must be non-negative, got {args.warmup}") + + print(f"torch: {torch.__version__}") + print(f"tilelang: {tilelang.__version__}") + print(f"MPS: {torch.backends.mps.is_available()}") + if not torch.backends.mps.is_available(): + raise SystemExit("Metal GEMM benchmark requires PyTorch MPS support") + print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}") + print() + + ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats) + print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS") + print() + + configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)] + + print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}") + print("-" * 44) + + best_tflops = 0.0 + best_config = configs[0] + for bM, bN, bK in configs: + try: + tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) + ratio = tl / ref_tflops * 100 + tag = "" + if tl > best_tflops: + best_tflops = tl + best_config = (bM, bN, bK) + print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") + except Exception as e: + print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") + + if args.sweep: + print() + print(f"Best config: {best_config}") + print(f"Best TFlops: {best_tflops:.1f}") + print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}") diff --git a/pyproject.toml b/pyproject.toml index 80aba41e32..d2c046df43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ # requirement as wide as possible to be compatible with other libraries # pip will try to use latest version whenever possible. "apache-tvm-ffi~=0.1.0,>=0.1.2", + "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'", # torch-c-dlpack-ext provides prebuilt torch extensions. # Without it, TVM FFI may require JIT compilation on first import. "torch-c-dlpack-ext; python_version < '3.14'", diff --git a/requirements-dev.txt b/requirements-dev.txt index f8dccdc871..e4403be050 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # Requirements to run local build with `--no-build-isolation` or other developments apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi<0.1.8; platform_system == 'Darwin' build cmake>=3.26 cython>=3.1.0 diff --git a/requirements.txt b/requirements.txt index 2dbe070d9a..9c841cee5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Runtime requirements apache-tvm-ffi~=0.1.0,>=0.1.2 +apache-tvm-ffi<0.1.8; platform_system == 'Darwin' torch-c-dlpack-ext; python_version < '3.14' cloudpickle ml-dtypes diff --git a/src/backend/metal/CMakeLists.txt b/src/backend/metal/CMakeLists.txt index 9dbf33204a..6b1c789fbb 100644 --- a/src/backend/metal/CMakeLists.txt +++ b/src/backend/metal/CMakeLists.txt @@ -1,4 +1,12 @@ # Metal backend: source files and build configuration. + +# Metal codegen is pure C++ and can generate Metal shader source on any +# platform. Always compile it so target.build.tilelang_metal is available for +# cross-compilation and source-level tests on non-Apple hosts. +list(APPEND TILE_LANG_SRCS + src/target/codegen_metal.cc +) + if(NOT USE_METAL) return() endif() @@ -7,7 +15,7 @@ if(NOT APPLE) # On non-Apple platforms USE_METAL=ON enables only codegen (Metal source # generation) without requiring the Metal/Foundation frameworks. message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)") - set(USE_METAL OFF) + return() endif() file(GLOB TILE_LANG_METAL_SRCS diff --git a/src/op/copy.cc b/src/op/copy.cc index 93bd0cf70b..6d7d05b7ac 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, return result_map; } + if (copy_inst == CopyInst::kMetalSIMDGroup) { + return {}; + } + // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -792,8 +796,47 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map, if (!CheckCPAsyncCopyPreconditions()) { return false; } - // Skip vectorize size check here because, during the Infer Layout stage, - // the layout is not stable and the vectorized size cannot be determined. + return true; +} + +bool CopyNode::CheckSIMDGroupCopy(Target target) const { + if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) { + return false; + } + if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) { + return false; + } + if (src->dtype != dst->dtype) { + return false; + } + if (src_range.size() != 2 || dst_range.size() != 2 || + dst->shape.size() != 2) { + return false; + } + + int total_elements = 1; + for (auto extent : src->shape) { + auto imm = extent.as(); + if (!imm) { + return false; + } + total_elements *= imm->value; + } + if (total_elements % 64 != 0) { + return false; + } + + for (int i = 0; i < 2; ++i) { + auto src_shape = src->shape[i].as(); + auto src_min = src_range[i]->min.as(); + auto src_extent = src_range[i]->extent.as(); + auto dst_extent = dst_range[i]->extent.as(); + if (!src_shape || !src_min || src_min->value != 0 || !src_extent || + !dst_extent || src_extent->value != src_shape->value || + src_extent->value != dst_extent->value || src_extent->value % 8 != 0) { + return false; + } + } return true; } @@ -864,6 +907,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map, return CopyInst::kTMemLoad; } else if (CheckTMemStore(target)) { return CopyInst::kTMemStore; + } else if (CheckSIMDGroupCopy(target)) { + return CopyInst::kMetalSIMDGroup; } else { return CopyInst::kNormal; } @@ -897,6 +942,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto cp_async_copy = LowerCPAsyncCopy(T, analyzer); ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy"; return cp_async_copy; + } else if (copy_inst == CopyInst::kMetalSIMDGroup) { + return LowerSIMDGroupCopy(T, analyzer); } else if (copy_inst == CopyInst::kNormal) { return LowerNormalCopy(T, analyzer); } else { @@ -982,7 +1029,125 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T, return cp_async_loop; } -// Lowers the copy using standard load/store with loop transformations. +Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { + ICHECK(IsSIMDGroupBuffer(src)); + int total_elements = 1; + for (auto s : src->shape) { + auto imm = s.as(); + ICHECK(imm) << "simdgroup buffer must have constant shape"; + total_elements *= imm->value; + } + ICHECK(total_elements % 64 == 0) + << "simdgroup buffer size must be multiple of 64 (8x8), got " + << total_elements; + + ICHECK(dst_range.size() == 2) + << "Expected 2D destination for simdgroup store"; + PrimExpr dst_row_base = dst_range[0]->min; + PrimExpr dst_col_base = dst_range[1]->min; + ICHECK_EQ(dst->shape.size(), 2U) + << "simdgroup store currently supports 2D destination buffers"; + Array dst_strides = dst->strides; + if (dst_strides.empty()) { + PrimExpr stride = 1; + dst_strides.resize(dst->shape.size()); + for (int i = static_cast(dst->shape.size()) - 1; i >= 0; --i) { + dst_strides.Set(i, stride); + stride *= dst->shape[i]; + } + } + if (dst_strides.size() != dst->shape.size()) { + return LowerNormalCopy(T, analyzer); + } + if (!analyzer->CanProveEqual(dst_strides[1], 1)) { + return LowerNormalCopy(T, analyzer); + } + PrimExpr dst_stride = dst_strides[0]; + + int warp_size = TargetGetWarpSize(T.target); + auto block_extent = T.thread_bounds->extent.as(); + if (!block_extent || warp_size <= 0 || block_extent->value % warp_size != 0) { + return LowerNormalCopy(T, analyzer); + } + int block_size = block_extent->value; + int num_warps = block_size / warp_size; + if (num_warps <= 0) { + return LowerNormalCopy(T, analyzer); + } + PrimExpr relative_thread = T.thread_var - T.thread_bounds->min; + PrimExpr warp_id = FloorDiv(relative_thread, warp_size); + + auto M_imm = src_range[0]->extent.as(); + auto N_imm = src_range[1]->extent.as(); + if (!M_imm || !N_imm) { + return LowerNormalCopy(T, analyzer); + } + int M = M_imm->value; + int N = N_imm->value; + + int kMPerWarp = 8; + int kNPerWarp = 8; + int m_warp = 1, n_warp = num_warps; + int max_m = M / kMPerWarp; + int max_n = N / kNPerWarp; + if (max_m <= 0 || max_n <= 0) { + return LowerNormalCopy(T, analyzer); + } + float ideal = N > 0 ? static_cast(M) / N : 1.f; + float best_score = std::numeric_limits::max(); + for (int m = 1; m <= std::min(num_warps, max_m); ++m) { + if (num_warps % m != 0) + continue; + int n = num_warps / m; + if (n > max_n) + continue; + if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0) + continue; + float m_per = static_cast(M) / (m * kMPerWarp); + float n_per = static_cast(N) / (n * kNPerWarp); + float score = std::abs(m_per / n_per - ideal); + if (score < best_score) { + best_score = score; + m_warp = m; + n_warp = n; + } + } + + if (best_score == std::numeric_limits::max() || M < m_warp * 8 || + N < n_warp * 8) { + return LowerNormalCopy(T, analyzer); + } + int warp_row_tiles = M / m_warp / 8; + int warp_col_tiles = N / n_warp / 8; + if (warp_row_tiles <= 0 || warp_col_tiles <= 0 || + warp_row_tiles * warp_col_tiles * 64 > total_elements) { + return LowerNormalCopy(T, analyzer); + } + + PrimExpr warp_m = FloorMod(warp_id, m_warp); + PrimExpr warp_n = FloorDiv(warp_id, m_warp); + + Array stmts; + for (int i = 0; i < warp_row_tiles; i++) { + for (int j = 0; j < warp_col_tiles; j++) { + int tile_idx = i * warp_col_tiles + j; + PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8; + PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8; + PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(dst, {row, col})}); + stmts.push_back(Evaluate( + Call(DataType::Handle(), builtin::simdgroup_store(), + {src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride, + IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8), + Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))}))); + } + } + if (stmts.size() == 1) + return stmts[0]; + return SeqStmt(stmts); +} + Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; diff --git a/src/op/copy.h b/src/op/copy.h index d20f519815..4a80b75e46 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -24,10 +24,11 @@ enum class CopyInst : uint8_t { kCPAsync = 5, // cp.async global->shared copy // we should separate the bulk load and store for 1d and multi-dim // as they have different memory access patterns - kBulkLoad1D = 6, // utilize tma load 1d - kBulkStore1D = 7, // utilize tma store 1d - kTMemLoad = 8, // tcgen05.ld (tensor memory -> register) - kTMemStore = 9, // tcgen05.st (register -> tensor memory) + kBulkLoad1D = 6, // utilize tma load 1d + kBulkStore1D = 7, // utilize tma store 1d + kTMemLoad = 8, // tcgen05.ld (tensor memory -> register) + kTMemStore = 9, // tcgen05.st (register -> tensor memory) + kMetalSIMDGroup = 10, // Metal simdgroup load/store }; /// Convert CopyInst enum to string for debugging @@ -53,6 +54,8 @@ inline const char *CopyInstToString(CopyInst inst) { return "TMemLoad"; case CopyInst::kTMemStore: return "TMemStore"; + case CopyInst::kMetalSIMDGroup: + return "MetalSIMDGroup"; default: return "Unknown"; } @@ -290,6 +293,11 @@ class CopyNode : public TileOperatorNode { arith::Analyzer *analyzer) const; protected: + /*! + * \brief Check if copy from Metal simdgroup to shared/global is supported. + */ + bool CheckSIMDGroupCopy(Target target) const; + /*! * \brief Get the copy instruction type. */ @@ -331,6 +339,11 @@ class CopyNode : public TileOperatorNode { */ Stmt LowerCPAsyncCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! + * \brief Generate lowering for simdgroup store. + */ + Stmt LowerSIMDGroupCopy(const LowerArgs &T, arith::Analyzer *analyzer) const; + /*! * \brief Generate SIMT (thread-level) loop for copying. */ diff --git a/src/op/fill.cc b/src/op/fill.cc index be0dd8dc10..6d88563f30 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -156,7 +156,51 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { * @return Stmt The lowered TIR statement implementing the fill. */ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - if (IsFragmentBuffer(dst)) { + if (IsSIMDGroupBuffer(dst)) { + int region_elements = 1; + for (auto r : region) { + auto imm = r->extent.as(); + ICHECK(imm) << "simdgroup fill region must have constant extents"; + region_elements *= imm->value; + } + int total_elements = region_elements; + ICHECK(total_elements % 64 == 0) + << "simdgroup buffer size must be multiple of 64 (8x8), got " + << total_elements; + int num_matrices = total_elements / 64; + PrimExpr fill_value = Cast(dst->dtype, value); + Array strides = dst->strides; + if (strides.empty()) { + PrimExpr stride = 1; + strides.resize(dst->shape.size()); + for (int i = static_cast(dst->shape.size()) - 1; i >= 0; --i) { + strides.Set(i, stride); + stride *= dst->shape[i]; + } + } + ICHECK_EQ(strides.size(), dst->shape.size()) + << "simdgroup fill requires complete destination strides"; + PrimExpr element_offset = 0; + for (size_t i = 0; i < region.size(); ++i) { + element_offset += region[i]->min * strides[i]; + } + PrimExpr matrix_elements = IntImm(element_offset.dtype(), 64); + ICHECK( + analyzer->CanProveEqual(FloorMod(element_offset, matrix_elements), 0)) + << "simdgroup fill region must start on an 8x8 matrix boundary"; + PrimExpr matrix_index_base = FloorDiv(element_offset, matrix_elements); + Array stmts; + for (int i = 0; i < num_matrices; i++) { + stmts.push_back(Evaluate( + Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(), + {dst->data, matrix_index_base + IntImm(DataType::Int(32), i), + fill_value, IntImm(DataType::Int(32), 8), + IntImm(DataType::Int(32), 8)}))); + } + if (stmts.size() == 1) + return stmts[0]; + return SeqStmt(stmts); + } else if (IsFragmentBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5472b33386..1a5bc65a59 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -183,6 +183,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const { return GemmInst::kMMA; } else if (TargetIsCPU(target)) { return GemmInst::kScalar; + } else if (TargetIsMetal(target)) { + return GemmInst::kMetalSimdgroup; } else { ICHECK(0) << "Unsupported target for gemm: " << target->str(); return GemmInst::kMMA; @@ -199,8 +201,11 @@ std::pair GemmWarpPolicyNode::computeWarpPartition( } int m_warp = 1, n_warp = 1; - constexpr int kMPerWarp = 16; // Rows processed by a single warp - int kNPerWarp = 8; // Columns processed by a single warp + int kMPerWarp = 16; // Rows processed by a single warp + if (TargetIsMetal(target)) { + kMPerWarp = 8; + } + int kNPerWarp = 8; // Columns processed by a single warp if (TargetIsVolta(target)) { kNPerWarp = 16; } else if (TargetIsCDNA(target)) { diff --git a/src/op/gemm.h b/src/op/gemm.h index 26b6678402..90f85b00ef 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -45,7 +45,8 @@ enum class GemmInst : uint8_t { kTCGEN5MMA, kMFMA, kScalar, - kWMMA + kWMMA, + kMetalSimdgroup }; /// Convert GemmInst enum to string for debugging @@ -63,6 +64,8 @@ inline const char *GemmInstToString(GemmInst inst) { return "Scalar"; case GemmInst::kWMMA: return "WMMA"; + case GemmInst::kMetalSimdgroup: + return "Metal"; default: return "Unknown"; } diff --git a/src/op/parallel.cc b/src/op/parallel.cc index b93467b54f..e31e94426b 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -359,7 +359,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, info && info.value()->rep == ReducerRepType::ALL) continue; - auto frag = T.layout_map[buffer].as().value(); bool is_fully_replicated = IsBufferCompletelyReplicated(buffer, T.layout_map); @@ -379,7 +378,9 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // If the buffer is not replicated and shape is equal to the // source_buffer, use it as source_buffer because the layout inference // is more accurate - if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) { + auto frag = T.layout_map[buffer].as(); + if (frag.has_value() && is_one(frag.value()->ReplicateExtent()) && + !source_buffer.defined()) { source_buffer = buffer; } } diff --git a/src/op/utils.h b/src/op/utils.h index 77e21feda6..8831fa6877 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -53,11 +53,18 @@ TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, TVM_DLL PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask); -// Check if a buffer is a fragment buffer (scope == "local.fragment") inline bool IsFragmentBuffer(const Buffer &buffer) { return buffer.defined() && buffer.scope() == "local.fragment"; } +inline bool IsSIMDGroupBuffer(const Buffer &buffer) { + return buffer.defined() && buffer.scope() == "metal.simdgroup"; +} + +inline bool IsRegisterBuffer(const Buffer &buffer) { + return IsFragmentBuffer(buffer) || IsSIMDGroupBuffer(buffer); +} + // Expand a lower-rank layout by prepending the leading dimensions of `buffer` // so that the resulting layout input shape matches `buffer->shape`. // diff --git a/src/target/codegen_metal.cc b/src/target/codegen_metal.cc new file mode 100644 index 0000000000..faa43a3558 --- /dev/null +++ b/src/target/codegen_metal.cc @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_metal.cc + */ +#include "codegen_metal.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "../op/builtin.h" +#include "runtime/metal/metal_module.h" +#include "runtime/thread_storage_scope.h" +#include "target/build_common.h" + +namespace tvm { +namespace codegen { + +void CodeGenTileLangMetal::InitFuncState(const PrimFunc &f) { + CodeGenC::InitFuncState(f); + // analyze the data; + for (Var arg : f->params) { + if (arg.dtype().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } +} + +CodeGenTileLangMetal::CodeGenTileLangMetal(Target target) : target_(target) { + decl_stream << "#include \n"; + decl_stream << "using namespace metal;\n\n"; + decl_stream << "union __TVMArgUnion {\n" + << " int v_int[2];\n" + << "};\n\n"; +} + +void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar, + const PrimFunc &func) { + // NOTE: There is no inter-function calls among Metal kernels. + // For now we keep the metal codegen without inter-function call + // process. + // We can switch to follow the flow with inter-function call process + // after the Metal function declaration is properly printed. + // In Metal, for PrimFuncs with signature + // def func(A: Buffer, B: Buffer, x: int, y: float) -> None + // where there are trailing pod parameters, the codegen emits a struct + // struct func_params{ x: int; y: float; } + // for the function. In the flow of inter-function call process, + // the struct will be emitted for every time a function is declared. + // So consequently there are duplicate appearances of a same struct, + // which makes the Metal compiler unable to recognize. + + // clear previous generated state. + this->InitFuncState(func); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); + + // add to alloc buffer type. + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.has_value()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + + // Function header. + this->stream << "kernel void " + << static_cast(global_symbol.value()) << "("; + + // Buffer arguments + size_t num_buffer = 0; + size_t limit = + target_->GetAttr("max_function_args").value().IntValue(); + if (func->params.size() > limit) { + LOG(WARNING) << "Probably you won't be able to execute your kernel due to " + "high number of " + "buffers in the kernel"; + } + for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { + Var v = func->params[i]; + if (!v.dtype().is_handle()) + break; + this->stream << " "; + std::string vid = AllocVarID(v.get()); + auto it = alloc_storage_scope_.find(v.get()); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, this->stream); + } + PrintType(GetType(v), this->stream); + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + if (auto *ptr = v->type_annotation.as()) { + if (auto *prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } + } + this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + } + // Setup normal arguments. + size_t nargs = func->params.size() - num_buffer; + std::string varg = name_supply_->FreshName("arg"); + if (nargs != 0) { + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; + this->stream << " constant " << arg_buf_type << "& " << varg + << " [[ buffer(" << num_buffer << ") ]],\n"; + // declare the struct + decl_stream << "struct " << arg_buf_type << " {\n"; + for (size_t i = num_buffer; i < func->params.size(); ++i) { + Var v = func->params[i]; + ICHECK(!v.dtype().is_handle()); + std::string vid = AllocVarID(v.get()); + std::ostringstream vref; + if (v.dtype().bits() == 32) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << "[2];\n"; + vref << varg << "." << vid << "[0]"; + } else if (v.dtype().bits() == 64) { + decl_stream << " "; + PrintType(v.dtype(), decl_stream); + decl_stream << " " << vid << ";\n"; + vref << varg << "." << vid; + } else { + // For non 32bit type, ref through arg union. + decl_stream << " __TVMArgUnion " << vid << ";\n"; + vref << varg << "." << vid << ".v_"; + PrintType(v.dtype(), vref); + } + var_idmap_[v.get()] = vref.str(); + } + decl_stream << "};\n\n"; + } + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + int work_dim = 0; + auto launch_params = + func->GetAttr>(tir::attr::kKernelLaunchParams) + .value(); + for (const auto &tag : launch_params) { + if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } + } + + if (work_dim != 0) { + // use ushort by default for now + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + stream << " threadIdx [[thread_position_in_threadgroup]]\n"; + } + thread_work_dim_ = work_dim; + + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(func->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +void CodeGenTileLangMetal::BindThreadIndex(const IterVar &iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + // if we only have threadIdx.x + // metal will directly print as threadIdx + std::string vname = iv->thread_tag; + if (thread_work_dim_ <= 1) { + vname = vname.substr(0, iv->thread_tag.length() - 2); + } + var_idmap_[iv->var.get()] = + CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); +} + +void CodeGenTileLangMetal::PrintType(DataType t, + std::ostream &os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; + } + + if (t.is_void()) { + os << "void"; + return; + } + if (t == DataType::Bool()) { + os << "bool"; + return; + } + if (t.is_float() && t.bits() == 16 && lanes > 4 && lanes <= 8 && + lanes % 2 == 0) { + os << "uint" << lanes / 2; + return; + } + bool fail = false; + if (t.is_float()) { + if (lanes == 3) { + os << "packed_"; + } + switch (t.bits()) { + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + switch (t.bits()) { + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 64: + os << "long"; + break; + case 1: + os << "bool"; + break; + default: + fail = true; + break; + } + if (!fail && lanes == 1) + return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_bfloat16()) { + os << "bfloat"; + return; + } + LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; +} + +void CodeGenTileLangMetal::PrintStorageSync(const CallNode *op) { + const std::string &sync = op->args[0].as()->value; + if (sync == "warp") { + this->PrintIndent(); + this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "shared") { + this->PrintIndent(); + this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "global") { + LOG(FATAL) << "global barrier not supported"; + } +} + +void CodeGenTileLangMetal::PrintVecElemLoad(const std::string &vec, DataType t, + int i, + std::ostream &os) { // NOLINT(*) + if (t.is_float16() && t.lanes() > 4) { + os << "((thread half*)(&" << vec << "))[" << i << "]"; + } else { + os << vec << "[" << i << "]"; + } +} + +void CodeGenTileLangMetal::PrintVecElemStore(const std::string &vec, DataType t, + int i, const std::string &value) { + this->PrintIndent(); + if (t.is_float16() && t.lanes() > 4) { + stream << "((thread half*)(&" << vec << "))[" << i << "] = " << value + << ";\n"; + } else { + stream << vec << "[" << i << "]" + << " = " << value << ";\n"; + } +} + +void CodeGenTileLangMetal::PrintStorageScope(const std::string &scope, + std::ostream &os) { // NOLINT(*) + if (scope == "global") { + os << "device "; + } else if (scope == "shared" || scope == "shared.dyn") { + os << "threadgroup "; + } else if (scope == "local") { + os << "thread "; + } else { + LOG(FATAL) << "Unknown storage scope `" << scope << "`"; + } +} + +void CodeGenTileLangMetal::VisitStmt_(const AllocateNode *op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + if (scope == "metal.simdgroup") { + ICHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Float(32) || + op->dtype == DataType::BFloat(16)) + << "Only float16, float32, and bfloat16 are supported, but got " + << op->dtype; + ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " + << constant_size << " bytes\n"; + + std::ostringstream dtype_os; + PrintType(op->dtype, dtype_os); + std::string dtype_str = dtype_os.str(); + simdgroup_dtype_[op->buffer_var.get()] = dtype_str; + stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' + << constant_size / 64 << "];\n"; + } else if (scope == "local.var") { + ICHECK(op->dtype.is_scalar()) + << "Vector local.var allocation is not supported."; + ICHECK_EQ(constant_size, 1) + << "Only scalar local.var allocation is supported."; + PrimExpr init = tir::make_const(op->dtype, 0); + auto init_it = op->annotations.find(tl::attr::kLocalVarInit); + if (init_it != op->annotations.end()) { + PrimExpr user_init = Downcast((*init_it).second); + if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) { + user_init = tir::Cast(op->dtype, user_init); + } + init = user_init; + } + PrintType(op->dtype, stream); + stream << ' ' << vid << " = " << PrintExpr(init) << ";\n"; + } else { + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; + } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenTileLangMetal::VisitExpr_(const BufferLoadNode *op, + std::ostream &os) { // NOLINT(*) + std::string scope; + auto it = alloc_storage_scope_.find(op->buffer->data.get()); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } + if (scope.empty()) { + scope = GetPtrStorageScope(op->buffer->data); + } + if (scope == "local.var") { + ICHECK_EQ(op->indices.size(), 1) + << "Load from non-flat local.var memory not supported."; + ICHECK(op->dtype.is_scalar()) << "Vector local.var load is not supported."; + auto index = op->indices[0].as(); + ICHECK(index && index->value == 0) + << "local.var load requires scalar index 0."; + os << GetVarID(op->buffer->data.get()); + return; + } + CodeGenC::VisitExpr_(op, os); +} + +void CodeGenTileLangMetal::VisitStmt_(const BufferStoreNode *op) { + std::string scope; + auto it = alloc_storage_scope_.find(op->buffer->data.get()); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } + if (scope.empty()) { + scope = GetPtrStorageScope(op->buffer->data); + } + if (scope == "local.var") { + ICHECK_EQ(op->indices.size(), 1) + << "Store to non-flat local.var memory not supported."; + ICHECK(op->value.dtype().is_scalar()) + << "Vector local.var store is not supported."; + auto index = op->indices[0].as(); + ICHECK(index && index->value == 0) + << "local.var store requires scalar index 0."; + this->PrintIndent(); + stream << GetVarID(op->buffer->data.get()) << " = " << PrintExpr(op->value) + << ";\n"; + return; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenTileLangMetal::VisitExpr_(const SelectNode *op, + std::ostream &os) { // NOLINT(*) + os << "select(" << PrintExpr(op->false_value) << ", " + << PrintExpr(op->true_value) << ", " << PrintExpr(op->condition) << ")"; +} + +void CodeGenTileLangMetal::VisitExpr_(const BroadcastNode *op, + std::ostream &os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + int lanes = op->dtype.lanes(); + if (op->dtype.is_float16() && lanes > 4 && lanes % 2 == 0) { + os << "uint" << lanes / 2 << "("; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "as_type(half2(" << v << ", " << v << "))"; + } + os << ')'; + } else { + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < lanes; ++i) { + if (i != 0) + os << ", "; + os << v; + } + os << ')'; + } +} + +void CodeGenTileLangMetal::VisitExpr_(const CallNode *op, + std::ostream &os) { // NOLINT(*) + CHECK(!op->op.as()) + << "CodegenMetal does not support inter-function calls, " + << "but expression " << ffi::GetRef(op) << " calls PrimFunc " + << op->op; + auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { + ICHECK(col->IsInstance() && row->IsInstance()) + << "Only constant shape is supported for simdgroup matrix, but got " + << col << "x" << row; + int col_val = col.as()->value; + int row_val = row.as()->value; + ICHECK(col_val == 8 && row_val == 8) + << "Only 8x8 matrix is supported, but got " << col_val << "x" + << row_val; + }; + if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { + ICHECK_EQ(op->args.size(), 5); + Var var = Downcast(op->args[0]); + // Get the data type of the simdgroup matrix + auto it = simdgroup_dtype_.find(var.get()); + ICHECK(it != simdgroup_dtype_.end()) + << "Cannot find variable allocation for simdgroup: " << var; + const std::string &dtype_str = it->second; + f_check_simdgroup_shape(op->args[3], op->args[4]); + os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) + << "] = make_filled_simdgroup_matrix<" << dtype_str << ", " + << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" + << PrintExpr(op->args[2]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_load())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" + << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_store())) { + ICHECK_EQ(op->args.size(), 7); + f_check_simdgroup_shape(op->args[4], op->args[5]); + os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" + << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " + << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) { + ICHECK_EQ(op->args.size(), 8); + os << "simdgroup_multiply_accumulate(" // + << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " // + << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " // + << PrintExpr(op->args[4]) << "[" << PrintExpr(op->args[5]) << "], " // + << PrintExpr(op->args[6]) << "[" << PrintExpr(op->args[7]) << "])"; + } else if (op->op.same_as(builtin::reinterpret())) { + // generate as_type(ARG) + os << "(as_type<"; + this->PrintType(op->dtype, os); + os << ">("; + this->PrintExpr(op->args[0], os); + os << "))"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op, + std::ostream &os) { // NOLINT(*) + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << "INFINITY"; + } else if (std::isnan(op->value)) { + temp << "NAN"; + } else { + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) + temp << 'f'; + else if (op->dtype.bits() == 16) + temp << 'h'; + } + MarkConst(temp.str()); + os << temp.str(); +} + +ffi::Module BuildTileLangMetal(IRModule mod, Target target) { + bool output_ssa = false; + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + + std::ostringstream source_maker; + std::unordered_map smap; + const auto fmetal_compile = + tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile"); + std::string fmt = fmetal_compile ? "metallib" : "metal"; + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) + << "CodeGenTileLangMetal: Can only take PrimFunc"; + auto global_symbol = + kv.second->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.has_value()); + std::string func_name = global_symbol.value(); + + source_maker << "// Function: " << func_name << "\n"; + CodeGenTileLangMetal cg(target); + cg.Init(output_ssa); + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenTileLangMetal: expect calling_conv equals " + "CallingConv::kDeviceKernelLaunch"; + + cg.AddFunction(kv.first, f); + + std::string fsource = cg.Finish(); + source_maker << fsource << "\n"; + if (fmetal_compile) { + fsource = (*fmetal_compile)(fsource, target).cast(); + } + smap[func_name] = fsource; + } + + return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("target.build.tilelang_metal", BuildTileLangMetal); +} +} // namespace codegen +} // namespace tvm diff --git a/src/target/codegen_metal.h b/src/target/codegen_metal.h new file mode 100644 index 0000000000..3a711b4ee4 --- /dev/null +++ b/src/target/codegen_metal.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_metal.h + * \brief Generate Metal device code. + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_METAL_H_ +#define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ + +#include + +#include +#include + +#include "target/source/codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenTileLangMetal final : public CodeGenC { +public: + explicit CodeGenTileLangMetal(Target target); + // override print thread tag. + void PrintArgUnionDecl(); + void AddFunction(const GlobalVar &gvar, const PrimFunc &func) final; + void InitFuncState(const PrimFunc &f) final; + void PrintStorageScope(const std::string &scope, + std::ostream &os) final; // NOLINT(*) + void PrintStorageSync(const CallNode *op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) + void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) + // print load of single element + void PrintVecElemLoad(const std::string &vec, DataType t, int i, + std::ostream &os) final; // NOLINT(*) + // print store of single element. + void PrintVecElemStore(const std::string &vec, DataType t, int i, + const std::string &value) final; + // overload visitor + void VisitStmt_(const AllocateNode *op) final; // NOLINT(*) + void VisitStmt_(const BufferStoreNode *op) final; // NOLINT(*) + void VisitExpr_(const BufferLoadNode *op, + std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*) + + // reuse parent's function. + using CodeGenC::PrintType; + +private: + std::unordered_map simdgroup_dtype_; + int thread_index_bits_{32}; + int thread_work_dim_{0}; + Target target_; +}; +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_METAL_H_ diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 4cfdb6bf82..73baa98208 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -433,12 +433,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } - // Check that all local.fragment buffers have inferred layouts + // Check that all local.fragment buffers have inferred layouts. + // On Metal targets, fragment buffers used as GEMM accumulators are + // lowered to opaque simdgroup matrices, so they have no explicit + // thread-level layout and can be safely skipped. for (const auto &[buffer, _] : use_list_) { if (IsFragmentBuffer(buffer)) { - ICHECK_NE(layout_map.count(buffer), 0) - << "The layout for fragment " << buffer - << " can not be inferred correctly."; + if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) { + ICHECK(false) << "The layout for fragment " << buffer + << " can not be inferred correctly."; + } } } diff --git a/src/transform/lower_device_storage_access_info.cc b/src/transform/lower_device_storage_access_info.cc index 1f3077843d..b9c54b724f 100644 --- a/src/transform/lower_device_storage_access_info.cc +++ b/src/transform/lower_device_storage_access_info.cc @@ -46,7 +46,7 @@ class StorageAccessInfoLower : public StmtExprMutator { auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" && scope.tag != ".barrier" && scope.tag != ".cluster_barrier" && - scope.tag.find(".descriptor") != 0) { + scope.tag != ".fragment" && scope.tag.find(".descriptor") != 0) { auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); diff --git a/testing/python/jit/test_tilelang_jit_adapter_mps.py b/testing/python/jit/test_tilelang_jit_adapter_mps.py new file mode 100644 index 0000000000..ddd93b705e --- /dev/null +++ b/testing/python/jit/test_tilelang_jit_adapter_mps.py @@ -0,0 +1,47 @@ +"""Focused tests for JIT adapter device selection without CUDA.""" + +from types import SimpleNamespace + +import torch + +from tilelang.jit.adapter.base import BaseKernelAdapter + + +def test_current_device_functor_prefers_mps_when_cuda_unavailable(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + if getattr(torch.backends, "mps", None) is None: + monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False) + else: + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True) + + device_functor = BaseKernelAdapter.get_current_device_functor() + + assert device_functor() == torch.device("mps") + + +def test_current_device_functor_prefers_mps_when_cuda_init_fails(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "_lazy_init", lambda: (_ for _ in ()).throw(RuntimeError("cuda init failed"))) + + if getattr(torch.backends, "mps", None) is None: + monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: True), raising=False) + else: + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True) + + device_functor = BaseKernelAdapter.get_current_device_functor() + + assert device_functor() == torch.device("mps") + + +def test_current_device_functor_falls_back_to_cpu_without_cuda_or_mps(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + if getattr(torch.backends, "mps", None) is None: + monkeypatch.setattr(torch.backends, "mps", SimpleNamespace(is_available=lambda: False), raising=False) + else: + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) + + device_functor = BaseKernelAdapter.get_current_device_functor() + + assert device_functor() == torch.device("cpu") diff --git a/testing/python/metal/metal_internal_runtime_coverage.md b/testing/python/metal/metal_internal_runtime_coverage.md new file mode 100644 index 0000000000..464180846c --- /dev/null +++ b/testing/python/metal/metal_internal_runtime_coverage.md @@ -0,0 +1,26 @@ +# Metal Internal Runtime Coverage + +This document summarizes internal-only Metal backend coverage for scalar lowering, simdgroup helpers, packed quantization probes, and GDN/attention-style tiled kernels. It does not add public `T.rt`/`T.rv` aliases and does not introduce model checkpoints, production integration, MPP/cooperative lowering, MPSGraph, CUDA, or native fp8/fp4 Metal storage. + +## Runtime-validated coverage + +- Packed quant matmul: `M=16,N=32,K=64`, synthetic packed `uint8` fp8 activations, packed `uint8` fp4 weights, `uint8` e8m0 activation/weight scales, and fp32 output. MPS output is compared against a CPU decode/reference matmul. +- GDN/attention-style staged component probe: `chunk=16,key_dim=16,value_dim=16`, deterministic synthetic fp32 tensors, staged 8x8 KKT score accumulation over two key-dimension slices, scalar gate/causal triangular mask, and tiled W/U accumulation. MPS output is compared against Torch reference. +- Smaller runtime probes remain in `testing/python/metal/test_metal_internal_scaffolding.py` and related focused Metal tests. + +## Source-boundary-only / fail-closed coverage + +- Native Metal fp8/fp4 storage remains intentionally unsupported and fail-closed; component probes keep the packed `uint8` boundary. +- RegisterTile/RowVector helpers remain internal under `tilelang.tileop`; no public language aliases are added. +- Component probes assert that forbidden external backend tokens (`cooperative`, `mpp`, `mpsgraph`, `cuda`, etc.) are absent from generated Metal source. + +## Known blockers and deferrals + +- This coverage is correctness/scaffolding only; it does not optimize component-scale performance. +- Packed quant matmul uses scalar per-output decode/accumulation rather than native fp8/fp4 tensor storage or a production quantized GEMM lowering. +- GDN/attention-style coverage remains synthetic and chunk-local; no checkpoint-bound integration or full production recurrent/chunked scheduler is included. + +## Verification hooks + +- Default focused tests: `python3 -m pytest testing/python/metal/test_metal_internal_scaffolding.py -q` +- Opt-in component timing hook: `TILELANG_RUN_METAL_COMPONENT_BENCH=1 python3 -m pytest testing/python/metal/test_metal_internal_scaffolding.py::test_component_synthetic_runtime_benchmarks_opt_in -q -s` diff --git a/testing/python/metal/test_metal_gemm_v2.py b/testing/python/metal/test_metal_gemm_v2.py new file mode 100644 index 0000000000..cb0ce26e8f --- /dev/null +++ b/testing/python/metal/test_metal_gemm_v2.py @@ -0,0 +1,91 @@ +"""Test Metal gemm_v2 with actual execution on Metal hardware. + +These tests verify correctness of T.gemm (gemm_v2) using simdgroup matrix +operations by comparing results against torch.matmul. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import torch + + +@tilelang.jit +def matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_shared((block_M, block_N), accum_dtype, scope="shared") + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +def assert_gemm_v2( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float16, + accum_dtype=T.float32, + atol=1e-2, +): + jit_kernel = matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + + torch_dtype = dtype.as_torch() + torch_accum_dtype = accum_dtype.as_torch() + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_accum_dtype, device="mps") + + jit_kernel(a, b, c) + + ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) + assert torch.allclose(ref, c, atol=atol), ( + f"Result mismatch for M={M}, N={N}, K={K}, " + f"block=({block_M},{block_N},{block_K}), dtype={dtype}\n" + f"max diff: {(ref - c).abs().max().item()}" + ) + + +@tilelang.testing.requires_metal +def test_gemm_v2_16x16x16(): + assert_gemm_v2(128, 128, 128, 16, 16, 16) + + +@tilelang.testing.requires_metal +def test_gemm_v2_16x16x8(): + assert_gemm_v2(128, 128, 128, 16, 16, 8) + + +@tilelang.testing.requires_metal +def test_gemm_v2_large(): + assert_gemm_v2(128, 128, 128, 32, 32, 32) + + +@tilelang.testing.requires_metal +def test_gemm_v2_1024(): + assert_gemm_v2(1024, 1024, 1024, 16, 16, 16, atol=1.0) + + +if __name__ == "__main__": + if torch.mps.is_available(): + tilelang.testing.main() diff --git a/testing/python/metal/test_metal_gemm_v2_linux.py b/testing/python/metal/test_metal_gemm_v2_linux.py new file mode 100644 index 0000000000..2976646fc2 --- /dev/null +++ b/testing/python/metal/test_metal_gemm_v2_linux.py @@ -0,0 +1,82 @@ +"""Test Metal gemm_v2 code generation on any platform (including Linux). + +These tests verify that TileLang can compile kernels using T.gemm (gemm_v2) +down to Metal shader source code with simdgroup matrix operations, +without requiring a Metal runtime or macOS. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T + + +def matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) + T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) + + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) + + return main + + +def assert_metal_gemm_v2_codegen( + M, + N, + K, + block_M, + block_N, + block_K, + dtype=T.float16, + accum_dtype=T.float32, +): + func = matmul_gemm_v2(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(func, target="metal") + + src_code = artifact.kernel_source + assert src_code is not None + assert "kernel void" in src_code + # Verify simdgroup matrix operations are present + assert "simdgroup_multiply_accumulate" in src_code + assert "simdgroup_load" in src_code + assert "simdgroup_store" in src_code + + +def test_metal_gemm_v2_float16(): + assert_metal_gemm_v2_codegen(64, 64, 64, 16, 16, 16, dtype=T.float16) + + +def test_metal_gemm_v2_float32(): + assert_metal_gemm_v2_codegen(64, 64, 64, 16, 16, 16, dtype=T.float32, accum_dtype=T.float32) + + +def test_metal_gemm_v2_larger(): + assert_metal_gemm_v2_codegen(128, 128, 128, 32, 32, 32, dtype=T.float16) + + +def test_metal_gemm_v2_small_blocks(): + """Test with blocks where warp_rows > 1 and warp_cols > 1, which previously + produced incorrect results due to swizzle padding changing the stride. + """ + assert_metal_gemm_v2_codegen(16, 16, 16, 16, 16, 16, dtype=T.float16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/metal/test_metal_internal_scaffolding.py b/testing/python/metal/test_metal_internal_scaffolding.py new file mode 100644 index 0000000000..3e750f555f --- /dev/null +++ b/testing/python/metal/test_metal_internal_scaffolding.py @@ -0,0 +1,879 @@ +"""Internal-only Metal scaffolding/source-boundary/runtime probes. + +These tests intentionally exercise private helper modules under ``tilelang.tileop`` +without adding public language aliases such as ``T.rt`` or ``T.rv``. + +Coverage notes: +- Runtime-validates packed uint8 fp8/fp4/e8m0 decode on MPS against a CPU + reference using synthetic tensors only. +- Runtime-validates a GDN/attention-style KKT 8x8 score tile on MPS against + ``torch.matmul`` using synthetic tensors only. +- Runtime-validates a small packed fp8 activation x packed fp4 weight matmul + with e8m0 scale bytes on MPS. +- Runtime-validates a small GDN/attention-style staged W/U tile on MPS using + internal RegisterTile helpers plus scalar local.var state. +- Runtime-validates component-scale packed quant matmul (`M=16,N=32,K=64`) and + GDN/attention-style staged KKT/gate/WU probes + (`chunk=16,key_dim=16,value_dim=16`) using synthetic tensors only. +- The internal RegisterTile/simdgroup helper is runtime-validated on MPS for + one 8x8 fp32 MMA and remains source-boundary checked for PR #1869 tokens. +- Native Metal fp8/fp4 storage remains fail-closed/source-boundary only. +- Optional timing hooks are disabled by default and require + ``TILELANG_RUN_METAL_SMALL_BENCH=1``, ``TILELANG_RUN_METAL_SCALED_BENCH=1``, + or ``TILELANG_RUN_METAL_COMPONENT_BENCH=1``. +""" + +import os +import subprocess +import sys +import textwrap +import time +from pathlib import Path + +import pytest +import torch + +import tilelang +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.tileop import metal_gdn, metal_quant, metal_simdgroup as metal_sg + + +_FORBIDDEN_EXTERNAL_TOKENS = ( + "cooperative", + "mpp", + "mpsgraph", + "warpgroup", + "cp.async", + "tcgen", + "tma", + "tl.ptx", + "tl.cuda", +) + + +def _lower_source(func) -> str: + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(func, target="metal") + assert artifact.kernel_source is not None + return artifact.kernel_source + + +def _assert_clean_metal_source(src: str) -> None: + lowered = src.lower() + for token in _FORBIDDEN_EXTERNAL_TOKENS: + assert token not in lowered, f"unexpected token {token!r} in generated Metal source:\n{src}" + + +def _make_register_tile_probe(): + @T.prim_func + def register_tile_probe( + A: T.Tensor((8, 8), T.float32), + B: T.Tensor((8, 8), T.float32), + C: T.Tensor((8, 8), T.float32), + ): + with T.Kernel(1, threads=32): + A_rt = metal_sg.alloc_rt(T.float32, 1, 1) + B_rt = metal_sg.alloc_rt(T.float32, 1, 1) + C_rt = metal_sg.alloc_rt(T.float32, 1, 1) + metal_sg.fill_rt(C_rt, T.float32(0.0)) + metal_sg.load_global_to_rt(A_rt, T.float32, A.data, 0, 64, 8) + metal_sg.load_global_to_rt(B_rt, T.float32, B.data, 0, 64, 8) + metal_sg.mma_ab(C_rt, A_rt, B_rt) + metal_sg.store_rt(C_rt, T.float32, C.data, 0, 64, 8) + + return register_tile_probe + + +def _make_row_vector_probe(): + @T.prim_func + def row_vector_probe( + A: T.Tensor((8, 8), T.float32), + B: T.Tensor((8, 8), T.float32), + stats: T.Tensor((8, 2), T.float32), + normalized: T.Tensor((8, 8), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + A_rt = metal_sg.alloc_rt(T.float32, 1, 1) + Bt_rt = metal_sg.alloc_rt(T.float32, 1, 1, layout=metal_sg.TileLayout.TRANSPOSED) + C_rt = metal_sg.alloc_rt(T.float32, 1, 1) + C_shared = T.alloc_shared((8, 8), T.float32) + row_max_shared = T.alloc_shared((8,), T.float32) + row_sum_shared = T.alloc_shared((8,), T.float32) + row_max = metal_sg.RowVector(row_max_shared, 8, T.float32) + row_sum = metal_sg.RowVector(row_sum_shared, 8, T.float32) + + metal_sg.fill_rt(C_rt, T.float32(0.0)) + metal_sg.load_global_to_rt(A_rt, T.float32, A.data, 0, 64, 8) + metal_sg.load_global_to_rt(Bt_rt, T.float32, B.data, 0, 64, 8, transpose=True) + metal_sg.mma_abt(C_rt, A_rt, Bt_rt) + metal_sg.materialize_rt_to_shared(C_rt, T.float32, C_shared.data, 0, 64, 8) + T.sync_threads() + + metal_sg.row_max(C_shared, row_max, rows=8, cols=8, clear=True) + metal_sg.row_sum(C_shared, row_sum, rows=8, cols=8, clear=True) + metal_sg.div_row(C_shared, row_sum, rows=8, cols=8) + T.sync_threads() + + for linear in T.serial(lane, 8 * 10, step=32): + row = linear // 10 + col = linear - row * 10 + if col == 0: + stats[row, 0] = row_max.values[row] + elif col == 1: + stats[row, 1] = row_sum.values[row] + else: + normalized[row, col - 2] = C_shared[row, col - 2] + + return row_vector_probe + + +def _make_deepseek_packed_quant_probe(): + @T.prim_func + def deepseek_packed_quant_probe( + q8: T.Tensor((16,), T.uint8), + q4: T.Tensor((8,), T.uint8), + e8m0_scale: T.Tensor((16,), T.uint8), + out: T.Tensor((16,), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + for i in T.serial(lane, 16, step=32): + nibble_index = i - (i // 2) * 2 + decoded_fp8 = metal_quant.fp8_e4m3fn_to_float(q8[i]) + decoded_fp4 = metal_quant.fp4_e2m1fn_to_float(q4[i // 2], nibble_index) + scale = metal_quant.e8m0_to_float(e8m0_scale[i]) + out[i] = decoded_fp8 * scale + decoded_fp4 + + return deepseek_packed_quant_probe + + +def _make_deepseek_packed_quant_matmul_probe(): + @T.prim_func + def deepseek_packed_quant_matmul_probe( + q8_act: T.Tensor((8, 16), T.uint8), + q4_weight: T.Tensor((8, 8), T.uint8), + act_scale: T.Tensor((8, 16), T.uint8), + weight_scale: T.Tensor((8, 16), T.uint8), + out: T.Tensor((8, 8), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + for linear in T.serial(lane, 64, step=32): + m = linear // 8 + n = linear - m * 8 + acc = T.alloc_var(T.float32) + acc = 0.0 + for k in T.serial(16): + nibble_index = k - (k // 2) * 2 + decoded_act = metal_quant.fp8_e4m3fn_to_float(q8_act[m, k]) + decoded_weight = metal_quant.fp4_e2m1fn_to_float(q4_weight[n, k // 2], nibble_index) + scale = metal_quant.e8m0_to_float(act_scale[m, k]) * metal_quant.e8m0_to_float(weight_scale[n, k]) + acc += decoded_act * decoded_weight * scale + out[m, n] = acc + + return deepseek_packed_quant_matmul_probe + + +def _make_deepseek_component_quant_matmul_probe(): + @T.prim_func + def deepseek_component_quant_matmul_probe( + q8_act: T.Tensor((16, 64), T.uint8), + q4_weight: T.Tensor((32, 32), T.uint8), + act_scale: T.Tensor((16, 64), T.uint8), + weight_scale: T.Tensor((32, 64), T.uint8), + out: T.Tensor((16, 32), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + for linear in T.serial(lane, 16 * 32, step=32): + m = linear // 32 + n = linear - m * 32 + acc = T.alloc_var(T.float32) + acc = 0.0 + for k in T.serial(64): + nibble_index = k - (k // 2) * 2 + decoded_act = metal_quant.fp8_e4m3fn_to_float(q8_act[m, k]) + decoded_weight = metal_quant.fp4_e2m1fn_to_float(q4_weight[n, k // 2], nibble_index) + scale = metal_quant.e8m0_to_float(act_scale[m, k]) * metal_quant.e8m0_to_float(weight_scale[n, k]) + acc += decoded_act * decoded_weight * scale + out[m, n] = acc + + return deepseek_component_quant_matmul_probe + + +def _make_flashqla_gdn_kkt_probe(): + @T.prim_func + def flashqla_gdn_kkt_probe( + row_k: T.Tensor((8, 8), T.float32), + col_k: T.Tensor((8, 8), T.float32), + scores: T.Tensor((8, 8), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + kkt_bias = T.alloc_var(T.float32) + kkt_bias = 0.0 + row_shared = T.alloc_shared((8, 8), T.float32) + col_shared = T.alloc_shared((8, 8), T.float32) + score_shared = T.alloc_shared((8, 8), T.float32) + for idx in T.serial(lane, 64, step=32): + r = idx // 8 + c = idx - r * 8 + row_shared[r, c] = row_k[r, c] + col_shared[r, c] = col_k[r, c] + T.sync_threads() + metal_gdn.kkt_score_tile(row_shared.data, col_shared.data, score_shared.data, block=8, key_dim=8) + T.sync_threads() + for idx in T.serial(lane, 64, step=32): + r = idx // 8 + c = idx - r * 8 + scores[r, c] = score_shared[r, c] + kkt_bias + + return flashqla_gdn_kkt_probe + + +def _make_flashqla_gdn_wu_probe(): + @T.prim_func + def flashqla_gdn_wu_probe( + a: T.Tensor((8, 8), T.float32), + k: T.Tensor((8, 8), T.float32), + v: T.Tensor((8, 8), T.float32), + beta: T.Tensor((8,), T.float32), + g_cum: T.Tensor((8,), T.float32), + w: T.Tensor((8, 8), T.float32), + u: T.Tensor((8, 8), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + gate_state = T.alloc_var(T.float32) + gate_state = 1.0 + a_shared = T.alloc_shared((8, 8), T.float32) + k_scaled_shared = T.alloc_shared((8, 8), T.float32) + v_scaled_shared = T.alloc_shared((8, 8), T.float32) + w_acc = metal_sg.alloc_rt(T.float32, 1, 1) + u_acc = metal_sg.alloc_rt(T.float32, 1, 1) + metal_sg.fill_rt(w_acc, T.float32(0.0)) + metal_sg.fill_rt(u_acc, T.float32(0.0)) + for idx in T.serial(lane, 64, step=32): + r = idx // 8 + c = idx - r * 8 + a_shared[r, c] = a[r, c] + k_scaled_shared[r, c] = k[r, c] * beta[r] * T.exp(g_cum[r]) * gate_state + v_scaled_shared[r, c] = v[r, c] * beta[r] * gate_state + T.sync_threads() + metal_gdn.wu_score_tiles(a_shared.data, k_scaled_shared.data, v_scaled_shared.data, w_acc, u_acc, block=8) + metal_sg.store_rt(w_acc, T.float32, w.data, 0, 64, 8) + metal_sg.store_rt(u_acc, T.float32, u.data, 0, 64, 8) + + return flashqla_gdn_wu_probe + + +def _make_flashqla_gdn_component_probe(): + @T.prim_func + def flashqla_gdn_component_probe( + k: T.Tensor((16, 16), T.float32), + v: T.Tensor((16, 16), T.float32), + beta: T.Tensor((16,), T.float32), + g_cum: T.Tensor((16,), T.float32), + a_pre: T.Tensor((16, 16), T.float32), + w: T.Tensor((16, 16), T.float32), + u: T.Tensor((16, 16), T.float32), + ): + with T.Kernel(1, threads=32): + lane = T.get_thread_binding() + gate_state = T.alloc_var(T.float32) + gate_state = 1.0 + row_shared = T.alloc_shared((8, 16), T.float32) + col_shared = T.alloc_shared((8, 16), T.float32) + score_shared = T.alloc_shared((8, 8), T.float32) + a_shared = T.alloc_shared((16, 16), T.float32) + k_scaled_shared = T.alloc_shared((16, 16), T.float32) + v_scaled_shared = T.alloc_shared((16, 16), T.float32) + w_acc = metal_sg.alloc_rt(T.float32, 1, 1) + u_acc = metal_sg.alloc_rt(T.float32, 1, 1) + + for idx in T.serial(lane, 16 * 16, step=32): + r = idx // 16 + c = idx - r * 16 + k_scaled_shared[r, c] = k[r, c] * beta[r] * T.exp(g_cum[r]) * gate_state + v_scaled_shared[r, c] = v[r, c] * beta[r] * gate_state + a_shared[r, c] = 0.0 + T.sync_threads() + + for row_block in T.unroll(2, explicit=True): + for col_block in T.unroll(2, explicit=True): + for idx in T.serial(lane, 8 * 16, step=32): + r = idx // 16 + c = idx - r * 16 + row_shared[r, c] = k[row_block * 8 + r, c] + col_shared[r, c] = k[col_block * 8 + r, c] + T.sync_threads() + metal_gdn.kkt_score_tile_accum( + row_shared.data, + col_shared.data, + score_shared.data, + block=8, + key_dim=16, + key_offset=0, + clear=True, + ) + metal_gdn.kkt_score_tile_accum( + row_shared.data, + col_shared.data, + score_shared.data, + block=8, + key_dim=16, + key_offset=8, + clear=False, + ) + T.sync_threads() + for idx in T.serial(lane, 8 * 8, step=32): + local_row = idx // 8 + local_col = idx - local_row * 8 + c = row_block * 8 + local_row + d = col_block * 8 + local_col + gated = T.alloc_var(T.float32) + gated = 0.0 + if d < c: + gated = score_shared[local_row, local_col] * T.exp(g_cum[c] - g_cum[d]) * gate_state + a_pre[c, d] = gated + a_shared[c, d] = gated + T.sync_threads() + + for row_block in T.unroll(2, explicit=True): + for col_block in T.unroll(2, explicit=True): + metal_sg.fill_rt(w_acc, T.float32(0.0)) + metal_sg.fill_rt(u_acc, T.float32(0.0)) + for d_block in T.unroll(2, explicit=True): + metal_gdn.wu_score_tiles_strided( + a_shared.data, + k_scaled_shared.data, + v_scaled_shared.data, + w_acc, + u_acc, + a_offset=row_block * 8 * 16 + d_block * 8, + k_offset=d_block * 8 * 16 + col_block * 8, + v_offset=d_block * 8 * 16 + col_block * 8, + a_stride=16, + kv_stride=16, + block=8, + ) + metal_sg.store_rt( + w_acc, + T.float32, + w.data, + row_block * 8 * 16 + col_block * 8, + 16 * 16, + 16, + ) + metal_sg.store_rt( + u_acc, + T.float32, + u.data, + row_block * 8 * 16 + col_block * 8, + 16 * 16, + 16, + ) + + return flashqla_gdn_component_probe + + +def test_internal_register_tile_helper_emits_pr_simdgroup_tokens_only(): + src = _lower_source(_make_register_tile_probe()) + _assert_clean_metal_source(src) + assert "simdgroup_multiply_accumulate" in src + assert "simdgroup_load" in src + assert "simdgroup_store" in src + assert "simdgroup_float8x8" in src + assert "C_tmp" not in src + + +def test_row_vector_remains_materialized_not_scalar_indexed_simdgroup_fragment(): + src = _lower_source(_make_row_vector_probe()) + _assert_clean_metal_source(src) + assert "simdgroup_multiply_accumulate" in src + assert "threadgroup float" in src + assert "simdgroup_float8x8" in src + # RowVector reductions are over materialized threadgroup storage, not scalar + # indexing into opaque simdgroup_matrix fragments. + assert "rt_fragment[0][" not in src + assert "rt_fragment_1[0][" not in src + assert "rt_fragment_2[0][" not in src + + +def test_no_public_register_tile_or_row_vector_language_aliases(): + assert not hasattr(T, "rt") + assert not hasattr(T, "rv") + assert not hasattr(T, "RegisterTile") + assert not hasattr(T, "RowVector") + + +def test_deepseek_packed_quant_probe_uses_uint8_boundary_not_native_fp8_fp4_storage(): + src = _lower_source(_make_deepseek_packed_quant_probe()) + _assert_clean_metal_source(src) + lowered = src.lower() + assert "device uchar" in lowered + assert "float8" not in lowered + assert "float4" not in lowered + assert "simdgroup_multiply_accumulate" not in lowered + assert metal_quant.use_large_simdgroup_tile(64, 512, mixed_fp4_weight=True) + assert not metal_quant.use_large_simdgroup_tile(64, 256, mixed_fp4_weight=True) + + +def test_flashqla_gdn_kkt_probe_combines_local_var_state_and_simdgroup_boundary(): + src = _lower_source(_make_flashqla_gdn_kkt_probe()) + _assert_clean_metal_source(src) + assert "simdgroup_multiply_accumulate" in src + assert "simdgroup_load" in src + assert "simdgroup_store" in src + assert "threadgroup float" in src + assert "local.var" not in src + assert "float kkt_bias = 0.000000e+00f;" in src + + +def test_scaled_packed_quant_and_gdn_probes_source_boundary_tokens(): + deepseek_src = _lower_source(_make_deepseek_packed_quant_matmul_probe()) + _assert_clean_metal_source(deepseek_src) + deepseek_lowered = deepseek_src.lower() + assert deepseek_lowered.count("device uchar") >= 4 + assert "float8" not in deepseek_lowered + assert "float4" not in deepseek_lowered + assert "simdgroup_multiply_accumulate" not in deepseek_lowered + + gdn_src = _lower_source(_make_flashqla_gdn_wu_probe()) + _assert_clean_metal_source(gdn_src) + assert "simdgroup_multiply_accumulate" in gdn_src + assert gdn_src.count("simdgroup_load") >= 3 + assert gdn_src.count("simdgroup_store") >= 2 + assert "threadgroup float" in gdn_src + assert "local.var" not in gdn_src + assert "float gate_state" in gdn_src + assert "gate_state = 1.000000e+00f;" in gdn_src + + +def test_component_packed_quant_and_gdn_probes_source_boundary_tokens(): + deepseek_src = _lower_source(_make_deepseek_component_quant_matmul_probe()) + _assert_clean_metal_source(deepseek_src) + deepseek_lowered = deepseek_src.lower() + assert deepseek_lowered.count("device uchar") >= 4 + assert "float8" not in deepseek_lowered + assert "float4" not in deepseek_lowered + assert "simdgroup_multiply_accumulate" not in deepseek_lowered + + gdn_src = _lower_source(_make_flashqla_gdn_component_probe()) + _assert_clean_metal_source(gdn_src) + assert gdn_src.count("simdgroup_multiply_accumulate") >= 12 + assert gdn_src.count("simdgroup_load") >= 18 + assert gdn_src.count("simdgroup_store") >= 12 + assert "threadgroup float" in gdn_src + assert "local.var" not in gdn_src + assert "float gate_state" in gdn_src + assert "gate_state = 1.000000e+00f;" in gdn_src + + +def _run_native_dtype_probe(tmp_path: Path, dtype_name: str) -> subprocess.CompletedProcess[str]: + script = tmp_path / f"probe_{dtype_name}.py" + script.write_text( + textwrap.dedent( + f""" + import tilelang + import tilelang.language as T + + @T.prim_func + def bad_kernel(A: T.Tensor((32,), T.float32), B: T.Tensor((32,), T.{dtype_name})): + with T.Kernel(1, threads=32) as bx: + for i in T.Parallel(32): + B[bx * 32 + i] = A[bx * 32 + i] + + tilelang.compile(bad_kernel, target="metal") + """ + ) + ) + env = os.environ.copy() + repo_root = str(Path.cwd()) + env["PYTHONPATH"] = repo_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") + return subprocess.run( + [sys.executable, str(script)], + cwd=repo_root, + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=30, + check=False, + ) + + +@pytest.mark.parametrize("dtype_name", ["float8_e4m3fn", "float4_e2m1fn"]) +def test_native_fp8_fp4_metal_storage_fail_closed_in_subprocess(tmp_path, dtype_name): + result = _run_native_dtype_probe(tmp_path, dtype_name) + combined = result.stdout + result.stderr + assert result.returncode != 0 + assert f"Cannot convert type {dtype_name} to Metal type" in combined + + +def _fp8_e4m3fn_to_float_cpu(bits: int) -> float: + abs_bits = bits & 0x7F + sign = (bits >> 7) & 1 + exp_bits = (bits >> 3) & 0xF + mant_bits = bits & 0x7 + if exp_bits == 0: + value = mant_bits / 512.0 + else: + value = (1.0 + mant_bits / 8.0) * (2.0 ** (exp_bits - 7)) + if abs_bits == 0x7F: + value = 0.0 + return -value if sign else value + + +def _fp4_e2m1fn_to_float_cpu(bits: int, nibble_index: int) -> float: + nibble = (bits >> (nibble_index * 4)) & 0xF + sign = (nibble >> 3) & 1 + mag = nibble & 0x7 + value = (0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0)[mag] + return -value if sign else value + + +def _e8m0_to_float_cpu(bits: int) -> float: + return 0.0 if bits == 255 else 2.0 ** (bits - 127) + + +def _deepseek_synthetic_inputs(): + q8 = torch.tensor( + [0, 1, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 120], + dtype=torch.uint8, + ) + q4 = torch.tensor([0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE], dtype=torch.uint8) + e8m0_scale = torch.tensor( + [127, 128, 126, 129, 125, 130, 124, 131, 127, 128, 126, 129, 125, 130, 124, 255], + dtype=torch.uint8, + ) + return q8, q4, e8m0_scale + + +def _deepseek_decode_ref(q8: torch.Tensor, q4: torch.Tensor, e8m0_scale: torch.Tensor) -> torch.Tensor: + values = [] + for i in range(16): + decoded_fp8 = _fp8_e4m3fn_to_float_cpu(int(q8[i])) + decoded_fp4 = _fp4_e2m1fn_to_float_cpu(int(q4[i // 2]), i % 2) + scale = _e8m0_to_float_cpu(int(e8m0_scale[i])) + values.append(decoded_fp8 * scale + decoded_fp4) + return torch.tensor(values, dtype=torch.float32) + + +def _deepseek_matmul_synthetic_inputs(): + q8_act = ((torch.arange(128, dtype=torch.int16).reshape(8, 16) * 3 + 5) % 121).to(torch.uint8) + q4_weight = ((torch.arange(64, dtype=torch.int16).reshape(8, 8) * 5 + 1) % 256).to(torch.uint8) + act_scale = (126 + (torch.arange(128, dtype=torch.int16).reshape(8, 16) % 5)).to(torch.uint8) + weight_scale = (125 + ((torch.arange(128, dtype=torch.int16).reshape(8, 16) * 2) % 5)).to(torch.uint8) + return q8_act, q4_weight, act_scale, weight_scale + + +def _deepseek_matmul_ref( + q8_act: torch.Tensor, + q4_weight: torch.Tensor, + act_scale: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + out = torch.empty((8, 8), dtype=torch.float32) + for m in range(8): + for n in range(8): + acc = 0.0 + for k in range(16): + decoded_act = _fp8_e4m3fn_to_float_cpu(int(q8_act[m, k])) + decoded_weight = _fp4_e2m1fn_to_float_cpu(int(q4_weight[n, k // 2]), k % 2) + scale = _e8m0_to_float_cpu(int(act_scale[m, k])) * _e8m0_to_float_cpu(int(weight_scale[n, k])) + acc += decoded_act * decoded_weight * scale + out[m, n] = acc + return out + + +def _deepseek_component_matmul_synthetic_inputs(): + q8_act = ((torch.arange(16 * 64, dtype=torch.int16).reshape(16, 64) * 7 + 11) % 121).to(torch.uint8) + q4_weight = ((torch.arange(32 * 32, dtype=torch.int16).reshape(32, 32) * 13 + 3) % 256).to(torch.uint8) + act_scale = (124 + ((torch.arange(16 * 64, dtype=torch.int16).reshape(16, 64) * 3) % 7)).to(torch.uint8) + weight_scale = (123 + ((torch.arange(32 * 64, dtype=torch.int16).reshape(32, 64) * 5) % 7)).to(torch.uint8) + return q8_act, q4_weight, act_scale, weight_scale + + +def _deepseek_component_matmul_ref( + q8_act: torch.Tensor, + q4_weight: torch.Tensor, + act_scale: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + m_size, k_size = q8_act.shape + n_size = q4_weight.shape[0] + out = torch.empty((m_size, n_size), dtype=torch.float32) + for m in range(m_size): + for n in range(n_size): + acc = 0.0 + for k in range(k_size): + decoded_act = _fp8_e4m3fn_to_float_cpu(int(q8_act[m, k])) + decoded_weight = _fp4_e2m1fn_to_float_cpu(int(q4_weight[n, k // 2]), k % 2) + scale = _e8m0_to_float_cpu(int(act_scale[m, k])) * _e8m0_to_float_cpu(int(weight_scale[n, k])) + acc += decoded_act * decoded_weight * scale + out[m, n] = acc + return out + + +def _flashqla_gdn_wu_synthetic_inputs(): + a = torch.tril((torch.arange(64, dtype=torch.float32).reshape(8, 8) - 9.0) / 23.0) + k = (torch.arange(64, dtype=torch.float32).reshape(8, 8).flip(0) - 5.0) / 17.0 + v = (torch.arange(64, dtype=torch.float32).reshape(8, 8).flip(1) + 3.0) / 19.0 + beta = torch.linspace(0.25, 1.125, 8, dtype=torch.float32) + g_cum = torch.linspace(-0.375, 0.5, 8, dtype=torch.float32) + return a, k, v, beta, g_cum + + +def _flashqla_gdn_wu_ref( + a: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cum: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + k_scaled = k * (beta * torch.exp(g_cum)).unsqueeze(1) + v_scaled = v * beta.unsqueeze(1) + return a @ k_scaled, a @ v_scaled + + +def _flashqla_gdn_component_synthetic_inputs(): + k = (torch.arange(16 * 16, dtype=torch.float32).reshape(16, 16) - 31.0) / 37.0 + v = (torch.arange(16 * 16, dtype=torch.float32).reshape(16, 16).flip(1) + 7.0) / 41.0 + beta = torch.linspace(0.125, 1.0625, 16, dtype=torch.float32) + g_cum = torch.linspace(-0.5, 0.625, 16, dtype=torch.float32) + return k, v, beta, g_cum + + +def _flashqla_gdn_component_ref( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cum: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + scores = k @ k.T + row_idx = torch.arange(16).view(16, 1) + col_idx = torch.arange(16).view(1, 16) + causal = col_idx < row_idx + gated = scores * torch.exp(g_cum.view(16, 1) - g_cum.view(1, 16)) + a_pre = torch.where(causal, gated, torch.zeros_like(gated)) + k_scaled = k * (beta * torch.exp(g_cum)).unsqueeze(1) + v_scaled = v * beta.unsqueeze(1) + return a_pre, a_pre @ k_scaled, a_pre @ v_scaled + + +@tilelang.testing.requires_metal +def test_deepseek_packed_decode_runtime_mps_matches_cpu_reference(): + kernel = tilelang.compile(_make_deepseek_packed_quant_probe(), target="metal") + q8, q4, e8m0_scale = _deepseek_synthetic_inputs() + out = torch.empty(16, dtype=torch.float32, device="mps") + + kernel(q8.to("mps"), q4.to("mps"), e8m0_scale.to("mps"), out) + torch.mps.synchronize() + + ref = _deepseek_decode_ref(q8, q4, e8m0_scale) + assert torch.allclose(out.cpu(), ref, atol=1e-6, rtol=1e-6) + + +@tilelang.testing.requires_metal +def test_deepseek_packed_quant_matmul_runtime_mps_matches_cpu_reference(): + kernel = tilelang.compile(_make_deepseek_packed_quant_matmul_probe(), target="metal") + q8_act, q4_weight, act_scale, weight_scale = _deepseek_matmul_synthetic_inputs() + out = torch.empty((8, 8), dtype=torch.float32, device="mps") + + kernel(q8_act.to("mps"), q4_weight.to("mps"), act_scale.to("mps"), weight_scale.to("mps"), out) + torch.mps.synchronize() + + ref = _deepseek_matmul_ref(q8_act, q4_weight, act_scale, weight_scale) + assert torch.allclose(out.cpu(), ref, atol=1e-4, rtol=1e-5) + + +@tilelang.testing.requires_metal +def test_deepseek_component_quant_matmul_runtime_mps_matches_cpu_reference(): + kernel = tilelang.compile(_make_deepseek_component_quant_matmul_probe(), target="metal") + q8_act, q4_weight, act_scale, weight_scale = _deepseek_component_matmul_synthetic_inputs() + out = torch.empty((16, 32), dtype=torch.float32, device="mps") + + kernel(q8_act.to("mps"), q4_weight.to("mps"), act_scale.to("mps"), weight_scale.to("mps"), out) + torch.mps.synchronize() + + ref = _deepseek_component_matmul_ref(q8_act, q4_weight, act_scale, weight_scale) + assert torch.allclose(out.cpu(), ref, atol=1e-3, rtol=1e-5) + + +@tilelang.testing.requires_metal +def test_flashqla_gdn_kkt_runtime_mps_matches_torch_reference(): + kernel = tilelang.compile(_make_flashqla_gdn_kkt_probe(), target="metal") + row_k = torch.arange(64, dtype=torch.float32).reshape(8, 8) / 17.0 + col_k = (torch.arange(64, dtype=torch.float32).reshape(8, 8).flip(1) - 10.0) / 19.0 + scores = torch.empty((8, 8), dtype=torch.float32, device="mps") + + kernel(row_k.to("mps"), col_k.to("mps"), scores) + torch.mps.synchronize() + + ref = row_k @ col_k.T + assert torch.allclose(scores.cpu(), ref, atol=1e-5, rtol=1e-5) + + +@tilelang.testing.requires_metal +def test_register_tile_runtime_mps_matches_torch_reference(): + kernel = tilelang.compile(_make_register_tile_probe(), target="metal") + a = torch.arange(64, dtype=torch.float32).reshape(8, 8) / 13.0 + b = (torch.arange(64, dtype=torch.float32).reshape(8, 8) - 20.0) / 11.0 + c = torch.empty((8, 8), dtype=torch.float32, device="mps") + + kernel(a.to("mps"), b.to("mps"), c) + torch.mps.synchronize() + + assert torch.allclose(c.cpu(), a @ b, atol=1e-6, rtol=1e-6) + + +@tilelang.testing.requires_metal +def test_flashqla_gdn_staged_wu_runtime_mps_matches_torch_reference(): + kernel = tilelang.compile(_make_flashqla_gdn_wu_probe(), target="metal") + a, k, v, beta, g_cum = _flashqla_gdn_wu_synthetic_inputs() + w = torch.empty((8, 8), dtype=torch.float32, device="mps") + u = torch.empty((8, 8), dtype=torch.float32, device="mps") + + kernel(a.to("mps"), k.to("mps"), v.to("mps"), beta.to("mps"), g_cum.to("mps"), w, u) + torch.mps.synchronize() + + ref_w, ref_u = _flashqla_gdn_wu_ref(a, k, v, beta, g_cum) + assert torch.allclose(w.cpu(), ref_w, atol=1e-5, rtol=1e-5) + assert torch.allclose(u.cpu(), ref_u, atol=1e-5, rtol=1e-5) + + +@tilelang.testing.requires_metal +def test_flashqla_gdn_component_runtime_mps_matches_torch_reference(): + kernel = tilelang.compile(_make_flashqla_gdn_component_probe(), target="metal") + k, v, beta, g_cum = _flashqla_gdn_component_synthetic_inputs() + a_pre = torch.empty((16, 16), dtype=torch.float32, device="mps") + w = torch.empty((16, 16), dtype=torch.float32, device="mps") + u = torch.empty((16, 16), dtype=torch.float32, device="mps") + + kernel(k.to("mps"), v.to("mps"), beta.to("mps"), g_cum.to("mps"), a_pre, w, u) + torch.mps.synchronize() + + ref_a, ref_w, ref_u = _flashqla_gdn_component_ref(k, v, beta, g_cum) + assert torch.allclose(a_pre.cpu(), ref_a, atol=1e-4, rtol=1e-5) + assert torch.allclose(w.cpu(), ref_w, atol=1e-4, rtol=1e-5) + assert torch.allclose(u.cpu(), ref_u, atol=1e-4, rtol=1e-5) + + +@tilelang.testing.requires_metal +def test_small_synthetic_runtime_benchmarks_opt_in(): + if os.environ.get("TILELANG_RUN_METAL_SMALL_BENCH") != "1": + pytest.skip("set TILELANG_RUN_METAL_SMALL_BENCH=1 to run small Metal benchmark hooks") + + deepseek_kernel = tilelang.compile(_make_deepseek_packed_quant_probe(), target="metal") + gdn_kernel = tilelang.compile(_make_flashqla_gdn_kkt_probe(), target="metal") + q8, q4, e8m0_scale = _deepseek_synthetic_inputs() + q8_mps, q4_mps, e8m0_mps = q8.to("mps"), q4.to("mps"), e8m0_scale.to("mps") + decode_out = torch.empty(16, dtype=torch.float32, device="mps") + # Prepare MPS inputs via CPU expressions; mixing torch MPS arithmetic with + # TVM's Metal command encoder in the same tiny benchmark can leave an active + # encoder and trip Metal's command-buffer assertion. + row_k = (torch.arange(64, dtype=torch.float32).reshape(8, 8) / 17.0).to("mps") + col_k = ((torch.arange(64, dtype=torch.float32).reshape(8, 8).flip(1) - 10.0) / 19.0).to("mps") + scores = torch.empty((8, 8), dtype=torch.float32, device="mps") + torch.mps.synchronize() + + def bench(name, fn, iterations: int = 20): + # The current TVM/Metal runtime can trip Metal's single command-encoder + # assertion if tiny kernels are launched back-to-back without a flush. + # Keep this hook safe/rerunnable by timing synchronized iterations. + for _ in range(3): + fn() + torch.mps.synchronize() + start = time.perf_counter() + for _ in range(iterations): + fn() + torch.mps.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iterations + print(f"metal_small_bench {name}: {elapsed_ms:.4f} ms/iter over {iterations} iterations") + assert elapsed_ms >= 0.0 + + bench("deepseek_packed_decode_16", lambda: deepseek_kernel(q8_mps, q4_mps, e8m0_mps, decode_out)) + bench("flashqla_gdn_kkt_8x8", lambda: gdn_kernel(row_k, col_k, scores)) + + +@tilelang.testing.requires_metal +def test_scaled_synthetic_runtime_benchmarks_opt_in(): + if os.environ.get("TILELANG_RUN_METAL_SCALED_BENCH") != "1": + pytest.skip("set TILELANG_RUN_METAL_SCALED_BENCH=1 to run scaled Metal benchmark hooks") + + deepseek_kernel = tilelang.compile(_make_deepseek_packed_quant_matmul_probe(), target="metal") + gdn_kernel = tilelang.compile(_make_flashqla_gdn_wu_probe(), target="metal") + q8_act, q4_weight, act_scale, weight_scale = _deepseek_matmul_synthetic_inputs() + q8_mps = q8_act.to("mps") + q4_mps = q4_weight.to("mps") + act_scale_mps = act_scale.to("mps") + weight_scale_mps = weight_scale.to("mps") + matmul_out = torch.empty((8, 8), dtype=torch.float32, device="mps") + a, k, v, beta, g_cum = _flashqla_gdn_wu_synthetic_inputs() + a_mps, k_mps, v_mps = a.to("mps"), k.to("mps"), v.to("mps") + beta_mps, g_cum_mps = beta.to("mps"), g_cum.to("mps") + w = torch.empty((8, 8), dtype=torch.float32, device="mps") + u = torch.empty((8, 8), dtype=torch.float32, device="mps") + torch.mps.synchronize() + + def bench(name, fn, iterations: int = 20): + for _ in range(3): + fn() + torch.mps.synchronize() + start = time.perf_counter() + for _ in range(iterations): + fn() + torch.mps.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iterations + print(f"metal_scaled_bench {name}: {elapsed_ms:.4f} ms/iter over {iterations} iterations") + assert elapsed_ms >= 0.0 + + bench( + "deepseek_packed_quant_matmul_m8n8k16", + lambda: deepseek_kernel(q8_mps, q4_mps, act_scale_mps, weight_scale_mps, matmul_out), + ) + bench("flashqla_gdn_wu_8x8", lambda: gdn_kernel(a_mps, k_mps, v_mps, beta_mps, g_cum_mps, w, u)) + + +@tilelang.testing.requires_metal +def test_component_synthetic_runtime_benchmarks_opt_in(): + if os.environ.get("TILELANG_RUN_METAL_COMPONENT_BENCH") != "1": + pytest.skip("set TILELANG_RUN_METAL_COMPONENT_BENCH=1 to run component Metal benchmark hooks") + + deepseek_kernel = tilelang.compile(_make_deepseek_component_quant_matmul_probe(), target="metal") + gdn_kernel = tilelang.compile(_make_flashqla_gdn_component_probe(), target="metal") + q8_act, q4_weight, act_scale, weight_scale = _deepseek_component_matmul_synthetic_inputs() + q8_mps = q8_act.to("mps") + q4_mps = q4_weight.to("mps") + act_scale_mps = act_scale.to("mps") + weight_scale_mps = weight_scale.to("mps") + matmul_out = torch.empty((16, 32), dtype=torch.float32, device="mps") + k, v, beta, g_cum = _flashqla_gdn_component_synthetic_inputs() + k_mps, v_mps = k.to("mps"), v.to("mps") + beta_mps, g_cum_mps = beta.to("mps"), g_cum.to("mps") + a_pre = torch.empty((16, 16), dtype=torch.float32, device="mps") + w = torch.empty((16, 16), dtype=torch.float32, device="mps") + u = torch.empty((16, 16), dtype=torch.float32, device="mps") + torch.mps.synchronize() + + def bench(name, fn, iterations: int = 10): + for _ in range(2): + fn() + torch.mps.synchronize() + start = time.perf_counter() + for _ in range(iterations): + fn() + torch.mps.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iterations + print(f"metal_component_bench {name}: {elapsed_ms:.4f} ms/iter over {iterations} iterations") + assert elapsed_ms >= 0.0 + + bench( + "deepseek_packed_quant_matmul_m16n32k64", + lambda: deepseek_kernel(q8_mps, q4_mps, act_scale_mps, weight_scale_mps, matmul_out), + ) + bench("flashqla_gdn_component_chunk16_k16_v16", lambda: gdn_kernel(k_mps, v_mps, beta_mps, g_cum_mps, a_pre, w, u)) diff --git a/testing/python/metal/test_metal_local_var.py b/testing/python/metal/test_metal_local_var.py new file mode 100644 index 0000000000..9b1bce9b82 --- /dev/null +++ b/testing/python/metal/test_metal_local_var.py @@ -0,0 +1,62 @@ +"""Focused Metal support tests for local.var scalar code generation.""" + +import re + +import torch + +import tilelang +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing + + +def _make_local_var_func(): + @T.prim_func + def local_var_kernel(A: T.Tensor((2,), T.int32)): + with T.Kernel(1, threads=1) as _: + x = T.alloc_var(T.int32, init=3) + y = T.alloc_var(T.int32) + y = x + 4 + A[0] = x + A[1] = y + + return local_var_kernel + + +def test_metal_local_var_scalar_codegen_uses_thread_scalars(): + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(_make_local_var_func(), target="metal") + + src = artifact.kernel_source + assert src is not None + assert "kernel void" in src + + # local.var should lower to scalar declarations/stores rather than arrays or + # an unsupported storage scope. + assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src + assert re.search(r"\w+\s*=\s*3;", src), src + assert re.search(r"\w+\s*=\s*\(\w+ \+ 4\);", src), src + assert "local.var" not in src + assert "thread int" not in src + + +def test_metal_local_var_codegen_has_scalar_loads_for_outputs(): + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(_make_local_var_func(), target="metal") + + src = artifact.kernel_source + assert src is not None + output_lines = [line.strip() for line in src.splitlines() if line.strip().startswith("A[")] + assert len(output_lines) == 2, src + assert all("[0]" not in line.split("=", 1)[1] for line in output_lines), output_lines + + +@tilelang.testing.requires_metal +def test_metal_local_var_runtime_scalar_load_store(): + kernel = tilelang.compile(_make_local_var_func(), target="metal") + out = torch.empty(2, dtype=torch.int32, device="mps") + + kernel(out) + torch.mps.synchronize() + + assert out.cpu().tolist() == [3, 7] diff --git a/testing/python/metal/test_metal_simdgroup_store.py b/testing/python/metal/test_metal_simdgroup_store.py new file mode 100644 index 0000000000..ea8429589c --- /dev/null +++ b/testing/python/metal/test_metal_simdgroup_store.py @@ -0,0 +1,133 @@ +"""Test Metal simdgroup register GEMM with direct simdgroup_store to device memory. + +These tests verify the simdgroup register accumulation path where C is allocated +in metal.simdgroup scope. This eliminates C_simd load/store round-trips through +shared memory on each K iteration. The final T.copy(C_local, C[...]) is lowered +to simdgroup_store directly to device memory via LowerSIMDGroupStore. +""" + +import tilelang +from tilelang import tvm as tvm +import tilelang.testing +import tilelang.language as T +import torch + + +def _make_simdgroup_gemm_func(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm_kernel + + +matmul_simdgroup = tilelang.jit(_make_simdgroup_gemm_func) + + +def assert_simdgroup_store_correctness(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, atol=1e-2): + kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + + torch_dtype = dtype.as_torch() + torch_accum_dtype = accum_dtype.as_torch() + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(K, N, dtype=torch_dtype, device="mps") + c = torch.zeros(M, N, dtype=torch_accum_dtype, device="mps") + + kernel(a, b, c) + + ref = a.to(torch_accum_dtype) @ b.to(torch_accum_dtype) + assert torch.allclose(ref, c, atol=atol), ( + f"Result mismatch for M={M}, N={N}, K={K}, " + f"block=({block_M},{block_N},{block_K}), dtype={dtype}\n" + f"max diff: {(ref - c).abs().max().item()}" + ) + + +def assert_simdgroup_store_codegen(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + func = _make_simdgroup_gemm_func(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) + with tvm.transform.PassContext(), tvm.target.Target("metal"): + artifact = tilelang.lower(func, target="metal") + + src = artifact.kernel_source + assert src is not None + assert "kernel void" in src + assert "simdgroup_multiply_accumulate" in src + assert "make_filled_simdgroup_matrix" in src + + assert "simdgroup_float8x8" in src or "simdgroup_half8x8" in src, "Expected simdgroup_float8x8 or simdgroup_half8x8 for C accumulator" + + store_to_device = src.count("simdgroup_store(C_local") + assert store_to_device > 0, "Expected simdgroup_store of C_local to device memory" + + load_c_from_shared = [line for line in src.split("\n") if "simdgroup_load" in line and "C_local" in line] + assert len(load_c_from_shared) == 0, f"C_local should not be loaded from shared memory, but found: {load_c_from_shared}" + + +# --- Codegen tests (cross-platform) --- + + +def test_codegen_square_small(): + assert_simdgroup_store_codegen(64, 64, 64, 16, 16, 16) + + +def test_codegen_square_large(): + assert_simdgroup_store_codegen(128, 128, 128, 32, 32, 32) + + +def test_codegen_non_square(): + assert_simdgroup_store_codegen(128, 128, 128, 32, 64, 16) + + +def test_codegen_float32_accum(): + assert_simdgroup_store_codegen(64, 64, 64, 16, 16, 16, dtype=T.float32, accum_dtype=T.float32) + + +# --- Correctness tests (require Metal hardware) --- + + +@tilelang.testing.requires_metal +def test_correctness_16x16x16(): + assert_simdgroup_store_correctness(128, 128, 128, 16, 16, 16) + + +@tilelang.testing.requires_metal +def test_correctness_32x32x32(): + assert_simdgroup_store_correctness(128, 128, 128, 32, 32, 32) + + +@tilelang.testing.requires_metal +def test_correctness_non_square_block(): + assert_simdgroup_store_correctness(128, 128, 128, 32, 64, 16) + + +@tilelang.testing.requires_metal +def test_correctness_64x64x32(): + assert_simdgroup_store_correctness(128, 128, 128, 64, 64, 32) + + +@tilelang.testing.requires_metal +def test_correctness_large_matrix(): + assert_simdgroup_store_correctness(1024, 1024, 1024, 32, 32, 32, atol=1.0) + + +@tilelang.testing.requires_metal +def test_correctness_non_square_matrix(): + assert_simdgroup_store_correctness(256, 512, 128, 32, 32, 16) + + +if __name__ == "__main__": + if torch.mps.is_available(): + tilelang.testing.main() diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index e900890300..d7c3547a40 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -236,7 +236,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule: elif target.kind.name == "hip": device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) elif target.kind.name == "metal": - device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) else: raise ValueError(f"Target {target.kind.name} is not supported") @@ -261,7 +261,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> elif target.kind.name == "webgpu": device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target) elif target.kind.name == "metal": - device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.tilelang_metal")(device_mod, target) else: raise ValueError(f"Target {target.kind.name} is not supported") diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5563845214..854d9e73c4 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -197,6 +197,11 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.Simplify()(mod) + # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup + # before layout inference (which would otherwise require a layout for them) + from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup + + mod = MetalFragmentToSimdgroup(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) # Visualize the layout diff --git a/tilelang/intrinsics/metal_macro_generator.py b/tilelang/intrinsics/metal_macro_generator.py new file mode 100644 index 0000000000..9da073152c --- /dev/null +++ b/tilelang/intrinsics/metal_macro_generator.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import tilelang.language as T +from tvm import tir +from tvm.tir import Buffer, BufferRegion + + +class MPSIntrinEmitter: + WARP_SIZE = 32 + + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float32", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 1, + block_col_warps: int = 1, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 32, + thread_var: tir.Var | None = None, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self.thread_var = thread_var + + # Metal simdgroup matrix size (always 8x8) + self.micro_size_x = 8 + self.micro_size_y = 8 + self.micro_size_k = 8 + + # Number of 8x8 tiles per warp + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + + def get_thread_binding(self): + if self.thread_var is None: + current_frame = T.KernelLaunchFrame.Current() + assert current_frame is not None, "Must be called in a T.Kernel Frame" + return current_frame.get_thread_binding() + else: + return self.thread_var + + def _get_warp_indices(self): + thread_binding = self.get_thread_binding() + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + + warp_m = (thread_binding // WARP_SIZE) % block_row_warps + warp_n = (thread_binding // (WARP_SIZE * block_row_warps)) % block_col_warps + return warp_m, warp_n + + @staticmethod + def _parse_buffer_2d(buf): + """Extract (buffer, row_offset, col_offset, stride) from Buffer or BufferRegion.""" + if isinstance(buf, BufferRegion): + buffer = buf.buffer + off_row = buf.region[-2].min + off_col = buf.region[-1].min + else: + buffer = buf + off_row = 0 + off_col = 0 + stride = buffer.strides[-2] if len(buffer.strides) == len(buffer.shape) else buffer.shape[-1] + return buffer, off_row, off_col, stride + + def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki): + warp_rows = self.warp_rows + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + a_transposed = self.a_transposed + + warp_m, _ = self._get_warp_indices() + + buffer, offset_m, offset_k, stride = self._parse_buffer_2d(A_shared_buf) + + @T.macro + def _warp_ldmatrix_a(A_local_buf, buffer, offset_m, offset_k, stride, warp_m, ki): + for i in T.serial(warp_rows): + if a_transposed: + row_idx = offset_k + ki * micro_size_k + col_idx = offset_m + warp_m * (self.warp_row_tiles) + i * micro_size_x + else: + row_idx = offset_m + warp_m * (self.warp_row_tiles) + i * micro_size_x + col_idx = offset_k + ki * micro_size_k + + ptr = T.access_ptr(buffer[row_idx, col_idx], "r") + + T.simdgroup_load( + A_local_buf.data, + i, + ptr, + stride, + micro_size_x, + micro_size_k, + T.bool(a_transposed), + ) + + return _warp_ldmatrix_a(A_local_buf, buffer, offset_m, offset_k, stride, warp_m, ki) + + def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki): + warp_cols = self.warp_cols + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + b_transposed = self.b_transposed + + _, warp_n = self._get_warp_indices() + + buffer, offset_k, offset_n, stride = self._parse_buffer_2d(B_shared_buf) + + @T.macro + def _warp_ldmatrix_b(B_local_buf, buffer, offset_k, offset_n, stride, warp_n, ki): + for j in T.serial(warp_cols): + if b_transposed: + row_idx = offset_n + warp_n * (self.warp_col_tiles) + j * micro_size_y + col_idx = offset_k + ki * micro_size_k + else: + row_idx = offset_k + ki * micro_size_k + col_idx = offset_n + warp_n * (self.warp_col_tiles) + j * micro_size_y + + ptr = T.access_ptr(buffer[row_idx, col_idx], "r") + + T.simdgroup_load( + B_local_buf.data, + j, + ptr, + stride, + micro_size_k, + micro_size_y, + T.bool(b_transposed), + ) + + return _warp_ldmatrix_b(B_local_buf, buffer, offset_k, offset_n, stride, warp_n, ki) + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.simdgroup_multiply_accumulate( + C_local_buf.data, + i * warp_cols + j, + A_local_buf.data, + i, + B_local_buf.data, + j, + C_local_buf.data, + i * warp_cols + j, + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + def simdgroup_copy(self, C_simd_buf, C_dst, is_store=True): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + micro_size_x = self.micro_size_x + micro_size_y = self.micro_size_y + + warp_m, warp_n = self._get_warp_indices() + + buffer, offset_m, offset_n, stride = self._parse_buffer_2d(C_dst) + + simd_op = T.simdgroup_store if is_store else T.simdgroup_load + access_mode = "w" if is_store else "r" + + @T.macro + def _simdgroup_copy(C_simd_buf, buffer, offset_m, offset_n, stride, warp_m, warp_n): + for i, j in T.grid(warp_rows, warp_cols): + row = offset_m + warp_m * self.warp_row_tiles + i * micro_size_x + col = offset_n + warp_n * self.warp_col_tiles + j * micro_size_y + + index_c = i * warp_cols + j + + simd_op( + C_simd_buf.data, + index_c, + T.access_ptr(buffer[row, col], access_mode), + stride, + micro_size_x, + micro_size_y, + T.bool(False), + ) + + return _simdgroup_copy(C_simd_buf, buffer, offset_m, offset_n, stride, warp_m, warp_n) + + def simd_store(self, C_simd_buf, C_dst): + return self.simdgroup_copy(C_simd_buf, C_dst, is_store=True) + + def simd_load(self, C_simd_buf, C_src): + return self.simdgroup_copy(C_simd_buf, C_src, is_store=False) diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 3669f9e35c..9cec6ab56c 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -79,7 +79,9 @@ def get_current_device_functor() -> Callable[[], torch.device]: current_device = torch._C._cuda_getDevice return lambda: torch.device("cuda", current_device()) except Exception: - return lambda: torch.device("cuda", torch.cuda.current_device()) + pass + if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): + return lambda: torch.device("mps") # CPU fallback return lambda: torch.device("cpu") diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 4690cf59bd..dfed68ae81 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -53,6 +53,9 @@ def __init__( _kernel = None + def get_kernel_source(self, kernel_only: bool = True) -> str: + return self.kernel_global_source or "" + def _convert_torch_func(self) -> Callable: if self._kernel is None: _kernel = getattr(torch.mps.compile_shader(self.kernel_global_source), self.kernel_name) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index a2290b0191..3f673e7f73 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -13,8 +13,9 @@ from .gemm_mfma import GemmMFMA from .gemm_wmma import GemmWMMA from .gemm_scalar import GemmScalar +from .gemm_metal import GemmMetal from tilelang import _ffi_api -from tilelang.utils.target import target_is_volta +from tilelang.utils.target import target_is_volta, target_is_metal @tvm_ffi.register_global_func("tl.gemm.infer_layout") @@ -157,8 +158,31 @@ def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst 1. TCGEN5MMA for Blackwell architecture 2. WGMMA for Hopper architecture with sufficient matrix size and warp count 3. MFMA for CDNA (AMD) architecture - 4. MMA for CUDA architecture - 5. Scalar for CPU target (scalar fallback) + 4. WMMA for RDNA (AMD) architecture + 5. MMA for CUDA architecture + 6. METAL_SIMDGROUP for Metal target (simdgroup_matrix) + 7. Scalar for CPU target (scalar fallback) + + Special-case on Metal: + + - FP8 inputs: Apple Silicon has no native FP8 hardware (M1-M5 + inclusive -- see Apple WWDC 2025 cooperative tensors session). + The TileLang Metal codegen rejects allocating metal.simdgroup + buffers with FP8 dtype (see codegen_metal.cc:454 -- + "Only float16, float32, and bfloat16 are supported"). We route + FP8-input GEMMs to the scalar fallback (GemmInst.Scalar), + which on Metal targets emits per-element T.cast(value, accum_dtype) + reads for both operands. The T.cast from FP8 to a wider dtype is + handled by the storage-only FP8 emulation patch in + codegen_metal.cc::VisitExpr_(CastNode) which expands to + __tvm_fp8_e4m3_to_half / __tvm_fp8_e5m2_to_half + helper calls (see + docs/upstream/tilelang_metal_fp8/0001-metal-fp8-storage-only.patch). + + This mirrors the audiohacking fp8_scaled_matmul_kernel + reference (https://github.com/audiohacking/fp8-mps-metal): a + scalar dequant-multiply-accumulate loop in half / + float rather than any FP8 simdgroup intrinsic. Args: thread_nums: Number of threads in the block @@ -167,8 +191,43 @@ def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst Returns: GemmInst: The selected GEMM instruction type """ + if target_is_metal(target): + # FP8 (e4m3 / e5m2 / e8m0fnu) inputs: Apple has no native FP8 + # ALU through M5; route to the scalar fallback so the per-element + # T.cast(..., accum_dtype) invokes the storage-only FP8 decode + # helpers. The runtime mapping of GemmInst.Scalar to + # GemmMetalScalar on Metal targets is provided by the + # tilelang_gemm_mixed_dtype companion patch (PR #2118 stack); + # without it the resulting kernel will not lower correctly, + # but the routing decision is the load-bearing change here. + if self._has_fp8_input_dtype(): + return GemmInst.Scalar + return GemmInst.METAL_SIMDGROUP return GemmInst(_ffi_api.GemmGetGemmInst(self, int(thread_nums), target)) + def _has_fp8_input_dtype(self) -> bool: + """Return True if either A or B carries an FP8 dtype. + + Used by the Metal dispatcher to force routing through the scalar + fallback (GemmMetalScalar via the companion mixed-dtype patch) + because Metal has no native FP8 ALU (Apple Silicon M1-M5 inclusive) + and the TileLang codegen rejects allocating metal.simdgroup buffers + with FP8 dtype. The scalar fallback's T.cast(..., accum_dtype) reads + invoke the storage-only FP8 decode helpers emitted by + codegen_metal.cc::VisitExpr_(CastNode). + """ + a = getattr(self, "a", None) + b = getattr(self, "b", None) + for buf in (a, b): + if buf is None: + continue + try: + if str(buf.dtype).startswith("float8"): + return True + except AttributeError: # pragma: no cover - defensive + continue + return False + def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): """Get the appropriate implementation class for the given GEMM instruction. @@ -197,5 +256,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst, target: Target): return GemmWMMA elif gemm_inst.is_scalar(): return GemmScalar + elif gemm_inst.is_metal_simdgroup(): + return GemmMetal else: raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}") diff --git a/tilelang/tileop/gemm/gemm_metal.py b/tilelang/tileop/gemm/gemm_metal.py new file mode 100644 index 0000000000..942cfbebfb --- /dev/null +++ b/tilelang/tileop/gemm/gemm_metal.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from .gemm_base import GemmBase +from .inst import GemmInst +from tilelang.utils.language import is_shared, is_full_region, is_metal_simdgroup, is_fragment +from tilelang import tvm as tvm +from tvm.target import Target +from tvm.ir import Range +from tvm import tir +from tilelang import language as T +from tilelang.transform.simplify import _Simplify + + +class GemmMetal(GemmBase): + def is_gemm_ss(self) -> bool: + return is_shared(self.A) and is_shared(self.B) + + def infer_layout(self, target: Target, thread_nums: int): + return {} + + def lower( + self, layout_map: dict, target: Target, thread_bounds: Range, thread_var: tir.Var, mbar_phase_expr: tir.PrimExpr | None = None + ): + thread_nums = thread_bounds.extent + for name, value in (("M", self.M), ("N", self.N), ("K", self.chunk)): + if value % 8 != 0: + raise ValueError(f"Metal GEMM requires {name} to be a multiple of 8, got {value}") + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.METAL_SIMDGROUP) + if self.M % m_warp != 0 or self.N % n_warp != 0: + raise ValueError(f"Metal GEMM cannot evenly partition {self.M}x{self.N} across {m_warp}x{n_warp} warps") + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + if warp_row_tiles % 8 != 0 or warp_col_tiles % 8 != 0: + raise ValueError(f"Metal GEMM per-warp tile must be a multiple of 8x8, got {warp_row_tiles}x{warp_col_tiles}") + + from tilelang.intrinsics.metal_macro_generator import MPSIntrinEmitter + + mps_emitter = MPSIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) + + in_dtype = self.in_dtype + accum_dtype = self.accum_dtype + warp_rows = mps_emitter.warp_rows + warp_cols = mps_emitter.warp_cols + num_simd_c = warp_rows * warp_cols + block_K = mps_emitter.chunk + micro_size_k = mps_emitter.micro_size_k + + A_region = self.ARegion + B_region = self.BRegion + C_region = self.CRegion + + C_buf = C_region.buffer + + clear_accum = self.clear_accum + c_in_simdgroup_reg = is_metal_simdgroup(C_buf) or is_fragment(C_buf) + + if block_K < micro_size_k: + raise ValueError(f"Metal GEMM requires block_K ({block_K}) to be >= micro_size_k ({micro_size_k})") + if block_K % micro_size_k != 0: + raise ValueError(f"Metal GEMM requires block_K ({block_K}) to be divisible by micro_size_k ({micro_size_k})") + if not is_full_region(C_region): + raise ValueError(f"Metal GEMM requires full output C region, got {C_region}") + if not c_in_simdgroup_reg and not is_shared(C_buf): + raise ValueError(f"Metal GEMM requires C in local.fragment, metal.simdgroup, or shared scope, got {C_buf.scope()}") + + if self.is_gemm_ss(): + if c_in_simdgroup_reg: + + @T.prim_func + def _gemm_ss_simdgroup() -> None: + A_local = T.alloc_local((warp_rows * 64), in_dtype, scope="metal.simdgroup") + B_local = T.alloc_local((warp_cols * 64), in_dtype, scope="metal.simdgroup") + if clear_accum: + for _i in T.serial(num_simd_c): + T.make_filled_simdgroup_matrix(C_buf.data, _i, T.cast(0, accum_dtype)) + for ki in T.serial(0, (block_K // micro_size_k)): + mps_emitter.ldmatrix_a(A_local, A_region, ki) + mps_emitter.ldmatrix_b(B_local, B_region, ki) + mps_emitter.mma(A_local, B_local, C_buf) + + return _Simplify(_gemm_ss_simdgroup, inline_let=True) + else: + + @T.prim_func + def _gemm_ss_shared() -> None: + A_local = T.alloc_local((warp_rows * 64), in_dtype, scope="metal.simdgroup") + B_local = T.alloc_local((warp_cols * 64), in_dtype, scope="metal.simdgroup") + C_simd = T.alloc_local((num_simd_c * 64), accum_dtype, scope="metal.simdgroup") + if clear_accum: + for _i in T.serial(num_simd_c): + T.make_filled_simdgroup_matrix(C_simd.data, _i, T.cast(0, accum_dtype)) + else: + mps_emitter.simd_load(C_simd, C_buf) + for ki in T.serial(0, (block_K // micro_size_k)): + mps_emitter.ldmatrix_a(A_local, A_region, ki) + mps_emitter.ldmatrix_b(B_local, B_region, ki) + mps_emitter.mma(A_local, B_local, C_simd) + + mps_emitter.simd_store(C_simd, C_buf) + + return _Simplify(_gemm_ss_shared, inline_let=True) + else: + raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") diff --git a/tilelang/tileop/gemm/inst.py b/tilelang/tileop/gemm/inst.py index cbfaf016f2..66902a42cc 100644 --- a/tilelang/tileop/gemm/inst.py +++ b/tilelang/tileop/gemm/inst.py @@ -8,7 +8,8 @@ class GemmInst(IntEnum): TCGEN5MMA = 2 MFMA = 3 Scalar = 4 - WMMA = 5 # AMD RDNA WMMA (gfx11/gfx12) + WMMA = 5 + METAL_SIMDGROUP = 6 def is_mma(self) -> bool: return self == GemmInst.MMA @@ -28,5 +29,8 @@ def is_scalar(self) -> bool: def is_wmma(self) -> bool: return self == GemmInst.WMMA + def is_metal_simdgroup(self) -> bool: + return self == GemmInst.METAL_SIMDGROUP + def __repr__(self) -> str: return self.name diff --git a/tilelang/tileop/metal_gdn.py b/tilelang/tileop/metal_gdn.py new file mode 100644 index 0000000000..0ea00b4325 --- /dev/null +++ b/tilelang/tileop/metal_gdn.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from tilelang import language as T +from tilelang.tileop import metal_simdgroup as metal_sg + + +@T.macro +def kkt_score_tile( + row_k_data, + col_k_data, + scores_data, + *, + block: int = 8, + key_dim: int = 8, +) -> None: + """Compute one 8x8 GDN KKT score tile from staged fp32 key tiles.""" + row_rt = metal_sg.alloc_rt(T.float32, 1, 1) + col_rt = metal_sg.alloc_rt(T.float32, 1, 1, layout=metal_sg.TileLayout.TRANSPOSED) + score_rt = metal_sg.alloc_rt(T.float32, 1, 1) + metal_sg.fill_rt(score_rt, T.float32(0.0)) + metal_sg.load_threadgroup_to_rt(row_rt, T.float32, row_k_data, 0, block * key_dim, key_dim) + metal_sg.load_threadgroup_to_rt( + col_rt, + T.float32, + col_k_data, + 0, + block * key_dim, + key_dim, + transpose=True, + ) + metal_sg.mma_abt(score_rt, row_rt, col_rt) + metal_sg.materialize_rt_to_shared(score_rt, T.float32, scores_data, 0, block * block, block) + + +@T.macro +def kkt_score_tile_accum( + row_k_data, + col_k_data, + scores_data, + *, + block: int = 8, + key_dim: int = 16, + key_offset: int = 0, + clear: bool = True, +) -> None: + """Accumulate one 8-column slice into a staged KKT 8x8 score tile.""" + row_rt = metal_sg.alloc_rt(T.float32, 1, 1) + col_rt = metal_sg.alloc_rt(T.float32, 1, 1, layout=metal_sg.TileLayout.TRANSPOSED) + score_rt = metal_sg.alloc_rt(T.float32, 1, 1) + if clear: + metal_sg.fill_rt(score_rt, T.float32(0.0)) + else: + metal_sg.load_threadgroup_to_rt(score_rt, T.float32, scores_data, 0, block * block, block) + metal_sg.load_threadgroup_to_rt(row_rt, T.float32, row_k_data, key_offset, block * key_dim, key_dim) + metal_sg.load_threadgroup_to_rt( + col_rt, + T.float32, + col_k_data, + key_offset, + block * key_dim, + key_dim, + transpose=True, + ) + metal_sg.mma_abt(score_rt, row_rt, col_rt) + metal_sg.materialize_rt_to_shared(score_rt, T.float32, scores_data, 0, block * block, block) + + +@T.macro +def apply_kkt_gate_triangular_tile( + scores, + g_row, + g_col, + a_pre, + head, + row_block, + col_block, + lane, + *, + block: int = 8, + chunk_size: int, + threads: int = 32, +) -> None: + """Apply GDN KKT gate decay and causal triangular mask to one score tile.""" + for linear in T.serial(lane, block * block, step=threads): + local_row = linear // block + local_col = linear - local_row * block + c = row_block * block + local_row + d = col_block * block + local_col + if c < chunk_size and d < chunk_size: + if d < c: + a_pre[c, head, d] = scores[local_row, local_col] * T.exp(g_row[local_row] - g_col[local_col]) + else: + a_pre[c, head, d] = 0.0 + + +@T.macro +def wu_linear_element( + k, + v, + beta, + g_cum, + a, + w, + u, + head, + linear, + *, + chunk_size: int, + key_dim: int, + value_dim: int, +) -> None: + """Compute one scalar W or U output element from solved GDN A.""" + c = linear // (key_dim + value_dim) + rem = linear - c * (key_dim + value_dim) + acc = T.alloc_var(T.float32) + acc = 0.0 + if rem < key_dim: + kk = rem + for d in T.serial(chunk_size): + acc += a[c, head, d] * T.cast(k[d, head, kk], T.float32) * beta[d, head] * T.exp(g_cum[d, head]) + w[c, head, kk] = acc + else: + vv = rem - key_dim + for d in T.serial(chunk_size): + acc += a[c, head, d] * T.cast(v[d, head, vv], T.float32) * beta[d, head] + u[c, head, vv] = acc + + +@T.macro +def wu_score_tiles_strided( + a_data, + k_scaled_data, + v_scaled_data, + w_acc, + u_acc, + *, + a_offset: int = 0, + k_offset: int = 0, + v_offset: int = 0, + a_stride: int = 16, + kv_stride: int = 16, + block: int = 8, +) -> None: + """Accumulate one strided 8x8 A/K/V tile slice into W and U outputs.""" + a_rt = metal_sg.alloc_rt(T.float32, 1, 1) + k_rt = metal_sg.alloc_rt(T.float32, 1, 1) + v_rt = metal_sg.alloc_rt(T.float32, 1, 1) + metal_sg.load_threadgroup_to_rt(a_rt, T.float32, a_data, a_offset, block * a_stride, a_stride) + metal_sg.load_threadgroup_to_rt(k_rt, T.float32, k_scaled_data, k_offset, block * kv_stride, kv_stride) + metal_sg.load_threadgroup_to_rt(v_rt, T.float32, v_scaled_data, v_offset, block * kv_stride, kv_stride) + metal_sg.mma_ab(w_acc, a_rt, k_rt) + metal_sg.mma_ab(u_acc, a_rt, v_rt) + + +@T.macro +def wu_score_tiles( + a_data, + k_scaled_data, + v_scaled_data, + w_acc, + u_acc, + *, + block: int = 8, +) -> None: + """Accumulate one staged 8x8 A tile into W and U RegisterTile outputs.""" + a_rt = metal_sg.alloc_rt(T.float32, 1, 1) + k_rt = metal_sg.alloc_rt(T.float32, 1, 1) + v_rt = metal_sg.alloc_rt(T.float32, 1, 1) + metal_sg.load_threadgroup_to_rt(a_rt, T.float32, a_data, 0, block * block, block) + metal_sg.load_threadgroup_to_rt(k_rt, T.float32, k_scaled_data, 0, block * block, block) + metal_sg.load_threadgroup_to_rt(v_rt, T.float32, v_scaled_data, 0, block * block, block) + metal_sg.mma_ab(w_acc, a_rt, k_rt) + metal_sg.mma_ab(u_acc, a_rt, v_rt) diff --git a/tilelang/tileop/metal_quant.py b/tilelang/tileop/metal_quant.py new file mode 100644 index 0000000000..c6b25d2868 --- /dev/null +++ b/tilelang/tileop/metal_quant.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from tilelang import language as T + + +FP32 = "float32" + + +@dataclass(frozen=True) +class QuantSimdgroupTile: + block_m: int + block_n: int + block_k: int + wm: int + wn: int + + +SMALL_TILE = QuantSimdgroupTile(block_m=16, block_n=32, block_k=32, wm=1, wn=1) +LARGE_TILE = QuantSimdgroupTile(block_m=32, block_n=32, block_k=32, wm=1, wn=2) + + +def use_large_simdgroup_tile(m: int, n: int, *, mixed_fp4_weight: bool = False) -> bool: + """Shape-only quant contraction selector for Metal packed uint8 probes. + + fp8 x fp8 starts winning once there is enough row and output-column work to + amortize threadgroup staging. The mixed fp8/fp4 path has more decode and + scale traffic, so keep the middle ``N=256`` band on scalar/GEMV schedules + until a better mixed simdgroup tile lands. + """ + if mixed_fp4_weight: + return m >= 64 and (n == 128 or n >= 512) + return m >= 64 and n >= 256 + + +def selected_simdgroup_tile(m: int, n: int, *, mixed_fp4_weight: bool = False) -> QuantSimdgroupTile: + return LARGE_TILE if use_large_simdgroup_tile(m, n, mixed_fp4_weight=mixed_fp4_weight) else SMALL_TILE + + +def use_small_m_gemv(m: int, n: int, *, mixed_fp4_weight: bool = False) -> bool: + """Shape-only selector for promoted small-M packed quant GEMV schedules.""" + if mixed_fp4_weight: + if 1 <= m <= 16: + return n >= 64 + if 17 <= m <= 24: + return n >= 128 + if 25 <= m <= 32: + return n == 256 + if 33 <= m <= 48: + return n >= 128 + return False + if 1 <= m <= 32: + return n >= 128 + return False + + +def fp8_e4m3fn_to_float(bits): + """Decode packed uint8 e4m3fn to fp32 inside TileLang/Metal kernels.""" + bits_u = T.Cast("uint32", bits) + abs_bits = bits_u & T.uint32(0x7F) + sign = (bits_u >> T.uint32(7)) & T.uint32(1) + exp_bits = (bits_u >> T.uint32(3)) & T.uint32(0xF) + mant_bits = bits_u & T.uint32(0x7) + + mant = T.Cast(FP32, mant_bits) + subnormal = mant * T.float32(1.0 / 512.0) + normal = (T.float32(1.0) + mant * T.float32(1.0 / 8.0)) * T.exp2(T.Cast(FP32, T.Cast("int32", exp_bits) - T.int32(7))) + value = T.if_then_else(exp_bits == T.uint32(0), subnormal, normal) + value = T.if_then_else(abs_bits == T.uint32(0x7F), T.float32(0.0), value) + return T.if_then_else(sign != T.uint32(0), -value, value) + + +def fp4_e2m1fn_to_float(bits, nibble_index): + """Decode one e2m1fn nibble from packed uint8 storage.""" + bits_u = T.Cast("uint32", bits) + shift = T.Cast("uint32", nibble_index) * T.uint32(4) + nibble = (bits_u >> shift) & T.uint32(0xF) + sign = (nibble >> T.uint32(3)) & T.uint32(1) + mag = nibble & T.uint32(0x7) + value = T.if_then_else( + mag == T.uint32(0), + T.float32(0.0), + T.if_then_else( + mag == T.uint32(1), + T.float32(0.5), + T.if_then_else( + mag == T.uint32(2), + T.float32(1.0), + T.if_then_else( + mag == T.uint32(3), + T.float32(1.5), + T.if_then_else( + mag == T.uint32(4), + T.float32(2.0), + T.if_then_else( + mag == T.uint32(5), + T.float32(3.0), + T.if_then_else(mag == T.uint32(6), T.float32(4.0), T.float32(6.0)), + ), + ), + ), + ), + ), + ) + return T.if_then_else(sign != T.uint32(0), -value, value) + + +def e8m0_to_float(bits): + """Decode torch.float8_e8m0fnu-compatible scale byte to fp32.""" + bits_i = T.Cast("int32", bits) + value = T.exp2(T.Cast(FP32, bits_i - T.int32(127))) + return T.if_then_else(bits_i == T.int32(255), T.float32(0.0), value) diff --git a/tilelang/tileop/metal_simdgroup.py b/tilelang/tileop/metal_simdgroup.py new file mode 100644 index 0000000000..7c0ddffa66 --- /dev/null +++ b/tilelang/tileop/metal_simdgroup.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from tilelang import language as T + + +class TileLayout(str, Enum): + """Internal Metal register-tile layout metadata. + + These values describe how higher-level tile code intends to interpret a + tile. The current MSL lowering still uses explicit simdgroup load/store + transpose flags; this metadata is deliberately internal until the layout + contract has survived GEMM, attention, and MoE retargeting. + """ + + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + TRANSPOSED = "transposed" + + +@dataclass(frozen=True) +class RegisterTile: + """Opaque array of Metal 8x8 simdgroup register fragments. + + The ``fragment`` object is still the only object passed to TileLang/Metal + intrinsics. This metadata gives compiler-owned lowerings a reusable way to + address arrays of fragments without exposing scalar fragment indexing. + """ + + fragment: object + fragments_m: int + fragments_n: int + rows: int = 8 + cols: int = 8 + layout: TileLayout = TileLayout.ROW_MAJOR + + def __post_init__(self) -> None: + if not isinstance(self.layout, TileLayout): + object.__setattr__(self, "layout", TileLayout(self.layout)) + if self.fragments_m <= 0 or self.fragments_n <= 0: + raise ValueError(f"RegisterTile fragment counts must be positive, got {self.fragments_m}x{self.fragments_n}") + if self.rows != 8 or self.cols != 8: + raise ValueError(f"Metal register tiles are 8x8 fragments, got {self.rows}x{self.cols}") + + @property + def data(self): + return self.fragment.data + + def index(self, tile_m: int, tile_n: int = 0) -> int: + if isinstance(tile_m, int) and not 0 <= tile_m < self.fragments_m: + raise IndexError(f"tile_m {tile_m} out of bounds for {self.fragments_m} register-tile rows") + if isinstance(tile_n, int) and not 0 <= tile_n < self.fragments_n: + raise IndexError(f"tile_n {tile_n} out of bounds for {self.fragments_n} register-tile columns") + return tile_m * self.fragments_n + tile_n + + +@dataclass(frozen=True) +class MMATile(RegisterTile): + """Backward-compatible name for existing internal simdgroup users.""" + + +@dataclass(frozen=True) +class RowVector: + """Internal row vector backed by explicit scalar storage. + + Row vectors intentionally do not index ``metal.simdgroup`` fragments. They + operate on materialized buffers until there is a native register-vector + lowering for row reductions and normalization. + """ + + values: object + length: int + dtype: object = T.float32 + + @property + def data(self): + return self.values.data + + +def _require_layout(tile: RegisterTile, expected: TileLayout, role: str, op_name: str) -> None: + if tile.layout != expected: + raise ValueError(f"{op_name} requires {role} layout {expected.value}, got {tile.layout.value}") + + +def _require_load_layout(tile: RegisterTile, transpose: bool, op_name: str) -> None: + expected = TileLayout.TRANSPOSED if transpose else TileLayout.ROW_MAJOR + if tile.layout != expected: + mode = "transposed" if transpose else "row-major" + raise ValueError(f"{op_name} {mode} load requires tile layout {expected.value}, got {tile.layout.value}") + + +def _require_store_layout(tile: RegisterTile, transpose: bool, op_name: str) -> None: + expected = TileLayout.TRANSPOSED if transpose else TileLayout.ROW_MAJOR + if tile.layout != expected: + mode = "transposed" if transpose else "row-major" + raise ValueError(f"{op_name} {mode} store requires tile layout {expected.value}, got {tile.layout.value}") + + +@T.macro +def alloc_rt( + dtype, + fragments_m: int, + fragments_n: int = 1, + *, + rows: int = 8, + cols: int = 8, + layout: TileLayout = TileLayout.ROW_MAJOR, +) -> RegisterTile: + """Allocate an internal Metal register tile backed by 8x8 fragments.""" + rt_fragment = T.alloc_fragment((fragments_m * fragments_n, rows, cols), dtype, scope="metal.simdgroup") + return RegisterTile(rt_fragment, fragments_m, fragments_n, rows, cols, layout) + + +@T.macro +def fill(fragment, matrix_index, value, rows: int = 8, cols: int = 8) -> None: + """Fill one opaque Metal simdgroup matrix fragment.""" + T.make_filled_simdgroup_matrix(fragment.data, matrix_index, value, rows, cols) + + +@T.macro +def access_ptr(dtype, data, offset, extent, rw_mask: int): + return T.tvm_access_ptr(T.type_annotation(dtype), data, offset, extent, rw_mask) + + +@T.macro +def load( + fragment, + matrix_index, + dtype, + data, + offset, + extent, + stride, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + T.simdgroup_load( + fragment.data, + matrix_index, + access_ptr(dtype, data, offset, extent, 1), + stride, + rows, + cols, + T.bool(transpose), + ) + + +@T.macro +def store( + fragment, + matrix_index, + dtype, + data, + offset, + extent, + stride, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + T.simdgroup_store( + fragment.data, + matrix_index, + access_ptr(dtype, data, offset, extent, 2), + stride, + rows, + cols, + T.bool(transpose), + ) + + +@T.macro +def mma(acc, a, b, acc_index=0, a_index=0, b_index=0, out_index=None) -> None: + """Accumulate ``a @ b`` into an opaque Metal simdgroup accumulator.""" + if out_index is None: + out_index = acc_index + T.simdgroup_multiply_accumulate( + acc.data, + out_index, + a.data, + a_index, + b.data, + b_index, + acc.data, + acc_index, + ) + + +@T.macro +def fill_tile(tile: MMATile, value) -> None: + for tile_m in T.unroll(tile.fragments_m, explicit=True): + for tile_n in T.unroll(tile.fragments_n, explicit=True): + fill(tile.fragment, tile.index(tile_m, tile_n), value, tile.rows, tile.cols) + + +@T.macro +def fill_rt(tile: RegisterTile, value) -> None: + """Fill every 8x8 fragment in a register tile.""" + for tile_m in T.unroll(tile.fragments_m, explicit=True): + for tile_n in T.unroll(tile.fragments_n, explicit=True): + fill(tile.fragment, tile.index(tile_m, tile_n), value, tile.rows, tile.cols) + + +@T.macro +def load_tile( + tile: MMATile, + dtype, + data, + offset, + extent, + stride, + *, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + load( + tile.fragment, + tile.index(0, 0), + dtype, + data, + offset, + extent, + stride, + rows, + cols, + transpose, + ) + + +@T.macro +def load_global_to_rt( + tile: RegisterTile, + dtype, + data, + offset, + extent, + stride, + *, + tile_m: int = 0, + tile_n: int = 0, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + """Load one 8x8 global-memory tile into a register-tile fragment.""" + _require_load_layout(tile, transpose, "load_global_to_rt") + load( + tile.fragment, + tile.index(tile_m, tile_n), + dtype, + data, + offset, + extent, + stride, + rows, + cols, + transpose, + ) + + +@T.macro +def load_threadgroup_to_rt( + tile: RegisterTile, + dtype, + data, + offset, + extent, + stride, + *, + tile_m: int = 0, + tile_n: int = 0, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + """Load one 8x8 threadgroup-memory tile into a register-tile fragment.""" + _require_load_layout(tile, transpose, "load_threadgroup_to_rt") + load( + tile.fragment, + tile.index(tile_m, tile_n), + dtype, + data, + offset, + extent, + stride, + rows, + cols, + transpose, + ) + + +@T.macro +def store_tile( + tile: MMATile, + dtype, + data, + offset, + extent, + stride, + *, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + store( + tile.fragment, + tile.index(0, 0), + dtype, + data, + offset, + extent, + stride, + rows, + cols, + transpose, + ) + + +@T.macro +def store_rt( + tile: RegisterTile, + dtype, + data, + offset, + extent, + stride, + *, + tile_m: int = 0, + tile_n: int = 0, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + """Store one 8x8 register-tile fragment through explicit materialization.""" + _require_store_layout(tile, transpose, "store_rt") + store( + tile.fragment, + tile.index(tile_m, tile_n), + dtype, + data, + offset, + extent, + stride, + rows, + cols, + transpose, + ) + + +@T.macro +def materialize_rt_to_shared( + tile: RegisterTile, + dtype, + data, + offset, + extent, + stride, + *, + tile_m: int = 0, + tile_n: int = 0, + rows: int = 8, + cols: int = 8, + transpose: bool = False, +) -> None: + """Materialize one register-tile fragment into explicit shared storage.""" + store_rt( + tile, + dtype, + data, + offset, + extent, + stride, + tile_m=tile_m, + tile_n=tile_n, + rows=rows, + cols=cols, + transpose=transpose, + ) + + +@T.macro +def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: + for tile_m in T.unroll(acc.fragments_m, explicit=True): + for tile_n in T.unroll(acc.fragments_n, explicit=True): + mma( + acc.fragment, + a.fragment, + b.fragment, + acc.index(tile_m, tile_n), + a.index(tile_m, 0), + b.index(0, tile_n), + ) + + +@T.macro +def mma_ab( + acc: RegisterTile, + a: RegisterTile, + b: RegisterTile, + *, + acc_m: int = 0, + acc_n: int = 0, + a_m: int = 0, + a_n: int = 0, + b_m: int = 0, + b_n: int = 0, +) -> None: + """Accumulate ``A @ B`` into one accumulator tile fragment.""" + _require_layout(acc, TileLayout.ROW_MAJOR, "accumulator", "mma_ab") + _require_layout(a, TileLayout.ROW_MAJOR, "A", "mma_ab") + _require_layout(b, TileLayout.ROW_MAJOR, "B", "mma_ab") + mma( + acc.fragment, + a.fragment, + b.fragment, + acc.index(acc_m, acc_n), + a.index(a_m, a_n), + b.index(b_m, b_n), + ) + + +@T.macro +def mma_abt( + acc: RegisterTile, + a: RegisterTile, + bt: RegisterTile, + *, + acc_m: int = 0, + acc_n: int = 0, + a_m: int = 0, + a_n: int = 0, + b_m: int = 0, + b_n: int = 0, +) -> None: + """Accumulate ``A @ B.T`` after ``B`` has been loaded transposed.""" + _require_layout(acc, TileLayout.ROW_MAJOR, "accumulator", "mma_abt") + _require_layout(a, TileLayout.ROW_MAJOR, "A", "mma_abt") + _require_layout(bt, TileLayout.TRANSPOSED, "B", "mma_abt") + mma( + acc.fragment, + a.fragment, + bt.fragment, + acc.index(acc_m, acc_n), + a.index(a_m, a_n), + bt.index(b_m, b_n), + ) + + +@T.macro +def prefix_block_vector( + src, + head, + block_index, + dst, + *, + block: int, + length: int, + writeback=None, + writeback_guard=True, +) -> None: + """Compute an inclusive block-local prefix vector from a 2D source.""" + block_start = block_index * block + acc = T.alloc_var(T.float32) + acc = 0.0 + for idx in T.serial(block_start): + acc += src[idx, head] + for local_idx in T.serial(block): + token = block_start + local_idx + value = T.alloc_var(T.float32) + value = 0.0 + if token < length: + acc += src[token, head] + value = acc + if writeback is not None and writeback_guard: + writeback[token, head] = value + if writeback is not None and not writeback_guard: + writeback[token, head] = value + dst[local_idx] = value + + +@T.macro +def row_max(src, dst: RowVector, *, rows: int, cols: int, clear: bool = True) -> None: + """Compute per-row maxima over a materialized scalar tile.""" + for row in T.Parallel(rows): + acc = T.alloc_var(dst.dtype) + if clear: + acc = T.cast(-3.4028234663852886e38, dst.dtype) + else: + acc = dst.values[row] + for col in T.serial(cols): + acc = T.max(acc, src[row, col]) + dst.values[row] = acc + + +@T.macro +def row_sum(src, dst: RowVector, *, rows: int, cols: int, clear: bool = True) -> None: + """Compute per-row sums over a materialized scalar tile.""" + for row in T.Parallel(rows): + acc = T.alloc_var(dst.dtype) + if clear: + acc = T.cast(0, dst.dtype) + else: + acc = dst.values[row] + for col in T.serial(cols): + acc += src[row, col] + dst.values[row] = acc + + +@T.macro +def mul_row(src, vec: RowVector, *, rows: int, cols: int) -> None: + """Scale each materialized scalar-tile row by a row-vector value.""" + for row, col in T.Parallel(rows, cols): + src[row, col] *= vec.values[row] + + +@T.macro +def div_row(src, vec: RowVector, *, rows: int, cols: int) -> None: + """Divide each materialized scalar-tile row by a row-vector value.""" + for row, col in T.Parallel(rows, cols): + src[row, col] /= vec.values[row] diff --git a/tilelang/transform/decouple_type_cast.py b/tilelang/transform/decouple_type_cast.py index eafd8b36f5..8c1d20234b 100644 --- a/tilelang/transform/decouple_type_cast.py +++ b/tilelang/transform/decouple_type_cast.py @@ -68,14 +68,14 @@ # Cache the Op for if_then_else to avoid repeated lookups _IF_THEN_ELSE_OP = Op.get("tir.if_then_else") -from tilelang.utils.language import is_fragment, is_global, is_local, is_local_var, is_shared +from tilelang.utils.language import is_fragment, is_global, is_local, is_local_var, is_shared, is_metal_simdgroup def is_local_buffer(buffer: Buffer) -> bool: - """Check if a buffer is local (register-level), including local.var.""" + """Check if a buffer is local (register-level), including local.var and metal.simdgroup.""" if buffer is None: return False - return is_local(buffer) or is_fragment(buffer) or is_local_var(buffer) + return is_local(buffer) or is_fragment(buffer) or is_local_var(buffer) or is_metal_simdgroup(buffer) def is_global_or_shared_buffer(buffer: Buffer) -> bool: diff --git a/tilelang/transform/metal_fragment_to_simdgroup.py b/tilelang/transform/metal_fragment_to_simdgroup.py new file mode 100644 index 0000000000..7577c619f1 --- /dev/null +++ b/tilelang/transform/metal_fragment_to_simdgroup.py @@ -0,0 +1,191 @@ +"""Rewrite local.fragment → metal.simdgroup for GEMM accumulators on Metal.""" + +from __future__ import annotations + +from tvm import tir, IRModule +from tvm.ir import Op, PointerType +from tvm.tir.transform import prim_func_pass + +_GEMM_OPS = None + + +def _get_gemm_ops(): + global _GEMM_OPS + if _GEMM_OPS is None: + _GEMM_OPS = { + Op.get("tl.tileop.gemm"), + Op.get("tl.tileop.wgmma_gemm"), + Op.get("tl.tileop.tcgen05_gemm"), + } + return _GEMM_OPS + + +def _extract_buffer_var_from_region(region_call): + if not isinstance(region_call, tir.Call): + return None + if len(region_call.args) < 1: + return None + buf_load = region_call.args[0] + if isinstance(buf_load, tir.BufferLoad): + return buf_load.buffer.data + return None + + +def _extract_buffer_from_region(region_call): + """Return the Buffer (not Var) referenced by a tl.region(BufferLoad, ...) call. + + Used to inspect operand dtypes so we can detect FP8-input GEMMs (which + on Metal route to the scalar fallback rather than simdgroup MMA and + therefore must NOT have their accumulator scope rewritten to + metal.simdgroup -- the Metal codegen rejects FP8 in metal.simdgroup + allocations). + """ + if not isinstance(region_call, tir.Call): + return None + if len(region_call.args) < 1: + return None + buf_load = region_call.args[0] + if isinstance(buf_load, tir.BufferLoad): + return buf_load.buffer + return None + + +def _is_fp8_dtype(dt) -> bool: + """Return True if the dtype is one of the FP8 storage variants. + + On Metal the FP8 path is storage-only emulation (see + docs/upstream/tilelang_metal_fp8/0001-metal-fp8-storage-only.patch) + so any GEMM with FP8 inputs must take the scalar fallback. Used here + to keep the C accumulator out of the metal.simdgroup rewrite -- that + scope rejects FP8 allocations in codegen_metal.cc (line ~454). + """ + try: + return str(dt).startswith("float8") + except Exception: # pragma: no cover - defensive + return False + + +def _collect_fragment_gemm_accum_vars(body: tir.Stmt) -> set: + """Walk the body and return fragment vars safe to rewrite to simdgroup. + + GEMM accumulators backed by ``local.fragment`` are eligible for the + rewrite to ``metal.simdgroup``, which the Metal simdgroup MMA path + needs. We exclude FP8-input GEMMs because the dispatcher routes them + to the scalar fallback (Apple has no native FP8 ALU through M5; the + per-element T.cast invokes the storage-only decode helpers from the + FP8 prelude -- see audiohacking fp8_scaled_matmul_kernel for the + analogous pattern). For those GEMMs the accumulator must stay in + ``local.fragment`` so the scalar fallback can perform its + per-element T.cast(..., accum_dtype) arithmetic without tripping the + Metal codegen's check that ``metal.simdgroup`` allocations are + scalar 8x8 blocks. + """ + accum_vars: set = set() + gemm_ops = _get_gemm_ops() + + def _visitor(stmt): + if isinstance(stmt, tir.Evaluate) and isinstance(stmt.value, tir.Call): + call = stmt.value + if call.op in gemm_ops and len(call.args) >= 3: + # FP8 inputs (storage-only on Metal) route to the scalar + # fallback; exclude their accumulators from the simdgroup + # rewrite so the codegen does not allocate a + # ``metal.simdgroup`` buffer for them. + a_buf = _extract_buffer_from_region(call.args[0]) + b_buf = _extract_buffer_from_region(call.args[1]) + fp8_inputs = (a_buf is not None and _is_fp8_dtype(a_buf.dtype)) or \ + (b_buf is not None and _is_fp8_dtype(b_buf.dtype)) + if fp8_inputs: + return + var = _extract_buffer_var_from_region(call.args[2]) + if var is not None and hasattr(var, "type_annotation"): + ta = var.type_annotation + if ta is not None and hasattr(ta, "storage_scope") and ta.storage_scope == "local.fragment": + accum_vars.add(var) + + tir.stmt_functor.post_order_visit(body, _visitor) + return accum_vars + + +def _remap_buffer(buf, var_map): + old_data = buf.data + new_data = var_map.get(old_data, None) + if new_data is None: + return buf + return tir.decl_buffer( + buf.shape, + buf.dtype, + buf.name, + data=new_data, + scope="metal.simdgroup", + data_alignment=buf.data_alignment, + offset_factor=buf.offset_factor, + ) + + +def _rewrite_scope(body, var_map): + buf_map = {} + + def _pre_order(stmt): + if isinstance(stmt, tir.Block): + new_alloc_bufs = [] + changed = False + for buf in stmt.alloc_buffers: + new_buf = _remap_buffer(buf, var_map) + new_alloc_bufs.append(new_buf) + if not new_buf.same_as(buf): + buf_map[buf] = new_buf + changed = True + if changed: + new_body = tir.stmt_functor.substitute(stmt.body, var_map) + new_block = tir.Block( + stmt.iter_vars, + stmt.reads, + stmt.writes, + stmt.name_hint, + new_body, + stmt.init, + new_alloc_bufs, + stmt.match_buffers, + stmt.annotations, + ) + return ( + tir.BlockRealize( + stmt.iter_vars, + tir.const(True, "bool"), + new_block, + ) + if False + else new_block + ) + elif isinstance(stmt, tir.Allocate): + new_var = var_map.get(stmt.buffer_var, None) + if new_var is not None: + new_body = tir.stmt_functor.substitute(stmt.body, var_map) + return tir.Allocate(new_var, stmt.dtype, stmt.extents, stmt.condition, new_body, stmt.annotations) + return None + + return tir.stmt_functor.ir_transform(body, _pre_order, None, ["tir.Block", "tir.Allocate"]) + + +def _metal_fragment_to_simdgroup(func: tir.PrimFunc, mod: IRModule, ctx) -> tir.PrimFunc: + target = func.attrs.get("target", None) + if target is None or target.kind.name != "metal": + return func + + accum_vars = _collect_fragment_gemm_accum_vars(func.body) + if not accum_vars: + return func + + var_map: dict = {} + for var in accum_vars: + ptr_type = var.type_annotation + new_ptr = PointerType(ptr_type.element_type, "metal.simdgroup") + new_var = tir.Var(var.name, new_ptr) + var_map[var] = new_var + + new_body = _rewrite_scope(func.body, var_map) + return func.with_body(new_body) + + +MetalFragmentToSimdgroup = prim_func_pass(_metal_fragment_to_simdgroup, opt_level=0, name="tl.MetalFragmentToSimdgroup") diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 53a5730ca2..0ad52839e4 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -118,6 +118,20 @@ def is_fragment(buffer: BufferLikeType) -> bool: return buffer.scope().startswith("local.fragment") +def is_metal_simdgroup(buffer: BufferLikeType) -> bool: + """ + Check if the buffer is in the Metal simdgroup scope. + + Args: + buffer: The TVM buffer, BufferLoad, or BufferRegion to check. + + Returns: + bool: True if the buffer is in metal.simdgroup scope, False otherwise. + """ + buffer = _get_buffer(buffer) + return buffer.scope() == "metal.simdgroup" + + def is_local_var(buffer: BufferLikeType) -> bool: """ Check if the buffer is in the local.var memory scope.