tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering#2142
tilelang: T.fp8_scaled_matmul DSL intrinsic + Metal lowering#2142apstenku123 wants to merge 11 commits intotile-ai:mainfrom
Conversation
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations (simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon (M1-M5) without requiring a TVM fork. Key changes: - codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with simdgroup intrinsic emission and 128-bit vectorized copy - gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM - metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros - metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM accumulators to metal.simdgroup scope before layout inference - LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store 24 Metal tests (codegen cross-platform + correctness on device).
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR introduces complete Metal backend support for TileLang, including a Metal code generator, Metal-specific GEMM and copy operations, a new FP8 scaled matmul intrinsic, and comprehensive testing infrastructure for macOS MPS execution. ChangesMetal Backend Infrastructure & Operations
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
Files documenting the actual PRs we just opened upstream: - PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean) - PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean) - PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130) - PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130) - PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130) - PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130) Deferred (split into companion PRs needed): tilelang_metal_fp8 and tilelang_metal_fp8_vector each touch both tilelang supermodule and the TileLang/tvm vendored submodule. These need 2 PRs each — one to tile-ai/tilelang, one to TileLang/tvm — separate filing round. PRs #3-#6 are independent of each other; each branches directly from jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they can be reviewed in any order. They DO depend on the upstream 4-PR Apple Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any of those land separately, ours can be retargeted at main.
There was a problem hiding this comment.
Pull request overview
Adds a new FP8 scaled matmul intrinsic to the TileLang language surface and expands the Metal backend to support simdgroup-matrix GEMM lowering (plus supporting codegen/runtime plumbing and tests).
Changes:
- Introduces
T.fp8_scaled_matmul(A_fp8, A_scale, B_fp8, B_scale, C_out)as a hygienic macro and re-exports it fromtilelang.language. - Adds Metal simdgroup GEMM lowering/codegen support (new Metal codegen target builder, simdgroup copy/store lowering, and a fragment→simdgroup rewrite pass).
- Adds extensive Metal and IR-level tests/coverage docs, plus a Metal GEMM benchmark script and minor JIT adapter behavior changes.
Reviewed changes
Copilot reviewed 40 out of 41 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/utils/language.py | Adds helper is_metal_simdgroup() for scope checks. |
| tilelang/transform/metal_fragment_to_simdgroup.py | New Metal-only pass rewriting GEMM accumulators from local.fragment to metal.simdgroup. |
| tilelang/transform/decouple_type_cast.py | Treats metal.simdgroup buffers as “local” for cast-decoupling logic. |
| tilelang/tileop/metal_simdgroup.py | Internal simdgroup register-tile utilities/macros. |
| tilelang/tileop/metal_quant.py | Internal packed-quant decode helpers used by Metal scaffolding tests. |
| tilelang/tileop/metal_gdn.py | Internal GDN/attention-style tile macros using simdgroup helper utilities. |
| tilelang/tileop/gemm/inst.py | Adds METAL_SIMDGROUP GEMM instruction enum value. |
| tilelang/tileop/gemm/gemm_metal.py | Implements Metal simdgroup-matrix GEMM lowering path. |
| tilelang/tileop/gemm/init.py | Selects Metal simdgroup GEMM impl on Metal targets. |
| tilelang/language/fp8_op.py | Adds fp8_scaled_matmul macro implementation and validation. |
| tilelang/language/init.py | Re-exports fp8_scaled_matmul on the public language surface. |
| tilelang/jit/adapter/torch/metal.py | Adds/overrides get_kernel_source() for Metal kernels. |
| tilelang/jit/adapter/base.py | Adjusts device selection to prefer MPS when CUDA init/lookup fails. |
| tilelang/intrinsics/metal_macro_generator.py | Adds MPSIntrinEmitter for simdgroup load/store/MMA TIR macro emission. |
| tilelang/engine/phase.py | Inserts Metal fragment→simdgroup rewrite before layout inference. |
| tilelang/engine/lower.py | Switches Metal build entrypoint to target.build.tilelang_metal. |
| testing/python/metal/test_metal_simdgroup_store.py | New tests for direct simdgroup store to device memory. |
| testing/python/metal/test_metal_local_var.py | New focused tests for local.var scalar codegen/runtime on Metal. |
| testing/python/metal/test_metal_internal_scaffolding.py | Large internal scaffolding suite for Metal source-boundary + runtime probes. |
| testing/python/metal/test_metal_gemm_v2_linux.py | Cross-platform (no-Metal-runtime) Metal GEMM source-generation checks. |
| testing/python/metal/test_metal_gemm_v2.py | On-device Metal GEMM correctness checks. |
| testing/python/metal/test_fp8_scaled_matmul_metal.py | Metal lowering + offline compile + runtime parity tests for fp8_scaled_matmul. |
| testing/python/metal/metal_internal_runtime_coverage.md | Documents internal Metal runtime coverage and constraints. |
| testing/python/jit/test_tilelang_jit_adapter_mps.py | Adds tests covering MPS device preference in the JIT adapter. |
| testing/python/cpu/test_fp8_scaled_matmul_lowering.py | IR/source-level lowering tests for fp8_scaled_matmul (no GPU required). |
| src/transform/lower_device_storage_access_info.cc | Treats .fragment scope as exempt from storage access info lowering. |
| src/transform/layout_inference.cc | Skips fragment-layout requirement on Metal targets for GEMM accumulators. |
| src/target/codegen_metal.h | Adds TileLang-specific Metal codegen class declaration. |
| src/target/codegen_metal.cc | Implements TileLang Metal codegen and registers target.build.tilelang_metal. |
| src/op/utils.h | Adds helpers for metal.simdgroup buffers and a combined “register buffer” predicate. |
| src/op/parallel.cc | Makes fragment layout inference more defensive when layout info is absent. |
| src/op/gemm.h | Adds Metal simdgroup GEMM instruction enum value in C++ core. |
| src/op/gemm.cc | Selects Metal GEMM instruction on Metal targets; tweaks warp policy for Metal. |
| src/op/fill.cc | Adds Fill lowering for metal.simdgroup buffers via make_filled_simdgroup_matrix. |
| src/op/copy.h | Adds Metal simdgroup copy instruction kind and lowering hooks. |
| src/op/copy.cc | Implements simdgroup store lowering from metal.simdgroup to shared/global. |
| src/backend/metal/CMakeLists.txt | Ensures Metal codegen builds in “codegen-only” mode cross-platform. |
| requirements.txt | Adds Darwin-only upper bound for apache-tvm-ffi. |
| requirements-dev.txt | Adds Darwin-only upper bound for apache-tvm-ffi (dev). |
| pyproject.toml | Adds Darwin-only upper bound for apache-tvm-ffi (packaging). |
| benchmark/matmul_metal/benchmark_matmul_metal.py | Adds a Metal GEMM benchmark script for simdgroup GEMM. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Storage-level FP8 dtype tags accepted by this intrinsic. Any other dtype | ||
| # in the A / B operands raises a TypeError at parse time. ``float8_e8m0fnu`` | ||
| # is the block-scale-factor format and is intentionally excluded — it is | ||
| # carried by the sf_a / sf_b operands of the block-scaled GEMM, not by A / B. | ||
| FP8_DTYPES: tuple[str, ...] = ("float8_e4m3", "float8_e5m2", "float8_e4m3fn", "float8_e4m3fnuz", "float8_e5m2fnuz") | ||
|
|
||
|
|
||
| def _is_fp8_dtype(dt) -> bool: | ||
| """Return True if a dtype string / object names an FP8 storage variant.""" | ||
| s = str(dt or "") | ||
| return any(s.startswith(t) for t in ("float8", "fp8")) | ||
|
|
| M_dim, K_dim = A_fp8.shape | ||
| K_dim_b, N_dim = B_fp8.shape | ||
| sa_size = A_scale.shape[0] | ||
| sb_size = B_scale.shape[0] | ||
|
|
||
| # The accumulation matches the audiohacking ``fp8_scaled_matmul_kernel`` |
| a_val = T.cast(A_fp8[i, k], "float32") | ||
| b_val = T.cast(B_fp8[k, j], "float32") |
| return lambda: torch.device("cuda", current_device()) | ||
| except Exception: | ||
| return lambda: torch.device("cuda", torch.cuda.current_device()) | ||
| pass |
| def get_kernel_source(self, kernel_only: bool = True) -> str: | ||
| return self.kernel_global_source or "" | ||
|
|
| buf_map = {} | ||
|
|
||
| def _pre_order(stmt): | ||
| if isinstance(stmt, tir.Block): | ||
| new_alloc_bufs = [] | ||
| changed = False | ||
| for buf in stmt.alloc_buffers: | ||
| new_buf = _remap_buffer(buf, var_map) | ||
| new_alloc_bufs.append(new_buf) | ||
| if not new_buf.same_as(buf): | ||
| buf_map[buf] = new_buf | ||
| changed = True | ||
| if changed: | ||
| new_body = tir.stmt_functor.substitute(stmt.body, var_map) | ||
| new_block = tir.Block( | ||
| stmt.iter_vars, | ||
| stmt.reads, | ||
| stmt.writes, | ||
| stmt.name_hint, | ||
| new_body, | ||
| stmt.init, | ||
| new_alloc_bufs, | ||
| stmt.match_buffers, | ||
| stmt.annotations, | ||
| ) | ||
| return ( | ||
| tir.BlockRealize( | ||
| stmt.iter_vars, | ||
| tir.const(True, "bool"), | ||
| new_block, | ||
| ) | ||
| if False | ||
| else new_block | ||
| ) |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/jit/adapter/base.py (1)
77-86:⚠️ Potential issue | 🟡 Minor | ⚡ Quick win
passin theexceptblock silently drops the CUDA fallback for partial-init failuresThe single
tryblock catches two distinct failure modes:
torch.cuda._lazy_init()raises → CUDA is genuinely unusable → falling through to MPS/CPU is correct.torch._C._cuda_getDeviceattribute access fails → only the fast path is unavailable; CUDA is still usable → the old code returnedlambda: torch.device("cuda", torch.cuda.current_device())here, which was correct.With
pass, Scenario 2 silently routes a fully functional CUDA machine to MPS (if accidentally present) or CPU. That would silently allocate output tensors on the wrong device (seetvm_ffi.pyline 242), likely causing a device mismatch at kernel launch time rather than an obvious error.Splitting the
trypreserves the intended fall-through for Scenario 1 while restoring the safe CUDA fallback for Scenario 2:🛡️ Proposed fix
if torch.cuda.is_available(): - try: - torch.cuda._lazy_init() - current_device = torch._C._cuda_getDevice - return lambda: torch.device("cuda", current_device()) - except Exception: - pass + try: + torch.cuda._lazy_init() + except Exception: + pass # CUDA init failed entirely; fall through to MPS/CPU + else: + try: + current_device = torch._C._cuda_getDevice + return lambda: torch.device("cuda", current_device()) + except Exception: + # Fast-path C handle unavailable; fall back to public API + return lambda: torch.device("cuda", torch.cuda.current_device())Also, the docstring on line 74 (
"On CPU or when CUDA is unavailable, returns torch.device('cpu')") no longer reflects reality — MPS is now also a possible return value.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/jit/adapter/base.py` around lines 77 - 86, The try/except currently swallows attribute-access failures and incorrectly falls back away from CUDA; change the logic so torch.cuda._lazy_init() is attempted in its own try and only if it raises should we fall through to MPS/CPU, while failures when accessing torch._C._cuda_getDevice (or AttributeError) should use the safe fallback used previously (i.e., return a lambda that calls torch.device("cuda", torch.cuda.current_device()) or the retrieved current_device()), and keep other exceptions bubbling or handled appropriately; also update the function docstring (the string around line 74) to reflect that the function may return MPS or CPU when CUDA is unavailable.src/op/parallel.cc (1)
569-575:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
ValidateCandidateAgainstFragmentshas the same unsafe.as<Fragment>().value()pattern that was just fixed at lines 381-383.
ValidateCandidateAgainstFragmentsiteratesindice_map_with only aT.layout_map.count(buffer)pre-check (line 567) — the same setup as the unfixed code that motivated the change at 381-383. If a non-Fragment layout appears for a buffer inindice_map_, line 572 aborts unconditionally. This function is called from at least two sites: line 475 andChooseBestCandidate(lines 779-780). The same issue exists at line 726 insideBuildReplicationGuardsIfNeeded.🛡️ Proposed defensive fix for `ValidateCandidateAgainstFragments` (line 572) and `BuildReplicationGuardsIfNeeded` (line 726)
// ValidateCandidateAgainstFragments, line 572 - auto fragment = T.layout_map[buffer].as<Fragment>().value(); + auto fragment_opt = T.layout_map[buffer].as<Fragment>(); + if (!fragment_opt.has_value()) + continue; + auto fragment = fragment_opt.value();// BuildReplicationGuardsIfNeeded, line 726 - auto fragment_layout = T.layout_map[fragment].as<Fragment>().value(); + auto fragment_layout_opt = T.layout_map[fragment].as<Fragment>(); + if (!fragment_layout_opt.has_value()) + continue; + auto fragment_layout = fragment_layout_opt.value();🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/parallel.cc` around lines 569 - 575, ValidateCandidateAgainstFragments (and BuildReplicationGuardsIfNeeded) use the unsafe pattern T.layout_map[buffer].as<Fragment>().value() which will abort if the layout isn't a Fragment; replace this with a defensive check that the layout exists and is a Fragment before accessing it. For example, retrieve the optional layout via auto layout_opt = T.layout_map[buffer]; if (!layout_opt.has_value() || !layout_opt->is<Fragment>()) continue (or otherwise skip/handle non-Fragment cases), then safely extract the Fragment via layout_opt->as<Fragment>().value(); apply the same guard in BuildReplicationGuardsIfNeeded and any other call sites (e.g., ChooseBestCandidate) that assume Fragment layouts.tilelang/tileop/gemm/__init__.py (1)
154-175:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winDocstring priority ordering is misleading — Metal is checked first, not at position 6.
The docstring lists METAL_SIMDGROUP as priority 6 (after MFMA, WMMA, MMA), but the implementation returns early for Metal before any C++ FFI call. A reader following the docstring's priority list would incorrectly assume Metal goes through the same selection chain as CUDA/AMD targets.
📝 Proposed fix: reorder the docstring to match implementation
- The selection logic follows this priority: - 1. TCGEN5MMA for Blackwell architecture - 2. WGMMA for Hopper architecture with sufficient matrix size and warp count - 3. MFMA for CDNA (AMD) architecture - 4. WMMA for RDNA (AMD) architecture - 5. MMA for CUDA architecture - 6. METAL_SIMDGROUP for Metal target (simdgroup_matrix) - 7. Scalar for CPU target (scalar fallback) + The selection logic follows this priority: + 1. METAL_SIMDGROUP for Metal target (short-circuit before C++ FFI dispatch) + 2. TCGEN5MMA for Blackwell architecture + 3. WGMMA for Hopper architecture with sufficient matrix size and warp count + 4. MFMA for CDNA (AMD) architecture + 5. WMMA for RDNA (AMD) architecture + 6. MMA for CUDA architecture + 7. Scalar for CPU target (scalar fallback)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/__init__.py` around lines 154 - 175, Update the _select_gemm_instruction docstring to reflect the actual implementation order: move METAL_SIMDGROUP (target_is_metal check returning GemmInst.METAL_SIMDGROUP) to be evaluated before the FFI selection, and clarify that all other targets are resolved via _ffi_api.GemmGetGemmInst(self, int(thread_nums), target); reference the function name _select_gemm_instruction, the target_is_metal check, GemmInst.METAL_SIMDGROUP, and _ffi_api.GemmGetGemmInst so readers see the docstring matches the code path.
🧹 Nitpick comments (8)
testing/python/jit/test_tilelang_jit_adapter_mps.py (1)
10-47: ⚡ Quick winTests look correct — LGTM for the three covered scenarios
All three test functions correctly isolate the CUDA/MPS availability flags and assert the expected device. The
SimpleNamespaceapproach for platforms wheretorch.backends.mpsmay not exist is a clean workaround.One gap worth adding (relates directly to the regression flagged in
base.py): there is currently no test for the case wheretorch.cuda.is_available()isTrue,_lazy_init()succeeds, buttorch._C._cuda_getDevicefails. With the proposed split-tryfix inbase.py, that path should return a CUDA device (not MPS/CPU), and a test like the following would pin that behavior:def test_current_device_functor_falls_back_to_cuda_when_c_handle_fails(monkeypatch): monkeypatch.setattr(torch.cuda, "is_available", lambda: True) monkeypatch.setattr(torch.cuda, "_lazy_init", lambda: None) # succeeds monkeypatch.setattr(torch._C, "_cuda_getDevice", None, raising=False) # remove handle device_functor = BaseKernelAdapter.get_current_device_functor() assert device_functor().type == "cuda"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/jit/test_tilelang_jit_adapter_mps.py` around lines 10 - 47, Add a test that covers the regression where torch.cuda.is_available() is True and torch.cuda._lazy_init() succeeds but the CUDA C handle lookup fails; create a new test (e.g., test_current_device_functor_falls_back_to_cuda_when_c_handle_fails) that monkeypatches torch.cuda.is_available to return True, monkeypatches torch.cuda._lazy_init to a no-op (succeeds), and removes or sets torch._C._cuda_getDevice to None/raise (using monkeypatch.setattr with raising=False) then calls BaseKernelAdapter.get_current_device_functor() and asserts the returned device has type "cuda" to ensure the split-try behavior in BaseKernelAdapter.get_current_device_functor is preserved.tilelang/intrinsics/metal_macro_generator.py (1)
46-52: 💤 Low valueConsider using a more robust assertion for production code.
Line 49 uses a bare
assertwhich can be stripped when Python runs with-O. For kernel generation code, consider raising aValueErroror using a proper check that survives optimization.♻️ Proposed fix
def get_thread_binding(self): if self.thread_var is None: current_frame = T.KernelLaunchFrame.Current() - assert current_frame is not None, "Must be called in a T.Kernel Frame" + if current_frame is None: + raise RuntimeError("MPSIntrinEmitter: Must be called within a T.Kernel Frame") return current_frame.get_thread_binding() else: return self.thread_var🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/metal_macro_generator.py` around lines 46 - 52, In get_thread_binding, replace the bare assert on current_frame (from T.KernelLaunchFrame.Current()) with an explicit runtime check that raises a suitable exception (e.g., ValueError or RuntimeError) so the failure isn't stripped with -O; locate the method get_thread_binding and the variables thread_var and current_frame and change the control flow to raise the new exception with a clear message instead of using assert before returning current_frame.get_thread_binding().tilelang/language/fp8_op.py (1)
261-264: ⚡ Quick winPrefix unused unpacked variable with underscore.
K_dim_bis extracted but never used. Python convention and linters (RUF059) expect unused variables to be prefixed with_to signal intentional discard.♻️ Proposed fix
- M_dim, K_dim = A_fp8.shape - K_dim_b, N_dim = B_fp8.shape + M_dim, K_dim = A_fp8.shape + _K_dim_b, N_dim = B_fp8.shapeApply the same fix to line 284 in
_fp8_scaled_matmul_macro_trans_b:- N_dim, K_dim_b = B_fp8.shape + N_dim, _K_dim_b = B_fp8.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/fp8_op.py` around lines 261 - 264, The unpacked variable K_dim_b in the function _fp8_scaled_matmul_macro_trans_a is unused; rename it to _K_dim_b (prefix with underscore) to satisfy linting (RUF059) and indicate intentional discard, and make the equivalent change in _fp8_scaled_matmul_macro_trans_b (the similar unpack on line ~284) so both places use _K_dim_b instead of K_dim_b.pyproject.toml (1)
33-34: 💤 Low valueAdd a comment explaining the Darwin-specific version constraint.
The
<0.1.8upper bound on Darwin lacks context. Future maintainers won't know why this limit exists or when it's safe to remove. Consider adding a brief inline comment documenting the incompatibility (e.g., a specific bug or breaking change in 0.1.8+).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` around lines 33 - 34, Add an inline comment next to the Darwin-specific constraint for "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" in pyproject.toml that briefly explains why versions >=0.1.8 are excluded (referencing the specific bug/PR/commit or the observed breaking behavior), include a link or identifier for the upstream issue if available, and note the condition under which the upper bound can be removed (e.g., fixed upstream version or date).tilelang/language/__init__.py (1)
117-117: 💤 Low valueConsider grouping with other operation imports.
The
fp8_scaled_matmulimport is placed in the middle ofbuiltinre-exports (betweenldg256andballot_sync), breaking the logical grouping. Consider moving it near other operation imports (e.g., aftergemm_opimports at line 66 or afterfill_opat line 68) for better organization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/__init__.py` at line 117, Move the fp8_scaled_matmul re-export so it stays with operation imports: locate the current import "from .fp8_op import fp8_scaled_matmul" (currently between ldg256 and ballot_sync among builtin re-exports) and cut/paste it into the block of other op re-exports (e.g., immediately after the gemm_op imports or after fill_op) so all operation functions (including fp8_scaled_matmul) are grouped together for clearer organization.tilelang/engine/phase.py (1)
200-204: TheMetalFragmentToSimdgrouppass IS internally guarded — it checks the bound target attribute and returns early for non-Metal targets, so there's no risk of corrupting CUDA/HIPlocal.fragmentbuffers.However, the unconditional invocation at the call site is still worth clarifying. Consider either:
- Adding an explicit
if target.kind.name == "metal"guard at the call site, or- Updating the comment to note that the pass has an internal Metal-only guard
This improves readability since readers shouldn't have to inspect the pass implementation to understand that it only affects Metal targets.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/engine/phase.py` around lines 200 - 204, Call site currently unconditionally constructs and applies MetalFragmentToSimdgroup(mod); either add an explicit target-kind guard around that call (e.g., check target.kind.name == "metal" before importing/instantiating MetalFragmentToSimdgroup and assigning mod = MetalFragmentToSimdgroup(mod)) so the intent is obvious at the call site, or update the existing comment above the import to state that MetalFragmentToSimdgroup already checks the bound target and returns early for non-Metal targets; reference the MetalFragmentToSimdgroup class and the mod variable in your change so readers can quickly find the pass and its use.tilelang/transform/metal_fragment_to_simdgroup.py (1)
94-102: 💤 Low valueDead code: the
if Falsebranch is never taken.The ternary expression
tir.BlockRealize(...) if False else new_blockalways returnsnew_block. This looks like a debugging remnant that should be cleaned up.Proposed cleanup
- return ( - tir.BlockRealize( - stmt.iter_vars, - tir.const(True, "bool"), - new_block, - ) - if False - else new_block - ) + return new_block🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/transform/metal_fragment_to_simdgroup.py` around lines 94 - 102, Remove the dead ternary that always selects the else branch; replace the expression "tir.BlockRealize(... ) if False else new_block" with just "new_block" so the function returns new_block directly (remove the unused tir.BlockRealize/ tir.const(stmt.iter_vars, ...) debug remnant surrounding new_block).testing/python/metal/test_metal_gemm_v2.py (1)
84-86: 💤 Low valueConsider documenting or tightening the loose tolerance for 1024×1024×1024.
atol=1.0is a very loose absolute tolerance. For fp16 inputs accumulated in fp32, typical differences should be much smaller. This may mask real numerical issues. Consider:
- Adding a comment explaining why this tolerance is needed
- Using relative tolerance (
rtol) in addition toatol🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@testing/python/metal/test_metal_gemm_v2.py` around lines 84 - 86, The test test_gemm_v2_1024 uses a very loose absolute tolerance (atol=1.0); update the test to (1) add a brief inline comment in test_gemm_v2_1024 explaining why a relaxed tolerance is required for 1024×1024×1024 fp16 inputs (e.g., fp16 inputs with fp32 accumulation and known nondeterminism/quantization artifacts), and (2) tighten the check by replacing the single atol with a combination of a smaller atol and an rtol (e.g., rtol=1e-2 and atol=1e-2) when calling assert_gemm_v2 so the test still allows small fp16 rounding differences but will catch larger numerical regressions; refer to the test function name test_gemm_v2_1024 and the assertion helper assert_gemm_v2 when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Around line 58-71: The benchmark is inconsistent: bench_torch_mps allocates
the output implicitly each iteration while bench_tilelang reuses c, biasing the
ratio; change bench_torch_mps to preallocate a torch.zeros output (matching
dtype/device/shape of c) and use torch.mm(a, b, out=preallocated_c) or otherwise
call the kernel with that preallocated tensor so both paths reuse the same
output allocation (refer to bench_torch_mps, bench_tilelang, variables a/b/c and
function matmul_simdgroup/_bench).
In `@src/op/gemm.cc`:
- Around line 204-207: Comments referencing a hardcoded "16" are now stale after
introducing kMPerWarp and the TargetIsMetal path; update the two comments in the
non-WGMMA warp-partition logic to refer to kMPerWarp (or to a neutral phrasing
like "m_warp * kMPerWarp" / "kMPerWarp elements") instead of the literal "16".
Locate the block that defines int kMPerWarp = 16; if (TargetIsMetal(target)) {
kMPerWarp = 8; } and replace the comment lines that read "// If M cannot be
evenly divided by m_warp*16" and "// Each warp needs at least 16 elements in M"
with wording that uses kMPerWarp (or "kMPerWarp" spelled out) or a generic
description ("m_warp * kMPerWarp" and "kMPerWarp elements per warp") so the
comments stay correct for both Metal and non-Metal targets.
In `@src/target/codegen_metal.cc`:
- Around line 55-57: The generated union __TVMArgUnion only declares v_int[2]
but kernel arg accessors emit fields like v_half, v_bool, v_char, etc.; update
the union definition (symbol: __TVMArgUnion) to include matching members for all
sub-32/64-bit POD types used by the accessor code (e.g., v_half, v_bool, v_char,
v_int8, v_int16) or alter the accessor emission logic so it does not emit
v_<type> field accesses for these types; ensure you make the same change in both
places where the union is emitted and where accessors are generated (the two
corresponding emit sites that produce the union and arg.<name>.v_* accesses).
In `@testing/python/metal/test_metal_local_var.py`:
- Line 36: The test's assertion uses the regex "\bint\s+\w+\s*=\s*0;" and
expects at least 2 matches, but only one variable (`y`) is initialized to 0;
update the assertion in testing/python/metal/test_metal_local_var.py to reflect
the actual output by either changing the numeric expectation from ">= 2" to ">=
1" or broaden the regex to "\bint\s+\w+\s*=" so it matches any int
initialization; adjust the assertion line containing the regex to use the chosen
fix.
In `@tilelang/tileop/metal_simdgroup.py`:
- Around line 208-231: The helpers load_tile and store_tile currently only
operate on tile.index(0, 0), silently dropping other fragments for MMATile
instances with fragments_m/fragments_n > 1; update both functions (load_tile and
the similar store_tile block) to either iterate over all fragment coordinates
(loop m in range(tile.fragments_m) and n in range(tile.fragments_n) and call
load/store with tile.index(m, n)) or, if you prefer a fail-fast behavior, assert
or raise an exception when tile.fragments_m != 1 or tile.fragments_n != 1 so
partial materialization cannot occur; locate references to tile.index(0, 0) in
load_tile and the corresponding store_tile implementation and apply the chosen
fix consistently.
- Around line 385-396: mma_tile currently only uses a.index(tile_m, 0) and
b.index(0, tile_n), dropping additional K fragments; add a K-fragment loop (or
reject non-unit K tiling). Concretely, inside mma_tile (the T.macro using
MMATile and mma) iterate tile_k with T.unroll over the tile/K fragment count
(e.g., acc.fragments_k or the appropriate MMATile.fragments_k property), call
mma(acc.fragment, a.fragment, b.fragment, acc.index(tile_m,tile_n),
a.index(tile_m,tile_k), b.index(tile_k,tile_n)) for each tile_k to accumulate
partial products, and add shape checks/assertions that a.fragments_n ==
acc.fragments_k == b.fragments_m (or throw if fragments_k == 1 is required) so
non-unit K tiling is handled safely.
---
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 569-575: ValidateCandidateAgainstFragments (and
BuildReplicationGuardsIfNeeded) use the unsafe pattern
T.layout_map[buffer].as<Fragment>().value() which will abort if the layout isn't
a Fragment; replace this with a defensive check that the layout exists and is a
Fragment before accessing it. For example, retrieve the optional layout via auto
layout_opt = T.layout_map[buffer]; if (!layout_opt.has_value() ||
!layout_opt->is<Fragment>()) continue (or otherwise skip/handle non-Fragment
cases), then safely extract the Fragment via layout_opt->as<Fragment>().value();
apply the same guard in BuildReplicationGuardsIfNeeded and any other call sites
(e.g., ChooseBestCandidate) that assume Fragment layouts.
In `@tilelang/jit/adapter/base.py`:
- Around line 77-86: The try/except currently swallows attribute-access failures
and incorrectly falls back away from CUDA; change the logic so
torch.cuda._lazy_init() is attempted in its own try and only if it raises should
we fall through to MPS/CPU, while failures when accessing
torch._C._cuda_getDevice (or AttributeError) should use the safe fallback used
previously (i.e., return a lambda that calls torch.device("cuda",
torch.cuda.current_device()) or the retrieved current_device()), and keep other
exceptions bubbling or handled appropriately; also update the function docstring
(the string around line 74) to reflect that the function may return MPS or CPU
when CUDA is unavailable.
In `@tilelang/tileop/gemm/__init__.py`:
- Around line 154-175: Update the _select_gemm_instruction docstring to reflect
the actual implementation order: move METAL_SIMDGROUP (target_is_metal check
returning GemmInst.METAL_SIMDGROUP) to be evaluated before the FFI selection,
and clarify that all other targets are resolved via
_ffi_api.GemmGetGemmInst(self, int(thread_nums), target); reference the function
name _select_gemm_instruction, the target_is_metal check,
GemmInst.METAL_SIMDGROUP, and _ffi_api.GemmGetGemmInst so readers see the
docstring matches the code path.
---
Nitpick comments:
In `@pyproject.toml`:
- Around line 33-34: Add an inline comment next to the Darwin-specific
constraint for "apache-tvm-ffi<0.1.8; platform_system == 'Darwin'" in
pyproject.toml that briefly explains why versions >=0.1.8 are excluded
(referencing the specific bug/PR/commit or the observed breaking behavior),
include a link or identifier for the upstream issue if available, and note the
condition under which the upper bound can be removed (e.g., fixed upstream
version or date).
In `@testing/python/jit/test_tilelang_jit_adapter_mps.py`:
- Around line 10-47: Add a test that covers the regression where
torch.cuda.is_available() is True and torch.cuda._lazy_init() succeeds but the
CUDA C handle lookup fails; create a new test (e.g.,
test_current_device_functor_falls_back_to_cuda_when_c_handle_fails) that
monkeypatches torch.cuda.is_available to return True, monkeypatches
torch.cuda._lazy_init to a no-op (succeeds), and removes or sets
torch._C._cuda_getDevice to None/raise (using monkeypatch.setattr with
raising=False) then calls BaseKernelAdapter.get_current_device_functor() and
asserts the returned device has type "cuda" to ensure the split-try behavior in
BaseKernelAdapter.get_current_device_functor is preserved.
In `@testing/python/metal/test_metal_gemm_v2.py`:
- Around line 84-86: The test test_gemm_v2_1024 uses a very loose absolute
tolerance (atol=1.0); update the test to (1) add a brief inline comment in
test_gemm_v2_1024 explaining why a relaxed tolerance is required for
1024×1024×1024 fp16 inputs (e.g., fp16 inputs with fp32 accumulation and known
nondeterminism/quantization artifacts), and (2) tighten the check by replacing
the single atol with a combination of a smaller atol and an rtol (e.g.,
rtol=1e-2 and atol=1e-2) when calling assert_gemm_v2 so the test still allows
small fp16 rounding differences but will catch larger numerical regressions;
refer to the test function name test_gemm_v2_1024 and the assertion helper
assert_gemm_v2 when making the change.
In `@tilelang/engine/phase.py`:
- Around line 200-204: Call site currently unconditionally constructs and
applies MetalFragmentToSimdgroup(mod); either add an explicit target-kind guard
around that call (e.g., check target.kind.name == "metal" before
importing/instantiating MetalFragmentToSimdgroup and assigning mod =
MetalFragmentToSimdgroup(mod)) so the intent is obvious at the call site, or
update the existing comment above the import to state that
MetalFragmentToSimdgroup already checks the bound target and returns early for
non-Metal targets; reference the MetalFragmentToSimdgroup class and the mod
variable in your change so readers can quickly find the pass and its use.
In `@tilelang/intrinsics/metal_macro_generator.py`:
- Around line 46-52: In get_thread_binding, replace the bare assert on
current_frame (from T.KernelLaunchFrame.Current()) with an explicit runtime
check that raises a suitable exception (e.g., ValueError or RuntimeError) so the
failure isn't stripped with -O; locate the method get_thread_binding and the
variables thread_var and current_frame and change the control flow to raise the
new exception with a clear message instead of using assert before returning
current_frame.get_thread_binding().
In `@tilelang/language/__init__.py`:
- Line 117: Move the fp8_scaled_matmul re-export so it stays with operation
imports: locate the current import "from .fp8_op import fp8_scaled_matmul"
(currently between ldg256 and ballot_sync among builtin re-exports) and
cut/paste it into the block of other op re-exports (e.g., immediately after the
gemm_op imports or after fill_op) so all operation functions (including
fp8_scaled_matmul) are grouped together for clearer organization.
In `@tilelang/language/fp8_op.py`:
- Around line 261-264: The unpacked variable K_dim_b in the function
_fp8_scaled_matmul_macro_trans_a is unused; rename it to _K_dim_b (prefix with
underscore) to satisfy linting (RUF059) and indicate intentional discard, and
make the equivalent change in _fp8_scaled_matmul_macro_trans_b (the similar
unpack on line ~284) so both places use _K_dim_b instead of K_dim_b.
In `@tilelang/transform/metal_fragment_to_simdgroup.py`:
- Around line 94-102: Remove the dead ternary that always selects the else
branch; replace the expression "tir.BlockRealize(... ) if False else new_block"
with just "new_block" so the function returns new_block directly (remove the
unused tir.BlockRealize/ tir.const(stmt.iter_vars, ...) debug remnant
surrounding new_block).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6c526c4b-0de1-40b4-9566-88d13550d371
📒 Files selected for processing (41)
benchmark/matmul_metal/benchmark_matmul_metal.pypyproject.tomlrequirements-dev.txtrequirements.txtsrc/backend/metal/CMakeLists.txtsrc/op/copy.ccsrc/op/copy.hsrc/op/fill.ccsrc/op/gemm.ccsrc/op/gemm.hsrc/op/parallel.ccsrc/op/utils.hsrc/target/codegen_metal.ccsrc/target/codegen_metal.hsrc/transform/layout_inference.ccsrc/transform/lower_device_storage_access_info.cctesting/python/cpu/test_fp8_scaled_matmul_lowering.pytesting/python/jit/test_tilelang_jit_adapter_mps.pytesting/python/metal/metal_internal_runtime_coverage.mdtesting/python/metal/test_fp8_scaled_matmul_metal.pytesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_internal_scaffolding.pytesting/python/metal/test_metal_local_var.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/engine/lower.pytilelang/engine/phase.pytilelang/intrinsics/metal_macro_generator.pytilelang/jit/adapter/base.pytilelang/jit/adapter/torch/metal.pytilelang/language/__init__.pytilelang/language/fp8_op.pytilelang/tileop/gemm/__init__.pytilelang/tileop/gemm/gemm_metal.pytilelang/tileop/gemm/inst.pytilelang/tileop/metal_gdn.pytilelang/tileop/metal_quant.pytilelang/tileop/metal_simdgroup.pytilelang/transform/decouple_type_cast.pytilelang/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
| def bench_torch_mps(M, N, K, warmup, repeats): | ||
| a = torch.randn(M, K, dtype=torch.float16, device="mps") | ||
| b = torch.randn(K, N, dtype=torch.float16, device="mps") | ||
| avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats) | ||
| return _tflops(M, N, K, avg_s) | ||
|
|
||
|
|
||
| def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats): | ||
| kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K) | ||
| a = torch.randn(M, K, dtype=torch.float16, device="mps") | ||
| b = torch.randn(K, N, dtype=torch.float16, device="mps") | ||
| c = torch.zeros(M, N, dtype=torch.float32, device="mps") | ||
| avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats) | ||
| return _tflops(M, N, K, avg_s) |
There was a problem hiding this comment.
Benchmark both paths with the same output-allocation policy.
bench_torch_mps measures torch.mm(a, b) with a fresh output tensor every iteration, but bench_tilelang reuses c. That bakes allocator cost into the reference only, so the printed TileLang/PyTorch ratio is artificially optimistic. Preallocate the reference output too, or include allocation cost on both sides.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmark/matmul_metal/benchmark_matmul_metal.py` around lines 58 - 71, The
benchmark is inconsistent: bench_torch_mps allocates the output implicitly each
iteration while bench_tilelang reuses c, biasing the ratio; change
bench_torch_mps to preallocate a torch.zeros output (matching dtype/device/shape
of c) and use torch.mm(a, b, out=preallocated_c) or otherwise call the kernel
with that preallocated tensor so both paths reuse the same output allocation
(refer to bench_torch_mps, bench_tilelang, variables a/b/c and function
matmul_simdgroup/_bench).
| int kMPerWarp = 16; // Rows processed by a single warp | ||
| if (TargetIsMetal(target)) { | ||
| kMPerWarp = 8; | ||
| } |
There was a problem hiding this comment.
Stale "16" in comments at lines 304 and 334 after introducing kMPerWarp = 8 for Metal.
Setting kMPerWarp = 8 for Metal is correct (simdgroup_matrix uses 8×8 tiles), but the existing comments in the non-WGMMA warp-partition logic still hardcode "16", which now misleads Metal readers:
- Line 304:
// If M cannot be evenly divided by m_warp*16 - Line 334:
// Each warp needs at least 16 elements in M
The code uses the kMPerWarp variable (correct), but the comments are now stale.
📝 Proposed comment fixes
- // If M cannot be evenly divided by m_warp*16, try to split remaining warps
+ // If M cannot be evenly divided by m_warp*kMPerWarp, try to split remaining
// to N- int max_m_warps =
- M / kMPerWarp; // Each warp needs at least 16 elements in M
+ int max_m_warps =
+ M / kMPerWarp; // Each warp needs at least kMPerWarp elements in M (16 for CUDA, 8 for Metal)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/op/gemm.cc` around lines 204 - 207, Comments referencing a hardcoded "16"
are now stale after introducing kMPerWarp and the TargetIsMetal path; update the
two comments in the non-WGMMA warp-partition logic to refer to kMPerWarp (or to
a neutral phrasing like "m_warp * kMPerWarp" / "kMPerWarp elements") instead of
the literal "16". Locate the block that defines int kMPerWarp = 16; if
(TargetIsMetal(target)) { kMPerWarp = 8; } and replace the comment lines that
read "// If M cannot be evenly divided by m_warp*16" and "// Each warp needs at
least 16 elements in M" with wording that uses kMPerWarp (or "kMPerWarp" spelled
out) or a generic description ("m_warp * kMPerWarp" and "kMPerWarp elements per
warp") so the comments stay correct for both Metal and non-Metal targets.
| decl_stream << "union __TVMArgUnion {\n" | ||
| << " int v_int[2];\n" | ||
| << "};\n\n"; |
There was a problem hiding this comment.
Handle sub-32-bit scalar kernel args without bogus union fields.
For non-32/64-bit POD params, the emitted accessor becomes arg.<name>.v_half / v_bool / v_char / etc., but __TVMArgUnion only declares v_int[2]. Any kernel with a half/bool/int8-style launch arg will therefore generate invalid MSL and fail compilation. Either add matching union members or stop emitting v_<type> field accesses here.
Proposed fix
union __TVMArgUnion {
int v_int[2];
+ bool v_bool;
+ char v_char;
+ uchar v_uchar;
+ short v_short;
+ ushort v_ushort;
+ half v_half;
+ bfloat v_bfloat;
};Also applies to: 145-149
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_metal.cc` around lines 55 - 57, The generated union
__TVMArgUnion only declares v_int[2] but kernel arg accessors emit fields like
v_half, v_bool, v_char, etc.; update the union definition (symbol:
__TVMArgUnion) to include matching members for all sub-32/64-bit POD types used
by the accessor code (e.g., v_half, v_bool, v_char, v_int8, v_int16) or alter
the accessor emission logic so it does not emit v_<type> field accesses for
these types; ensure you make the same change in both places where the union is
emitted and where accessors are generated (the two corresponding emit sites that
produce the union and arg.<name>.v_* accesses).
|
|
||
| # local.var should lower to scalar declarations/stores rather than arrays or | ||
| # an unsupported storage scope. | ||
| assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src |
There was a problem hiding this comment.
Assertion may be incorrect: only one variable initializes to 0.
The regex \bint\s+\w+\s*=\s*0; expects >= 2 matches, but based on the kernel:
x = T.alloc_var(T.int32, init=3)→ emitsint x = 3;y = T.alloc_var(T.int32)→ emitsint y = 0;(default)
Only y would match the pattern. Consider changing to >= 1 or updating the regex to match any initialization (e.g., \bint\s+\w+\s*=).
Proposed fix
- assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
+ assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src | |
| assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@testing/python/metal/test_metal_local_var.py` at line 36, The test's
assertion uses the regex "\bint\s+\w+\s*=\s*0;" and expects at least 2 matches,
but only one variable (`y`) is initialized to 0; update the assertion in
testing/python/metal/test_metal_local_var.py to reflect the actual output by
either changing the numeric expectation from ">= 2" to ">= 1" or broaden the
regex to "\bint\s+\w+\s*=" so it matches any int initialization; adjust the
assertion line containing the regex to use the chosen fix.
| def load_tile( | ||
| tile: MMATile, | ||
| dtype, | ||
| data, | ||
| offset, | ||
| extent, | ||
| stride, | ||
| *, | ||
| rows: int = 8, | ||
| cols: int = 8, | ||
| transpose: bool = False, | ||
| ) -> None: | ||
| load( | ||
| tile.fragment, | ||
| tile.index(0, 0), | ||
| dtype, | ||
| data, | ||
| offset, | ||
| extent, | ||
| stride, | ||
| rows, | ||
| cols, | ||
| transpose, | ||
| ) |
There was a problem hiding this comment.
load_tile / store_tile silently drop every fragment except (0, 0).
Both helpers accept an MMATile with fragments_m/fragments_n, but they hardcode tile.index(0, 0). A multi-fragment tile will therefore be only partially loaded/materialized, which is silent data corruption. Either iterate all fragments, or fail fast unless the tile is 1x1.
Proposed fix
`@T.macro`
def load_tile(
tile: MMATile,
@@
) -> None:
+ if tile.fragments_m != 1 or tile.fragments_n != 1:
+ raise ValueError("load_tile currently supports only 1x1 MMATile")
load(
tile.fragment,
tile.index(0, 0),
@@
`@T.macro`
def store_tile(
tile: MMATile,
@@
) -> None:
+ if tile.fragments_m != 1 or tile.fragments_n != 1:
+ raise ValueError("store_tile currently supports only 1x1 MMATile")
store(
tile.fragment,
tile.index(0, 0),Also applies to: 296-320
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/metal_simdgroup.py` around lines 208 - 231, The helpers
load_tile and store_tile currently only operate on tile.index(0, 0), silently
dropping other fragments for MMATile instances with fragments_m/fragments_n > 1;
update both functions (load_tile and the similar store_tile block) to either
iterate over all fragment coordinates (loop m in range(tile.fragments_m) and n
in range(tile.fragments_n) and call load/store with tile.index(m, n)) or, if you
prefer a fail-fast behavior, assert or raise an exception when tile.fragments_m
!= 1 or tile.fragments_n != 1 so partial materialization cannot occur; locate
references to tile.index(0, 0) in load_tile and the corresponding store_tile
implementation and apply the chosen fix consistently.
| @T.macro | ||
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | ||
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | ||
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | ||
| mma( | ||
| acc.fragment, | ||
| a.fragment, | ||
| b.fragment, | ||
| acc.index(tile_m, tile_n), | ||
| a.index(tile_m, 0), | ||
| b.index(0, tile_n), | ||
| ) |
There was a problem hiding this comment.
mma_tile misses the K-fragment reduction.
The implementation only multiplies a.index(tile_m, 0) by b.index(0, tile_n). If a.fragments_n / b.fragments_m is greater than 1, every partial product after k=0 is dropped and the result is wrong. Add a tile_k loop (plus shape checks), or reject non-unit K tiling here.
Proposed fix
`@T.macro`
def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None:
+ if a.fragments_n != b.fragments_m:
+ raise ValueError(
+ f"incompatible tile K dimensions: {a.fragments_n} vs {b.fragments_m}"
+ )
for tile_m in T.unroll(acc.fragments_m, explicit=True):
for tile_n in T.unroll(acc.fragments_n, explicit=True):
- mma(
- acc.fragment,
- a.fragment,
- b.fragment,
- acc.index(tile_m, tile_n),
- a.index(tile_m, 0),
- b.index(0, tile_n),
- )
+ for tile_k in T.unroll(a.fragments_n, explicit=True):
+ mma(
+ acc.fragment,
+ a.fragment,
+ b.fragment,
+ acc.index(tile_m, tile_n),
+ a.index(tile_m, tile_k),
+ b.index(tile_k, tile_n),
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @T.macro | |
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | |
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | |
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | |
| mma( | |
| acc.fragment, | |
| a.fragment, | |
| b.fragment, | |
| acc.index(tile_m, tile_n), | |
| a.index(tile_m, 0), | |
| b.index(0, tile_n), | |
| ) | |
| `@T.macro` | |
| def mma_tile(acc: MMATile, a: MMATile, b: MMATile) -> None: | |
| if a.fragments_n != b.fragments_m: | |
| raise ValueError( | |
| f"incompatible tile K dimensions: {a.fragments_n} vs {b.fragments_m}" | |
| ) | |
| for tile_m in T.unroll(acc.fragments_m, explicit=True): | |
| for tile_n in T.unroll(acc.fragments_n, explicit=True): | |
| for tile_k in T.unroll(a.fragments_n, explicit=True): | |
| mma( | |
| acc.fragment, | |
| a.fragment, | |
| b.fragment, | |
| acc.index(tile_m, tile_n), | |
| a.index(tile_m, tile_k), | |
| b.index(tile_k, tile_n), | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/tileop/metal_simdgroup.py` around lines 385 - 396, mma_tile
currently only uses a.index(tile_m, 0) and b.index(0, tile_n), dropping
additional K fragments; add a K-fragment loop (or reject non-unit K tiling).
Concretely, inside mma_tile (the T.macro using MMATile and mma) iterate tile_k
with T.unroll over the tile/K fragment count (e.g., acc.fragments_k or the
appropriate MMATile.fragments_k property), call mma(acc.fragment, a.fragment,
b.fragment, acc.index(tile_m,tile_n), a.index(tile_m,tile_k),
b.index(tile_k,tile_n)) for each tile_k to accumulate partial products, and add
shape checks/assertions that a.fragments_n == acc.fragments_k == b.fragments_m
(or throw if fragments_k == 1 is required) so non-unit K tiling is handled
safely.
Summary
Adds the
T.fp8_scaled_matmul(A_fp8, A_scale, B_fp8, B_scale, C_out)DSL intrinsic — a hygienic@T.macromirroring the audiohacking/fp8-mps-metal scaled-matmul kernel surface (Apache 2.0). Provides Metal lowering via the GemmMetalScalar path with per-tensor and (single-block) per-row scale dispatch resolved at macro-expansion time.This is the "real" implementation, not a stub: the macro expands to a working scaled-FP8 GEMM that produces results bit-near-equal to the audiohacking reference MSL kernel. 25/25 IR + Metal e2e parity tests pass on M-series.
Files
tilelang/language/__init__.py— re-exportsT.fp8_scaled_matmultilelang/language/fp8_op.py— 379 LOC, the macro implementationtesting/python/cpu/test_fp8_scaled_matmul_lowering.py— 8 IR-level lowering teststesting/python/metal/test_fp8_scaled_matmul_metal.py— 17 Metal codegen + xcrun + e2e parity tests vsmlx.coreground truthe + 8→e + 7) — required for parity withtorch.float8_e4m3fn/mx.from_fp8/ audiohacking LUT decoder. Without the fix, bytes 0x01..0x07 / 0x81..0x87 dequantize to 2× the correct value. NOTE: this hunk is for the3rdparty/tvmvendored fork (TileLang/tvm); the code change is in this PR's history but the actual landing happens via a companion PR to that submodule. Maintainers can either merge this PR with the bugfix included (it lives in the patch text) or split it off.Why
cppmega.mlx sparse-MLA-FP8 attention probes need a fused FP8 scaled GEMM path. Path B (hand-written MSL via mx.fast.metal_kernel calling the audiohacking kernel pre-compiled) works today; this PR adds the equivalent surface in TileLang DSL so
T.fp8_scaled_matmul(...)works inside@T.prim_func.Bench
Local Apple M4 Max:
The gap to audiohacking is expected: the MSL reference is a hand-tuned kernel with
simd_sumreduction for the M=1 case; the DSL macro lowers viaT.gemm+ post-loop scale broadcast. A follow-up scheduler pass that fuses per-load scale into the K-loop would close most of the matmul gap; the vecmat case additionally needs aT.simdgroup_reduce_sumprimitive.Stacking topology
This PR is based on
jorgecurious/tilelang:metal-gemm-upstream-rebase(PR #2130) at HEAD971c17b, which itself stacks on top of:Once #1869+#2118+#2121+#2130 merge into
tile-ai/tilelang:main, this PR can be retargeted to main directly.Test plan
cd /path/to/tilelang pytest testing/python/cpu/test_fp8_scaled_matmul_lowering.py testing/python/metal/test_fp8_scaled_matmul_metal.pyExpect 25/25 pass.
Limitations / follow-ups
T.cast(fp8, fp32)lowers via__nv_fp8_e4m3_to_halfetc. on CUDA. Tensor-core FP8 path (T.tcgen05_gemm_blockscaled) is intentionally not auto-dispatched here because the e8m0 block-scale layout doesn't match this op's per-tensor / per-row semantics.Attribution
The audiohacking/fp8-mps-metal MSL kernel (Apache 2.0) is the algorithmic reference. The TileLang DSL macro replicates the per-tensor scale broadcast + K-loop unrolling pattern in TIR. Co-developed with
cppmega.mlx.Summary by CodeRabbit
New Features
T.fp8_scaled_matmulintrinsic for FP8 quantized matrix multiplicationBug Fixes
Tests
Chores
apache-tvm-ffi