Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions benchmark/matmul_metal/benchmark_matmul_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import argparse
import logging
import time

import torch

import tilelang
import tilelang.language as T

logging.getLogger("tilelang").setLevel(logging.WARNING)

BLOCK_CONFIGS = [
(16, 16, 16),
(32, 32, 16),
(32, 32, 32),
(64, 64, 32),
]


@tilelang.jit
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):

@T.prim_func
def gemm_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

return gemm_kernel


def _tflops(M, N, K, seconds):
return 2.0 * M * N * K / seconds / 1e12


def _bench(fn, warmup, repeats):
for _ in range(warmup):
fn()
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(repeats):
fn()
torch.mps.synchronize()
return (time.perf_counter() - t0) / repeats


def bench_torch_mps(M, N, K, warmup, repeats):
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
return _tflops(M, N, K, avg_s)


def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
return _tflops(M, N, K, avg_s)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
parser.add_argument("--m", type=int, default=4096)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--k", type=int, default=4096)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k

Comment thread
coderabbitai[bot] marked this conversation as resolved.
print(f"torch: {torch.__version__}")
print(f"tilelang: {tilelang.__version__}")
print(f"MPS: {torch.backends.mps.is_available()}")
if not torch.backends.mps.is_available():
raise SystemExit("Metal GEMM benchmark requires PyTorch MPS support")
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
print()

ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
print()

configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]

print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
print("-" * 44)

best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
# requirement as wide as possible to be compatible with other libraries
# pip will try to use latest version whenever possible.
"apache-tvm-ffi~=0.1.0,>=0.1.2",
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
# torch-c-dlpack-ext provides prebuilt torch extensions.
# Without it, TVM FFI may require JIT compilation on first import.
"torch-c-dlpack-ext; python_version < '3.14'",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Requirements to run local build with `--no-build-isolation` or other developments

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
build
cmake>=3.26
cython>=3.1.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Runtime requirements

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
torch-c-dlpack-ext; python_version < '3.14'
cloudpickle
ml-dtypes
Expand Down
8 changes: 8 additions & 0 deletions src/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Metal backend: source files and build configuration.

# Metal codegen is pure C++ and can generate Metal shader source on any
# platform. Always compile it so target.build.tilelang_metal is available for
# cross-compilation and source-level tests on non-Apple hosts.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)

if(NOT USE_METAL)
return()
endif()
Expand Down
147 changes: 144 additions & 3 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
return result_map;
}

if (copy_inst == CopyInst::kMetalSIMDGroup) {
return {};
}

// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
Expand Down Expand Up @@ -792,8 +796,44 @@ bool CopyNode::CheckCPAsyncCopy(Target target, const LayoutMap &layout_map,
if (!CheckCPAsyncCopyPreconditions()) {
return false;
}
// Skip vectorize size check here because, during the Infer Layout stage,
// the layout is not stable and the vectorized size cannot be determined.
return true;
}

bool CopyNode::CheckSIMDGroupCopy(Target target) const {
if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) {
return false;
}
if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) {
return false;
}
if (src->dtype != dst->dtype) {
return false;
}
if (src_range.size() != 2 || dst_range.size() != 2 ||
dst->shape.size() != 2) {
return false;
}

int total_elements = 1;
for (auto extent : src->shape) {
auto imm = extent.as<IntImmNode>();
if (!imm) {
return false;
}
total_elements *= imm->value;
}
if (total_elements % 64 != 0) {
return false;
}

for (int i = 0; i < 2; ++i) {
auto src_extent = src_range[i]->extent.as<IntImmNode>();
auto dst_extent = dst_range[i]->extent.as<IntImmNode>();
if (!src_extent || !dst_extent || src_extent->value != dst_extent->value ||
src_extent->value % 8 != 0) {
return false;
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return true;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Expand Down Expand Up @@ -864,6 +904,8 @@ CopyInst CopyNode::GetCopyInst(Target target, const LayoutMap &layout_map,
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else if (CheckSIMDGroupCopy(target)) {
return CopyInst::kMetalSIMDGroup;
} else {
return CopyInst::kNormal;
}
Expand Down Expand Up @@ -897,6 +939,8 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto cp_async_copy = LowerCPAsyncCopy(T, analyzer);
ICHECK(cp_async_copy.defined()) << "Failed to lower cp.async copy";
return cp_async_copy;
} else if (copy_inst == CopyInst::kMetalSIMDGroup) {
return LowerSIMDGroupCopy(T, analyzer);
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else {
Expand Down Expand Up @@ -982,7 +1026,104 @@ Stmt CopyNode::LowerCPAsyncCopy(const LowerArgs &T,
return cp_async_loop;
}

// Lowers the copy using standard load/store with loop transformations.
Stmt CopyNode::LowerSIMDGroupCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(IsSIMDGroupBuffer(src));
int total_elements = 1;
for (auto s : src->shape) {
auto imm = s.as<IntImmNode>();
ICHECK(imm) << "simdgroup buffer must have constant shape";
total_elements *= imm->value;
}
ICHECK(total_elements % 64 == 0)
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
<< total_elements;

ICHECK(dst_range.size() == 2)
<< "Expected 2D destination for simdgroup store";
PrimExpr dst_row_base = dst_range[0]->min;
PrimExpr dst_col_base = dst_range[1]->min;
ICHECK_EQ(dst->shape.size(), 2U)
<< "simdgroup store currently supports 2D destination buffers";
Array<PrimExpr> dst_strides = dst->strides;
if (dst_strides.empty()) {
PrimExpr stride = 1;
dst_strides.resize(dst->shape.size());
for (int i = static_cast<int>(dst->shape.size()) - 1; i >= 0; --i) {
dst_strides.Set(i, stride);
stride *= dst->shape[i];
}
}
ICHECK_EQ(dst_strides.size(), dst->shape.size())
<< "simdgroup store requires complete destination strides";
ICHECK(analyzer->CanProveEqual(dst_strides[1], 1))
<< "simdgroup store requires contiguous destination columns, got stride "
<< dst_strides[1];
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
PrimExpr dst_stride = dst_strides[0];

int warp_size = TargetGetWarpSize(T.target);
int block_size = T.thread_bounds->extent.as<IntImmNode>()->value;
int num_warps = block_size / warp_size;
PrimExpr warp_id = FloorDiv(T.thread_var, warp_size);

int M = src_range[0]->extent.as<IntImmNode>()->value;
int N = src_range[1]->extent.as<IntImmNode>()->value;

int kMPerWarp = 8;
int kNPerWarp = 8;
int m_warp = 1, n_warp = num_warps;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
best_score = score;
m_warp = m;
n_warp = n;
}
}

ICHECK(M >= m_warp * 8 && N >= n_warp * 8)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
<< "Cannot partition " << M << "x" << N << " matrix across " << m_warp
<< "x" << n_warp << " warps with 8x8 simdgroup tiles";
int warp_row_tiles = M / m_warp / 8;
int warp_col_tiles = N / n_warp / 8;
ICHECK(warp_row_tiles > 0 && warp_col_tiles > 0);
ICHECK(warp_row_tiles * warp_col_tiles * 64 <= total_elements)
<< "Warp partition produces more tiles than buffer capacity";

PrimExpr warp_m = FloorMod(warp_id, m_warp);
PrimExpr warp_n = FloorDiv(warp_id, m_warp);

Array<Stmt> stmts;
for (int i = 0; i < warp_row_tiles; i++) {
for (int j = 0; j < warp_col_tiles; j++) {
int tile_idx = i * warp_col_tiles + j;
PrimExpr row = dst_row_base + warp_m * (warp_row_tiles * 8) + i * 8;
PrimExpr col = dst_col_base + warp_n * (warp_col_tiles * 8) + j * 8;
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, {row, col})});
stmts.push_back(Evaluate(
Call(DataType::Handle(), builtin::simdgroup_store(),
{src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8),
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
}
}
if (stmts.size() == 1)
return stmts[0];
return SeqStmt(stmts);
}

Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
Expand Down
21 changes: 17 additions & 4 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ enum class CopyInst : uint8_t {
kCPAsync = 5, // cp.async global->shared copy
// we should separate the bulk load and store for 1d and multi-dim
// as they have different memory access patterns
kBulkLoad1D = 6, // utilize tma load 1d
kBulkStore1D = 7, // utilize tma store 1d
kTMemLoad = 8, // tcgen05.ld (tensor memory -> register)
kTMemStore = 9, // tcgen05.st (register -> tensor memory)
kBulkLoad1D = 6, // utilize tma load 1d
kBulkStore1D = 7, // utilize tma store 1d
kTMemLoad = 8, // tcgen05.ld (tensor memory -> register)
kTMemStore = 9, // tcgen05.st (register -> tensor memory)
kMetalSIMDGroup = 10, // Metal simdgroup load/store
};

/// Convert CopyInst enum to string for debugging
Expand All @@ -53,6 +54,8 @@ inline const char *CopyInstToString(CopyInst inst) {
return "TMemLoad";
case CopyInst::kTMemStore:
return "TMemStore";
case CopyInst::kMetalSIMDGroup:
return "MetalSIMDGroup";
default:
return "Unknown";
}
Expand Down Expand Up @@ -290,6 +293,11 @@ class CopyNode : public TileOperatorNode {
arith::Analyzer *analyzer) const;

protected:
/*!
* \brief Check if copy from Metal simdgroup to shared/global is supported.
*/
bool CheckSIMDGroupCopy(Target target) const;

/*!
* \brief Get the copy instruction type.
*/
Expand Down Expand Up @@ -331,6 +339,11 @@ class CopyNode : public TileOperatorNode {
*/
Stmt LowerCPAsyncCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

/*!
* \brief Generate lowering for simdgroup store.
*/
Stmt LowerSIMDGroupCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

/*!
* \brief Generate SIMT (thread-level) loop for copying.
*/
Expand Down
Loading