Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ list(APPEND TILE_LANG_SRCS
src/runtime/error_helpers.cc
)

# Metal codegen is pure C++ (no Apple frameworks) and can generate Metal shader
# source on any platform. Always compile it so that "target.build.tilelang_metal"
# is available for cross-compilation on Linux/Windows.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)

set(TILELANG_OUTPUT_TARGETS tilelang tvm)

# Track if the user explicitly selected a backend via cache options.
Expand Down
119 changes: 119 additions & 0 deletions benchmark/matmul_metal/benchmark_matmul_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import argparse
import logging
import time

import torch

import tilelang
import tilelang.language as T

logging.getLogger("tilelang").setLevel(logging.WARNING)

BLOCK_CONFIGS = [
(16, 16, 16),
(32, 32, 16),
(32, 32, 32),
(64, 64, 32),
]


@tilelang.jit
def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32):

@T.prim_func
def gemm_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

return gemm_kernel


def _tflops(M, N, K, seconds):
return 2.0 * M * N * K / seconds / 1e12


def _bench(fn, warmup, repeats):
for _ in range(warmup):
fn()
torch.mps.synchronize()
t0 = time.perf_counter()
for _ in range(repeats):
fn()
torch.mps.synchronize()
return (time.perf_counter() - t0) / repeats


def bench_torch_mps(M, N, K, warmup, repeats):
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats)
return _tflops(M, N, K, avg_s)


def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats):
kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K)
a = torch.randn(M, K, dtype=torch.float16, device="mps")
b = torch.randn(K, N, dtype=torch.float16, device="mps")
c = torch.zeros(M, N, dtype=torch.float32, device="mps")
avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats)
return _tflops(M, N, K, avg_s)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)")
parser.add_argument("--m", type=int, default=4096)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--k", type=int, default=4096)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k

print(f"torch: {torch.__version__}")
print(f"tilelang: {tilelang.__version__}")
print(f"MPS: {torch.backends.mps.is_available()}")
print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}")
print()

ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats)
print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS")
print()

configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)]

print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}")
print("-" * 44)

best_tflops = 0.0
best_config = configs[0]
for bM, bN, bK in configs:
try:
tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats)
ratio = tl / ref_tflops * 100
tag = ""
if tl > best_tflops:
best_tflops = tl
best_config = (bM, bN, bK)
print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%")
except Exception as e:
print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}")

if args.sweep:
print()
print(f"Best config: {best_config}")
print(f"Best TFlops: {best_tflops:.1f}")
print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
# requirement as wide as possible to be compatible with other libraries
# pip will try to use latest version whenever possible.
"apache-tvm-ffi~=0.1.0,>=0.1.2",
"apache-tvm-ffi<0.1.8; platform_system == 'Darwin'",
# torch-c-dlpack-ext provides prebuilt torch extensions.
# Without it, TVM FFI may require JIT compilation on first import.
"torch-c-dlpack-ext; python_version < '3.14'",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Requirements to run local build with `--no-build-isolation` or other developments

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
build
cmake>=3.26
cython>=3.1.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Runtime requirements

apache-tvm-ffi~=0.1.0,>=0.1.2
apache-tvm-ffi<0.1.8; platform_system == 'Darwin'
torch-c-dlpack-ext; python_version < '3.14'
cloudpickle
ml-dtypes
Expand Down
117 changes: 117 additions & 0 deletions src/backend/metal/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,129 @@

#include "op/copy.h"

#include "backend/metal/op/utils.h"
#include "op/utils.h"
#include "target/utils.h"

#include <tvm/tir/builtin.h>

#include <algorithm>
#include <cmath>
#include <limits>

namespace tvm {
namespace tl {

using namespace tir;

namespace metal {

namespace {

bool CheckSIMDGroupCopy(const CopyNode &op) {
return IsSIMDGroupBuffer(op.src) &&
(IsSharedBuffer(op.dst) || IsGlobalBuffer(op.dst));
}

Stmt LowerSIMDGroupCopy(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer) {
(void)analyzer;
ICHECK(IsSIMDGroupBuffer(op.src));

int total_elements = 1;
for (auto s : op.src->shape) {
auto imm = s.as<IntImmNode>();
ICHECK(imm) << "simdgroup buffer must have constant shape";
total_elements *= imm->value;
}
ICHECK(total_elements % 64 == 0)
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
<< total_elements;

ICHECK(op.src_range.size() == 2) << "Expected 2D source for simdgroup store";
ICHECK(op.dst_range.size() == 2)
<< "Expected 2D destination for simdgroup store";
PrimExpr dst_row_base = op.dst_range[0]->min;
PrimExpr dst_col_base = op.dst_range[1]->min;
PrimExpr dst_stride = op.dst->shape[op.dst->shape.size() - 1];

int warp_size = TargetGetWarpSize(T.target);
const auto *block_size_imm = T.thread_bounds->extent.as<IntImmNode>();
ICHECK(block_size_imm)
<< "simdgroup copy requires constant thread bounds";
int block_size = block_size_imm->value;
int num_warps = block_size / warp_size;
PrimExpr warp_id = FloorDiv(T.thread_var, warp_size);

const auto *m_imm = op.src_range[0]->extent.as<IntImmNode>();
const auto *n_imm = op.src_range[1]->extent.as<IntImmNode>();
ICHECK(m_imm && n_imm) << "simdgroup copy requires constant extents";
int M = m_imm->value;
int N = n_imm->value;

constexpr int kMPerWarp = 8;
constexpr int kNPerWarp = 8;
int m_warp = 1, n_warp = num_warps;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
for (int m = 1; m <= std::min(num_warps, max_m); ++m) {
if (num_warps % m != 0) {
continue;
}
int n = num_warps / m;
if (n > max_n) {
continue;
}
float m_per = static_cast<float>(M) / (m * kMPerWarp);
float n_per = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per / n_per - ideal);
if (score < best_score) {
best_score = score;
m_warp = m;
n_warp = n;
}
}

ICHECK(M >= m_warp * kMPerWarp && N >= n_warp * kNPerWarp)
<< "Cannot partition " << M << "x" << N << " matrix across " << m_warp
<< "x" << n_warp << " warps with 8x8 simdgroup tiles";
int warp_row_tiles = M / m_warp / kMPerWarp;
int warp_col_tiles = N / n_warp / kNPerWarp;
ICHECK(warp_row_tiles > 0 && warp_col_tiles > 0);
ICHECK(warp_row_tiles * warp_col_tiles * 64 <= total_elements)
<< "Warp partition produces more tiles than buffer capacity";

PrimExpr warp_m = FloorMod(warp_id, m_warp);
PrimExpr warp_n = FloorDiv(warp_id, m_warp);

Array<Stmt> stmts;
for (int i = 0; i < warp_row_tiles; i++) {
for (int j = 0; j < warp_col_tiles; j++) {
int tile_idx = i * warp_col_tiles + j;
PrimExpr row =
dst_row_base + warp_m * (warp_row_tiles * kMPerWarp) + i * kMPerWarp;
PrimExpr col =
dst_col_base + warp_n * (warp_col_tiles * kNPerWarp) + j * kNPerWarp;
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(op.dst, {row, col})});
stmts.push_back(Evaluate(Call(
DataType::Handle(), builtin::simdgroup_store(),
{op.src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
IntImm(DataType::Int(32), kMPerWarp),
IntImm(DataType::Int(32), kNPerWarp),
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
}
}
if (stmts.size() == 1) {
return stmts[0];
}
return SeqStmt(stmts);
}

} // namespace

struct Copy {
static LayoutMap InferLayout(const CopyNode &op, const LayoutInferArgs &T,
InferLevel level) {
Expand All @@ -22,6 +136,9 @@ struct Copy {

static Stmt Lower(const CopyNode &op, const LowerArgs &T,
arith::Analyzer *analyzer) {
if (CheckSIMDGroupCopy(op)) {
return LowerSIMDGroupCopy(op, T, analyzer);
}
return LowerNormalCopy(op, T, analyzer);
}
};
Expand Down
31 changes: 31 additions & 0 deletions src/backend/metal/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,50 @@

#include "op/fill.h"

#include "backend/metal/op/utils.h"
#include "op/utils.h"
#include "target/utils.h"
#include "transform/loop_partition.h"
#include "transform/loop_vectorize.h"

#include <tvm/tir/builtin.h>

namespace tvm {
namespace tl {

using namespace tir;

namespace metal {

struct Fill {
static Stmt Lower(const FillNode &op, const LowerArgs &T,
arith::Analyzer *analyzer) {
if (IsSIMDGroupBuffer(op.dst)) {
int region_elements = 1;
for (auto r : op.region) {
auto imm = r->extent.as<IntImmNode>();
ICHECK(imm) << "simdgroup fill region must have constant extents";
region_elements *= imm->value;
}
ICHECK(region_elements % 64 == 0)
<< "simdgroup buffer size must be multiple of 64 (8x8), got "
<< region_elements;

int num_matrices = region_elements / 64;
PrimExpr fill_value = Cast(op.dst->dtype, op.value);
Array<Stmt> stmts;
for (int i = 0; i < num_matrices; i++) {
stmts.push_back(Evaluate(Call(
DataType::Handle(), builtin::make_filled_simdgroup_matrix(),
{op.dst->data, IntImm(DataType::Int(32), i), fill_value,
IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8)})));
}
if (stmts.size() == 1) {
return stmts[0];
}
return SeqStmt(stmts);
}

if (IsFragmentBuffer(op.dst)) {
auto par_op = ParallelOp(op.MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target,
Expand Down
Loading
Loading