Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions benchmark/matmul_metal/benchmark_matmul_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import logging
import time

import torch

import tilelang
import tilelang.language as T

logging.getLogger("tilelang").setLevel(logging.WARNING)

BLOCK_CONFIGS = [
(16, 16, 16),
(32, 32, 16),
(32, 32, 32),
(64, 64, 32),
]


@tilelang.jit
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):

@T.prim_func
def gemm_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

return gemm_kernel


def _tflops(M, N, K, seconds):
return 2.0 * M * N * K / seconds / 1e12


def _bench(fn, warmup, repeats):
for _ in range(warmup):
fn()
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(repeats):
fn()
torch.mps.synchronize()
return (time.perf_counter() - t0) / repeats


def bench_torch_mps(M, N, K, warmup, repeats):
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
return _tflops(M, N, K, avg_s)


def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
return _tflops(M, N, K, avg_s)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
parser.add_argument("--m", type=int, default=4096)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--k", type=int, default=4096)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k
for name, value in (("m", M), ("n", N), ("k", K), ("repeats", args.repeats)):
if value <= 0:
raise SystemExit(f"--{name} must be a positive integer, got {value}")
if args.warmup < 0:
raise SystemExit(f"--warmup must be non-negative, got {args.warmup}")

print(f"torch: {torch.__version__}")
print(f"tilelang: {tilelang.__version__}")
print(f"MPS: {torch.backends.mps.is_available()}")
if not torch.backends.mps.is_available():
raise SystemExit("Metal GEMM benchmark requires PyTorch MPS support")
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
print()

ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
print()

configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]

print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
print("-" * 44)

best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
Comment on lines +108 to +125
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Don't print a winner when every sweep config failed.

Lines 108-125 still report configs[0] as the best config if every benchmark attempt throws. That makes the summary misleading in exactly the case where the per-config error handling is supposed to help.

Suggested fix
-    best_tflops = 0.0
-    best_config = configs[0]
+    best_tflops = 0.0
+    best_config = None
     for bM, bN, bK in configs:
         try:
             tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
             ratio = tl / ref_tflops * 100
-            tag = ""
             if tl > best_tflops:
                 best_tflops = tl
                 best_config = (bM, bN, bK)
             print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
         except Exception as e:
             print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

-    if args.sweep:
+    if args.sweep and best_config is not None:
         print()
         print(f"Best config: {best_config}")
         print(f"Best TFlops: {best_tflops:.1f}")
         print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
+    elif args.sweep:
+        print()
+        print("No TileLang configuration completed successfully.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
best_tflops = 0.0
best_config = None
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")
if args.sweep and best_config is not None:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
elif args.sweep:
print()
print("No TileLang configuration completed successfully.")
🧰 Tools
🪛 Ruff (0.15.12)

[warning] 119-119: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 108 - 125, The
loop currently leaves best_config as configs[0] even if every bench_tilelang
call fails; change best_config to None (or similar sentinel) at initialization
and only assign it inside the try block when a run succeeds (e.g., when tl >
best_tflops or when first success), and when args.sweep is true print the
summary only if best_config is not None (otherwise skip or print "no successful
configs"); update references to best_tflops, best_config, configs and
bench_tilelang accordingly.

print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
# requirement as wide as possible to be compatible with other libraries
# pip will try to use latest version whenever possible.
"apache-tvm-ffi~=0.1.0,>=0.1.2",
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
Comment on lines 33 to +34
# torch-c-dlpack-ext provides prebuilt torch extensions.
# Without it, TVM FFI may require JIT compilation on first import.
"torch-c-dlpack-ext; python_version < '3.14'",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Requirements to run local build with `--no-build-isolation` or other developments

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
build
cmake>=3.26
cython>=3.1.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Runtime requirements

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
torch-c-dlpack-ext; python_version < '3.14'
cloudpickle
ml-dtypes
Expand Down
10 changes: 9 additions & 1 deletion src/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Metal backend: source files and build configuration.

# Metal codegen is pure C++ and can generate Metal shader source on any
# platform. Always compile it so target.build.tilelang_metal is available for
# cross-compilation and source-level tests on non-Apple hosts.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)

if(NOT USE_METAL)
return()
endif()
Expand All @@ -7,7 +15,7 @@ if(NOT APPLE)
# On non-Apple platforms USE_METAL=ON enables only codegen (Metal source
# generation) without requiring the Metal/Foundation frameworks.
message(STATUS "Metal backend on non-Apple: enabling codegen-only mode (no Metal runtime)")
set(USE_METAL OFF)
return()
endif()

file(GLOB TILE_LANG_METAL_SRCS
Expand Down
171 changes: 168 additions & 3 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
return result_map;
}

if (copy_inst == CopyInst::kMetalSIMDGroup) {
return {};
}

// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
Expand Down Expand Up @@ -792,8 +796,47 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map,
if (!CheckCPAsyncCopyPreconditions()) {
return false;
}
// Skip vectorize size check here because, during the Infer Layout stage,
// the layout is not stable and the vectorized size cannot be determined.
return true;
}

bool CopyNode::CheckSIMDGroupCopy(Target target) const {
if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) {
return false;
}
if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) {
return false;
}
if (src->dtype != dst->dtype) {
return false;
}
if (src_range.size() != 2 || dst_range.size() != 2 ||
dst->shape.size() != 2) {
return false;
}

int total_elements = 1;
for (auto extent : src->shape) {
auto imm = extent.as<IntImmNode>();
if (!imm) {
return false;
}
total_elements *= imm->value;
}
if (total_elements % 64 != 0) {
return false;
}

for (int i = 0; i < 2; ++i) {
auto src_shape = src->shape[i].as<IntImmNode>();
auto src_min = src_range[i]->min.as<IntImmNode>();
auto src_extent = src_range[i]->extent.as<IntImmNode>();
auto dst_extent = dst_range[i]->extent.as<IntImmNode>();
if (!src_shape || !src_min || src_min->value != 0 || !src_extent ||
!dst_extent || src_extent->value != src_shape->value ||
src_extent->value != dst_extent->value || src_extent->value % 8 != 0) {
return false;
}
}
return true;
}

Expand Down Expand Up @@ -864,6 +907,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map,
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else if (CheckSIMDGroupCopy(target)) {
return CopyInst::kMetalSIMDGroup;
} else {
return CopyInst::kNormal;
}
Expand Down Expand Up @@ -897,6 +942,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto cp_async_copy = LowerCPAsyncCopy(T, analyzer);
ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy";
return cp_async_copy;
} else if (copy_inst == CopyInst::kMetalSIMDGroup) {
return LowerSIMDGroupCopy(T, analyzer);
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else {
Expand Down Expand Up @@ -982,7 +1029,125 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T,
return cp_async_loop;
}

// Lowers the copy using standard load/store with loop transformations.
Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(IsSIMDGroupBuffer(src));
int total_elements = 1;
for (auto s : src->shape) {
auto imm = s.as<IntImmNode>();
ICHECK(imm) << "simdgroup buffer must have constant shape";
total_elements *= imm->value;
}
ICHECK(total_elements % 64 == 0)
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
<< total_elements;

ICHECK(dst_range.size() == 2)
<< "Expected 2D destination for simdgroup store";
PrimExpr dst_row_base = dst_range[0]->min;
PrimExpr dst_col_base = dst_range[1]->min;
ICHECK_EQ(dst->shape.size(), 2U)
<< "simdgroup store currently supports 2D destination buffers";
Array<PrimExpr> dst_strides = dst->strides;
if (dst_strides.empty()) {
PrimExpr stride = 1;
dst_strides.resize(dst->shape.size());
for (int i = static_cast<int>(dst->shape.size()) - 1; i >= 0; --i) {
dst_strides.Set(i, stride);
stride *= dst->shape[i];
}
}
if (dst_strides.size() != dst->shape.size()) {
return LowerNormalCopy(T, analyzer);
}
if (!analyzer->CanProveEqual(dst_strides[1], 1)) {
return LowerNormalCopy(T, analyzer);
}
PrimExpr dst_stride = dst_strides[0];

int warp_size = TargetGetWarpSize(T.target);
auto block_extent = T.thread_bounds->extent.as<IntImmNode>();
if (!block_extent || warp_size <= 0 || block_extent->value % warp_size != 0) {
return LowerNormalCopy(T, analyzer);
}
int block_size = block_extent->value;
int num_warps = block_size / warp_size;
if (num_warps <= 0) {
return LowerNormalCopy(T, analyzer);
}
PrimExpr relative_thread = T.thread_var - T.thread_bounds->min;
PrimExpr warp_id = FloorDiv(relative_thread, warp_size);

auto M_imm = src_range[0]->extent.as<IntImmNode>();
auto N_imm = src_range[1]->extent.as<IntImmNode>();
if (!M_imm || !N_imm) {
return LowerNormalCopy(T, analyzer);
}
int M = M_imm->value;
int N = N_imm->value;

int kMPerWarp = 8;
int kNPerWarp = 8;
int m_warp = 1, n_warp = num_warps;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
if (max_m <= 0 || max_n <= 0) {
return LowerNormalCopy(T, analyzer);
}
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
Comment on lines +1097 to +1100
continue;
int n = num_warps / m;
if (n > max_n)
continue;
if (M % (m * kMPerWarp) != 0 || N % (n * kNPerWarp) != 0)
continue;
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
best_score = score;
m_warp = m;
n_warp = n;
}
}

if (best_score == std::numeric_limits<float>::max() || M < m_warp * 8 ||
N < n_warp * 8) {
return LowerNormalCopy(T, analyzer);
}
int warp_row_tiles = M / m_warp / 8;
int warp_col_tiles = N / n_warp / 8;
if (warp_row_tiles <= 0 || warp_col_tiles <= 0 ||
warp_row_tiles * warp_col_tiles * 64 > total_elements) {
return LowerNormalCopy(T, analyzer);
}

PrimExpr warp_m = FloorMod(warp_id, m_warp);
PrimExpr warp_n = FloorDiv(warp_id, m_warp);

Array<Stmt> stmts;
for (int i = 0; i < warp_row_tiles; i++) {
for (int j = 0; j < warp_col_tiles; j++) {
int tile_idx = i * warp_col_tiles + j;
PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8;
PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8;
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, {row, col})});
stmts.push_back(Evaluate(
Call(DataType::Handle(), builtin::simdgroup_store(),
{src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8),
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
}
}
if (stmts.size() == 1)
return stmts[0];
return SeqStmt(stmts);
}

Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
Expand Down
Loading
Loading