Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions benchmark/matmul_metal/benchmark_matmul_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import logging
import time

import torch

import tilelang
import tilelang.language as T

logging.getLogger("tilelang").setLevel(logging.WARNING)

BLOCK_CONFIGS = [
(16, 16, 16),
(32, 32, 16),
(32, 32, 32),
(64, 64, 32),
]


@tilelang.jit
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):

@T.prim_func
def gemm_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

return gemm_kernel


def _tflops(M, N, K, seconds):
return 2.0 * M * N * K / seconds / 1e12


def _bench(fn, warmup, repeats):
for _ in range(warmup):
fn()
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(repeats):
fn()
torch.mps.synchronize()
return (time.perf_counter() - t0) / repeats


def bench_torch_mps(M, N, K, warmup, repeats):
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
return _tflops(M, N, K, avg_s)


def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
return _tflops(M, N, K, avg_s)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
parser.add_argument("--m", type=int, default=4096)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--k", type=int, default=4096)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k
for name, value in (("m", M), ("n", N), ("k", K), ("repeats", args.repeats)):
if value <= 0:
raise SystemExit(f"--{name} must be a positive integer, got {value}")
if args.warmup < 0:
raise SystemExit(f"--warmup must be non-negative, got {args.warmup}")

print(f"torch: {torch.__version__}")
print(f"tilelang: {tilelang.__version__}")
print(f"MPS: {torch.backends.mps.is_available()}")
if not torch.backends.mps.is_available():
raise SystemExit("Metal GEMM benchmark requires PyTorch MPS support")
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
print()

ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
print()

configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]

print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
print("-" * 44)

best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
# requirement as wide as possible to be compatible with other libraries
# pip will try to use latest version whenever possible.
"apache-tvm-ffi~=0.1.0,>=0.1.2",
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
# torch-c-dlpack-ext provides prebuilt torch extensions.
# Without it, TVM FFI may require JIT compilation on first import.
"torch-c-dlpack-ext; python_version < '3.14'",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Requirements to run local build with `--no-build-isolation` or other developments

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
build
cmake>=3.26
cython>=3.1.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Runtime requirements

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
torch-c-dlpack-ext; python_version < '3.14'
cloudpickle
ml-dtypes
Expand Down
10 changes: 9 additions & 1 deletion src/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Metal backend: source files and build configuration.

# Metal codegen is pure C++ and can generate Metal shader source on any
# platform. Always compile it so target.build.tilelang_metal is available for
# cross-compilation and source-level tests on non-Apple hosts.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)
Comment on lines +3 to +8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== CMake block =="
sed -n '1,40p' src/backend/metal/CMakeLists.txt

echo
echo "== codegen_metal.cc includes =="
sed -n '1,50p' src/target/codegen_metal.cc | grep '^#include'

echo
echo "== metal_module.h presence in this checkout =="
fd -a 'metal_module\.h$' . || true

Repository: tile-ai/tilelang

Length of output: 1331


Unconditionally compiling src/target/codegen_metal.cc breaks non-Metal builds.

The CMakeLists.txt appends src/target/codegen_metal.cc to TILE_LANG_SRCS before the USE_METAL and APPLE guards take effect. However, this file includes runtime/metal/metal_module.h, which is not available on non-Metal or non-Apple platforms. This causes compilation to fail before the early return() statements can skip the Metal-specific configuration.

To fix: either move the source append behind the appropriate guards, or refactor codegen_metal.cc to isolate the Metal runtime dependencies from the codegen-only logic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/metal/CMakeLists.txt` around lines 3 - 8, The CMakeLists
currently unconditionally appends src/target/codegen_metal.cc to TILE_LANG_SRCS
causing build failures on non-Apple/Non-Metal hosts because
runtime/metal/metal_module.h is missing; fix by moving the list(APPEND
TILE_LANG_SRCS src/target/codegen_metal.cc) so it is executed only inside the
existing guards (USE_METAL and APPLE) in CMakeLists.txt, or alternatively
refactor the codegen_metal.cc file to split Metal runtime includes from pure
codegen logic (e.g., create a codegen_metal_core.cc without
runtime/metal/metal_module.h and keep runtime-dependent bits in a separate file)
and update TILE_LANG_SRCS to only add the runtime-dependent file under the
USE_METAL/APPLE guards so non-Metal builds no longer attempt to compile runtime
headers.


if(NOT USE_METAL)
return()
endif()
Expand All @@ -7,7 +15,7 @@ if(NOT APPLE)
# On non-Apple platforms USE_METAL=ON enables only codegen (Metal source
# generation) without requiring the Metal/Foundation frameworks.
message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)")
set(USE_METAL OFF)
return()
endif()

file(GLOB TILE_LANG_METAL_SRCS
Expand Down
171 changes: 168 additions & 3 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
return result_map;
}

if (copy_inst == CopyInst::kMetalSIMDGroup) {
return {};
}

// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
Expand Down Expand Up @@ -792,8 +796,47 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map,
if (!CheckCPAsyncCopyPreconditions()) {
return false;
}
// Skip vectorize size check here because, during the Infer Layout stage,
// the layout is not stable and the vectorized size cannot be determined.
return true;
}

bool CopyNode::CheckSIMDGroupCopy(Target target) const {
if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) {
return false;
}
if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) {
return false;
}
if (src->dtype != dst->dtype) {
return false;
}
if (src_range.size() != 2 || dst_range.size() != 2 ||
dst->shape.size() != 2) {
return false;
}

int total_elements = 1;
for (auto extent : src->shape) {
auto imm = extent.as<IntImmNode>();
if (!imm) {
return false;
}
total_elements *= imm->value;
}
if (total_elements % 64 != 0) {
return false;
}

for (int i = 0; i < 2; ++i) {
auto src_shape = src->shape[i].as<IntImmNode>();
auto src_min = src_range[i]->min.as<IntImmNode>();
auto src_extent = src_range[i]->extent.as<IntImmNode>();
auto dst_extent = dst_range[i]->extent.as<IntImmNode>();
if (!src_shape || !src_min || src_min->value != 0 || !src_extent ||
!dst_extent || src_extent->value != src_shape->value ||
src_extent->value != dst_extent->value || src_extent->value % 8 != 0) {
return false;
}
}
return true;
Comment on lines +802 to 840
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard Metal SIMD-group copy against edge-tile OOB stores.

This legality check never proves that dst_range is fully in-bounds, but LowerSIMDGroupCopy later emits unpredicated simdgroup_store calls. For boundary tiles on non-divisible shapes, this can select kMetalSIMDGroup and write past dst where LowerNormalCopy would have kept the bounds predicate. Please thread buffer_oob/analyzer-based in-bounds checks through this path before returning true.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 802 - 840, CopyNode::CheckSIMDGroupCopy
currently accepts Metal SIMD-group copies without proving dst_range is fully
in-bounds, which allows LowerSIMDGroupCopy to emit unpredicated simdgroup_store
and OOB stores; modify CheckSIMDGroupCopy to consult the
buffer_oob/analyzer-based in-bounds check (the same predicate used elsewhere)
for dst (and if relevant src) and only return true when the analyzer proves
dst_range is fully inside the dst buffer bounds; thread the analyzer/buffer_oob
check into this path (or call the existing helper used by other copy legality
checks) so that CheckSIMDGroupCopy refuses SIMD-group selection for boundary
tiles that might be OOB.

}

Expand Down Expand Up @@ -864,6 +907,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map,
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else if (CheckSIMDGroupCopy(target)) {
return CopyInst::kMetalSIMDGroup;
} else {
return CopyInst::kNormal;
}
Expand Down Expand Up @@ -897,6 +942,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto cp_async_copy = LowerCPAsyncCopy(T, analyzer);
ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy";
return cp_async_copy;
} else if (copy_inst == CopyInst::kMetalSIMDGroup) {
return LowerSIMDGroupCopy(T, analyzer);
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else {
Expand Down Expand Up @@ -982,7 +1029,125 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T,
return cp_async_loop;
}

// Lowers the copy using standard load/store with loop transformations.
Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(IsSIMDGroupBuffer(src));
int total_elements = 1;
for (auto s : src->shape) {
auto imm = s.as<IntImmNode>();
ICHECK(imm) << "simdgroup buffer must have constant shape";
total_elements *= imm->value;
}
ICHECK(total_elements % 64 == 0)
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
<< total_elements;

ICHECK(dst_range.size() == 2)
<< "Expected 2D destination for simdgroup store";
PrimExpr dst_row_base = dst_range[0]->min;
PrimExpr dst_col_base = dst_range[1]->min;
ICHECK_EQ(dst->shape.size(), 2U)
<< "simdgroup store currently supports 2D destination buffers";
Array<PrimExpr> dst_strides = dst->strides;
if (dst_strides.empty()) {
PrimExpr stride = 1;
dst_strides.resize(dst->shape.size());
for (int i = static_cast<int>(dst->shape.size()) - 1; i >= 0; --i) {
dst_strides.Set(i, stride);
stride *= dst->shape[i];
}
}
if (dst_strides.size() != dst->shape.size()) {
return LowerNormalCopy(T, analyzer);
}
if (!analyzer->CanProveEqual(dst_strides[1], 1)) {
return LowerNormalCopy(T, analyzer);
}
PrimExpr dst_stride = dst_strides[0];

int warp_size = TargetGetWarpSize(T.target);
auto block_extent = T.thread_bounds->extent.as<IntImmNode>();
if (!block_extent || warp_size <= 0 || block_extent->value % warp_size != 0) {
return LowerNormalCopy(T, analyzer);
}
int block_size = block_extent->value;
int num_warps = block_size / warp_size;
if (num_warps <= 0) {
return LowerNormalCopy(T, analyzer);
}
PrimExpr relative_thread = T.thread_var - T.thread_bounds->min;
PrimExpr warp_id = FloorDiv(relative_thread, warp_size);

auto M_imm = src_range[0]->extent.as<IntImmNode>();
auto N_imm = src_range[1]->extent.as<IntImmNode>();
if (!M_imm || !N_imm) {
return LowerNormalCopy(T, analyzer);
}
int M = M_imm->value;
int N = N_imm->value;

int kMPerWarp = 8;
int kNPerWarp = 8;
int m_warp = 1, n_warp = num_warps;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
if (max_m <= 0 || max_n <= 0) {
return LowerNormalCopy(T, analyzer);
}
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0)
continue;
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
best_score = score;
m_warp = m;
n_warp = n;
}
}

if (best_score == std::numeric_limits<float>::max() || M < m_warp * 8 ||
N < n_warp * 8) {
return LowerNormalCopy(T, analyzer);
}
int warp_row_tiles = M / m_warp / 8;
int warp_col_tiles = N / n_warp / 8;
if (warp_row_tiles <= 0 || warp_col_tiles <= 0 ||
warp_row_tiles * warp_col_tiles * 64 > total_elements) {
return LowerNormalCopy(T, analyzer);
}

PrimExpr warp_m = FloorMod(warp_id, m_warp);
PrimExpr warp_n = FloorDiv(warp_id, m_warp);

Array<Stmt> stmts;
for (int i = 0; i < warp_row_tiles; i++) {
for (int j = 0; j < warp_col_tiles; j++) {
int tile_idx = i * warp_col_tiles + j;
PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8;
PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8;
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, {row, col})});
stmts.push_back(Evaluate(
Call(DataType::Handle(), builtin::simdgroup_store(),
{src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8),
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
}
}
if (stmts.size() == 1)
return stmts[0];
return SeqStmt(stmts);
}

Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
Expand Down
Loading
Loading