Add Metal scalar fallback for T.gemm#2118
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Metal-specific GEMM scalar dispatch and implementation, Metal-aware tests for GEMM/attention using T.gemm, Metal lowering for Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@testing/python/metal/test_metal_codegen.py`:
- Around line 132-136: The test currently always uses torch.randn to build
tensors a and b in assert_t_gemm which fails for integer dtypes; update
assert_t_gemm to check the dtype (use the same "int" in dtype check as
assert_gemm or inspect torch_dtype) and: when int, create a and b with
torch.randint(low=0, high=10, size=..., dtype=torch_dtype, device="mps");
otherwise keep torch.randn for floats; leave c as torch.zeros (zeros supports
integer dtype). Ensure you apply this to both a and b creation (use b_shape when
building b) and reuse torch_dtype and transpose_B variables.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8da235c7-20bb-40f0-b848-7ff7ed1375e3
📒 Files selected for processing (4)
src/op/gemm.cctesting/python/metal/test_metal_codegen.pytesting/python/metal/test_metal_codegen_linux.pytilelang/tileop/gemm/gemm_scalar.py
058d590 to
9f3a3ee
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)
589-608:⚠️ Potential issue | 🟠 MajorDon't bypass write validation for thread-varying replicated stores.
fragment->IsCompletedReplicated()only says the fragment layout is replicated. It does not guarantee every thread writes the same element. For a loop likefrag[i] = ..., skippingProveFragmentContainshere lets the per-thread replicas diverge while the buffer still keeps a replicated layout, which can mislead later consumers. Please keep the write-side skip only for loop-invariant indices.Safer condition
- bool replicated_local_write = access.is_write && - fragment->IsCompletedReplicated() && - store_shared_global_buffers_.empty(); + bool replicated_local_write = + access.is_write && fragment->IsCompletedReplicated() && + store_shared_global_buffers_.empty() && + std::all_of(access.indices.begin(), access.indices.end(), + [](const PrimExpr &idx) { + return idx.as<IntImmNode>() != nullptr; + });If you want to preserve more cases than constant indices, prove the write indices are thread-invariant before skipping the containment check.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/parallel.cc` around lines 589 - 608, The write-side skip currently uses fragment->IsCompletedReplicated() and store_shared_global_buffers_.empty() (the replicated_local_write condition) which can let thread-varying writes bypass ProveFragmentContains; modify the condition so replicated_local_write is true only when the write indices are proven loop-invariant (i.e., do not depend on the T.Parallel loop iteration variables) before skipping ProveFragmentContains: use the analyzer_/existing analysis utilities to check the access.indices for loop-invariance (or add a helper like AreIndicesLoopInvariant(candidate, access.indices, analyzer_)) and require that result in the replicated_local_write predicate where access.is_write is evaluated; leave ProveFragmentContains calls (the ProveFragmentContains(fragment, candidate, ...) path) unchanged for all other cases.
🧹 Nitpick comments (1)
tilelang/tileop/gemm/gemm_metal_scalar.py (1)
59-67: Consider hoisting the clear operation outside the inner loop.The
clear_accumcheck is inside theT.grid(M, N)loop, which means it's evaluated once per output element (correct), but the conditional branch is repeated M×N times. While the compiler may optimize this, moving the clear to a separate loop before the main computation would be slightly cleaner:💡 Optional refactor
`@T.prim_func` def _gemm_metal_scalar() -> None: + if clear_accum: + for i, j in T.grid(M, N): + C_buf[c0 + i, c1 + j] = T.cast(0, accum_dtype) for i, j in T.grid(M, N): - if clear_accum: - C_buf[c0 + i, c1 + j] = T.cast(0, accum_dtype) for k in T.Serial(K): C_buf[c0 + i, c1 + j] += T.cast(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_metal_scalar.py` around lines 59 - 67, The clear_accum conditional currently runs inside the output element loop (for i, j in T.grid(M, N)); hoist it by adding a separate initialization loop that runs when clear_accum is true before the main accumulation loop so you don't branch per element. Specifically, when clear_accum is true, iterate over T.grid(M, N) once to set C_buf[c0 + i, c1 + j] = T.cast(0, accum_dtype), then run the existing nested accumulation (the for i, j in T.grid(M, N) with for k in T.Serial(K) updating C_buf) without the clear_accum check; adjust only the control flow around C_buf initialization and keep the same index expressions and accum_dtype.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/reduce.cc`:
- Around line 404-429: The Metal fast-path currently writes into dst_buffer
while reading src_buffer and must be guarded against aliasing: update the
condition that enters the TargetIsMetal(...) &&
src_layout->IsCompletedReplicated() branch to also check for potential overlap
and the clear flag (i.e., skip the fast-path when clear == false and src/dst may
alias). Concretely, in the same conditional that uses TargetIsMetal and
src_layout->IsCompletedReplicated, add an alias-safety check (more robust than
simple data-pointer equality) for dst_buffer vs src_buffer and, if they can
alias and clear is false, route to the existing non-Metal/duplicate path (the
logic that previously used need_duplicate) instead of emitting the in-place init
+ reduce sequence built from MakeInitValue(), BufferLoad(), BufferStore(),
MakeReduce(), dst_vars and src_var_compressed loops.
---
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 589-608: The write-side skip currently uses
fragment->IsCompletedReplicated() and store_shared_global_buffers_.empty() (the
replicated_local_write condition) which can let thread-varying writes bypass
ProveFragmentContains; modify the condition so replicated_local_write is true
only when the write indices are proven loop-invariant (i.e., do not depend on
the T.Parallel loop iteration variables) before skipping ProveFragmentContains:
use the analyzer_/existing analysis utilities to check the access.indices for
loop-invariance (or add a helper like AreIndicesLoopInvariant(candidate,
access.indices, analyzer_)) and require that result in the
replicated_local_write predicate where access.is_write is evaluated; leave
ProveFragmentContains calls (the ProveFragmentContains(fragment, candidate, ...)
path) unchanged for all other cases.
---
Nitpick comments:
In `@tilelang/tileop/gemm/gemm_metal_scalar.py`:
- Around line 59-67: The clear_accum conditional currently runs inside the
output element loop (for i, j in T.grid(M, N)); hoist it by adding a separate
initialization loop that runs when clear_accum is true before the main
accumulation loop so you don't branch per element. Specifically, when
clear_accum is true, iterate over T.grid(M, N) once to set C_buf[c0 + i, c1 + j]
= T.cast(0, accum_dtype), then run the existing nested accumulation (the for i,
j in T.grid(M, N) with for k in T.Serial(K) updating C_buf) without the
clear_accum check; adjust only the control flow around C_buf initialization and
keep the same index expressions and accum_dtype.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b4e8b162-d6d3-4438-9558-4635ee3deccd
📒 Files selected for processing (13)
src/op/gemm.ccsrc/op/math.ccsrc/op/parallel.ccsrc/op/reduce.ccsrc/transform/layout_inference.ccsrc/transform/legalize_vectorized_loop.ccsrc/transform/lower_tile_op.ccsrc/transform/merge_shared_memory_allocations.ccsrc/transform/vectorize_loop.cctesting/python/metal/test_metal_codegen.pytesting/python/metal/test_metal_codegen_linux.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal_scalar.py
There was a problem hiding this comment.
♻️ Duplicate comments (2)
src/op/reduce.cc (1)
404-428:⚠️ Potential issue | 🟠 MajorGuard the Metal replicated fast path against src/dst aliasing when
clear == false.On Line 404, this path still performs in-place write/read on
dst_buffer/src_bufferwithout an alias safety guard. If they overlap, intermediate stores can corrupt later loads in the same reduction pass.Possible fix
- if (TargetIsMetal(T.target) && src_layout->IsCompletedReplicated()) { + bool can_use_metal_replicated_fast_path = + TargetIsMetal(T.target) && src_layout->IsCompletedReplicated() && + !(this->clear == false && src_buffer->data.same_as(dst_buffer->data)); + if (can_use_metal_replicated_fast_path) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/reduce.cc` around lines 404 - 428, The Metal replicated fast-path emits in-place loads/stores on dst_buffer and src_buffer (inside the TargetIsMetal && src_layout->IsCompletedReplicated branch) but does not guard against src/dst aliasing when clear == false; add an alias-safety check (e.g., detect if src_buffer may alias dst_buffer or overlapping regions) before taking this fast path and if aliasing is possible fall back to the safe reduction path (or allocate a temp copy of dst before the reduction) so that BufferLoad(BufferStore(...)) sequences created by MakeReduce, BufferLoad, BufferStore, src_buffer and dst_buffer are not corrupted by in-place writes when clear is false.testing/python/metal/test_metal_codegen.py (1)
132-135:⚠️ Potential issue | 🟡 MinorHandle integer dtypes in
assert_t_gemm.
torch.randndoes not support integer dtypes, so this helper will fail if it is reused fordtype=T.int32, unlikeassert_gemmabove. Branch ondtypehere and usetorch.randintfor integer tensors.🩹 Suggested fix
- a = torch.randn(M, K, dtype=torch_dtype, device="mps") b_shape = (N, K) if transpose_B else (K, N) - b = torch.randn(b_shape, dtype=torch_dtype, device="mps") + if "int" in dtype: + a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") + b = torch.randint(100, b_shape, dtype=torch_dtype, device="mps") + else: + a = torch.randn(M, K, dtype=torch_dtype, device="mps") + b = torch.randn(b_shape, dtype=torch_dtype, device="mps")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_codegen.py` around lines 132 - 135, The helper used in assert_t_gemm incorrectly always calls torch.randn with torch_dtype (which fails for integer dtypes); update the tensor creation to branch on the dtype: if torch_dtype is an integer type (e.g., dtype.is_integer() or compare against T.int32/T.int64), create a and b with torch.randint(low=-10, high=10, size=..., dtype=torch_dtype, device="mps"); otherwise keep using torch.randn for floating types; apply this change for variables a (shape M,K) and b (shape b_shape) where torch_dtype is used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@src/op/reduce.cc`:
- Around line 404-428: The Metal replicated fast-path emits in-place
loads/stores on dst_buffer and src_buffer (inside the TargetIsMetal &&
src_layout->IsCompletedReplicated branch) but does not guard against src/dst
aliasing when clear == false; add an alias-safety check (e.g., detect if
src_buffer may alias dst_buffer or overlapping regions) before taking this fast
path and if aliasing is possible fall back to the safe reduction path (or
allocate a temp copy of dst before the reduction) so that
BufferLoad(BufferStore(...)) sequences created by MakeReduce, BufferLoad,
BufferStore, src_buffer and dst_buffer are not corrupted by in-place writes when
clear is false.
In `@testing/python/metal/test_metal_codegen.py`:
- Around line 132-135: The helper used in assert_t_gemm incorrectly always calls
torch.randn with torch_dtype (which fails for integer dtypes); update the tensor
creation to branch on the dtype: if torch_dtype is an integer type (e.g.,
dtype.is_integer() or compare against T.int32/T.int64), create a and b with
torch.randint(low=-10, high=10, size=..., dtype=torch_dtype, device="mps");
otherwise keep using torch.randn for floating types; apply this change for
variables a (shape M,K) and b (shape b_shape) where torch_dtype is used.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d70e99ac-e4b2-48d5-a6c4-fc3f2a8e3a52
📒 Files selected for processing (13)
src/op/gemm.ccsrc/op/math.ccsrc/op/parallel.ccsrc/op/reduce.ccsrc/transform/layout_inference.ccsrc/transform/legalize_vectorized_loop.ccsrc/transform/lower_tile_op.ccsrc/transform/merge_shared_memory_allocations.ccsrc/transform/vectorize_loop.cctesting/python/metal/test_metal_codegen.pytesting/python/metal/test_metal_codegen_linux.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal_scalar.py
🚧 Files skipped from review as they are similar to previous changes (3)
- src/op/math.cc
- src/transform/lower_tile_op.cc
- src/transform/legalize_vectorized_loop.cc
9f3a3ee to
68e800f
Compare
|
Updated PR after the local code review pass. Changes in
Re-validated locally:
|
0b4a03e to
bd60bef
Compare
Select kScalar for Metal separately from CPU and route Metal scalar GEMM through a dedicated GemmMetalScalar implementation, leaving the CPU scalar path unchanged. Support attention-style Metal lowering by carrying PrimFunc targets into TileLang passes, lowering replicated local reductions without CUDA AllReduce, lowering tl.infinity on Metal, and converting Metal dynamic shared allocations to threadgroup shared storage before TVM Metal codegen. Tests: TILELANG_DISABLE_CACHE=1 python -m pytest testing/python/metal -q; TILELANG_DISABLE_CACHE=1 python -m pytest testing/python/cpu/test_tilelang_cpu_tgemm.py -q; ARLE HD128 paged attention tilelang.lower(..., target='metal').
bd60bef to
ccc1c66
Compare
|
Do you think #1869 would work for your case? It should provide a much higher performance than the scalar path. |
|
Yes, #1869 looks like the right performance path. For this PR I was aiming at a correctness/lowering fallback so the ARLE attention kernel can run locally on Metal. The full kernel is not a pure GEMM: after QK GEMM it needs elementwise access to scores for mask/reduce/softmax, and PV currently uses fragment x shared ( So I think #1869 is the right follow-up once we either materialize the simdgroup accumulator where needed or stage/extend the PV operand path, plus BF16 + |
Summary
Enable TileLang to lower and run ARLE's TileLang HD128 paged-attention kernel locally on macOS Metal.
T.gemm; the CPU scalar path stays unchanged.tl.infinity, target-scoped passes, and shared-memory codegen compatibility.T.gemm,transpose_B, pipelined shared buffers, single-thread kernels, dynamic/static shared-memory merge, and an attention-styleT.gemm+ reduce kernel.This is a correctness-first local Mac path, not a tuned Metal GEMM implementation.
Validation
TILELANG_DISABLE_CACHE=1 /tmp/arle-tilelang-mac-venv/bin/python -m pytest testing/python/metal -q->19 passedTILELANG_DISABLE_CACHE=1 /tmp/arle-tilelang-mac-venv/bin/python -m pytest testing/python/cpu/test_tilelang_cpu_tgemm.py -q->11 passedkernel_source_len 11868T.gemm;GemmMetalScalar.lowerwas called once per compile:128:0.933 ms,4.50 GFLOP/s,max_abs=0256:5.258 ms,6.38 GFLOP/s,max_abs=0512:38.234 ms,7.02 GFLOP/s,max_abs=0pre-commit run --files src/op/reduce.ccpassedcmake --build build -j8,compileall, andgit diff --checkpassedLatest commit validated:
93954575