Skip to content

[Metal] route FP8-input T.gemm to scalar fallback#2140

Open
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-fp8-gemm-software-path
Open

[Metal] route FP8-input T.gemm to scalar fallback#2140
apstenku123 wants to merge 11 commits intotile-ai:mainfrom
apstenku123:cppmega/metal-fp8-gemm-software-path

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

Routes T.gemm with FP8 input operands on the Metal target to the scalar dequant-multiply-accumulate fallback. Apple Silicon (M1-M4) has no native FP8 ALU, so simdgroup_multiply_accumulate cannot lower FP8 inputs directly; the scalar path correctly emits per-element dequant via __tvm_fp8_e4m3_to_half (or __tvm_fp8_e5m2_to_half) helpers and accumulates in fp32.

Two changes:

  1. tilelang/tileop/gemm/__init__.py::Gemm._select_gemm_instruction — Metal branch detects FP8 input via new _has_fp8_input_dtype helper and routes to GemmInst.Scalar instead of GemmInst.METAL_SIMDGROUP.
  2. tilelang/transform/metal_fragment_to_simdgroup.py — visitor short-circuits on FP8 inputs so the simdgroup-fragment rewrite skips FP8 GEMMs (they stay scalar).

Why

cppmega.mlx sparse-MLA-FP8 attention probes need an FP8 GEMM that lowers correctly on Metal. Without this routing, T.gemm(A_fp8, B_fp8, C_fp32) either produces invalid MSL (simdgroup path) or fails dispatch outright. The scalar fallback, paired with the storage-only FP8 emulation patches (separate PRs), gives a software FP8 GEMM that's slower than CUDA's native FP8 tensor cores but functionally correct on Apple Silicon today. M5/M6 will revisit native FP8.

Stacking topology

This PR is based on jorgecurious/tilelang:metal-gemm-upstream-rebase (PR #2130) at HEAD 971c17b, which itself stacks on top of:

Runtime prereq (NOT applied here): the storage-only FP8 emulation patch that introduces the __tvm_fp8_e4m3_to_half / __tvm_fp8_e5m2_to_half helper functions on Metal. Without that, this dispatcher routing has nothing to dequant through. That patch lives at cppmega.mlx:docs/upstream/tilelang_metal_fp8/ and will be filed once the TileLang/tvm-mirror split is settled.

Test plan

Test the dispatcher path with an FP8-input @T.prim_func:

import tilelang.language as T
@T.prim_func
def fp8_gemm(A: T.Tensor((M,K), "float8_e4m3"), B: T.Tensor((K,N), "float8_e4m3"), C: T.Tensor((M,N), "float32")):
    with T.Kernel(...): T.gemm(A, B, C)

tilelang.engine.lower.lower(fp8_gemm, target="metal") should succeed (with the storage-only FP8 patch as runtime prereq).

Caveats

  • Performance: this is the scalar fallback. Native FP8 simdgroup_multiply_accumulate doesn't exist on Apple Silicon today.
  • Dispatcher only — no new codegen emitted by this PR. Storage-side helpers come from a separate FP8 emulation patch (filed separately, see "Stacking topology" above).

Attribution

Co-developed with cppmega.mlx for Apple-Silicon Metal MLA kernel ports.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Metal backend support for GEMM operations on Apple MPS with SIMD group acceleration.
    • Introduced Metal code generation and kernel compilation pipeline.
    • Added benchmarking utilities for Metal matrix multiplication performance.
  • Bug Fixes

    • Improved device selection logic to fallback to MPS when CUDA is unavailable.
  • Tests

    • Added comprehensive Metal functional tests for GEMM, matrix operations, and quantization.
    • Included codegen validation tests for Metal kernel generation across platforms.
  • Documentation

    • Added Metal backend test coverage documentation.

oraluben and others added 11 commits April 30, 2026 01:43
Add T.gemm support for Apple Metal using simdgroup_matrix 8x8 operations
(simdgroup_load/store/multiply_accumulate). Works on all Apple Silicon
(M1-M5) without requiring a TVM fork.

Key changes:
- codegen_metal.cc/h: Fork TVM Metal codegen to tilelang with
  simdgroup intrinsic emission and 128-bit vectorized copy
- gemm_metal.py: GemmMetal tile operator for sharedxshared GEMM
- metal_macro_generator.py: MPSIntrinEmitter for simdgroup MMA macros
- metal_fragment_to_simdgroup.py: Pass rewrites local.fragment GEMM
  accumulators to metal.simdgroup scope before layout inference
- LowerSIMDGroupCopy in copy.cc for fragment->device simdgroup_store

24 Metal tests (codegen cross-platform + correctness on device).
Copilot AI review requested due to automatic review settings May 4, 2026 08:52
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

This pull request adds comprehensive Metal SIMD group support for TileLang GEMM operations, including a new Metal code generator, GEMM lowering path, copy and fill operations with SIMD group backing, high-level tile abstraction macros, intrinsic emitters, and extensive tests covering codegen and runtime correctness on Apple MPS hardware.

Changes

Metal SIMD Group GEMM & Infrastructure

Layer / File(s) Summary
Type & Instruction Definitions
src/op/gemm.h, src/op/copy.h, tilelang/tileop/gemm/inst.py
Added GemmInst::kMetalSimdgroup enum value, CopyInst::kMetalSIMDGroup enum value, and corresponding string converters for debug output.
Metal Target Detection & Utilities
src/op/utils.h, tilelang/utils/language.py
Added IsSIMDGroupBuffer and IsRegisterBuffer helpers; added is_metal_simdgroup predicate to detect "metal.simdgroup" scope buffers.
GEMM Instruction Selection & Warp Tiling
src/op/gemm.cc
Updated getGemmInst to return kMetalSimdgroup for Metal targets; modified computeWarpPartition to set kMPerWarp=8 (vs. generic 16) for Metal SIMD group tiling.
Copy Lowering for SIMD Group Storage
src/op/copy.h, src/op/copy.cc
Added CheckSIMDGroupCopy validation predicate and LowerSIMDGroupCopy implementation to emit builtin::simdgroup_store calls with 8x8 warp-tile addressing for 2D buffer transfers.
Fill Lowering for SIMD Group Matrices
src/op/fill.cc
Added SIMD group fast-path in FillNode::Lower to emit builtin::make_filled_simdgroup_matrix for constant-size SIMD group buffer fills with 64-element matrix boundary alignment.
Metal C Code Generation
src/target/codegen_metal.h, src/target/codegen_metal.cc
Implements CodeGenTileLangMetal, a CodeGenC subclass that generates Metal kernel void functions, handles metal.simdgroup 8x8 matrix allocation/access, local.var scalar loads/stores, Metal-specific builtins (simdgroup_*, make_filled_simdgroup_matrix), and optional compile-to-metallib via TVM callback.
Metal Codegen Build & Target Integration
src/backend/metal/CMakeLists.txt, tilelang/engine/lower.py
Added unconditional Metal codegen compilation and early exit on non-Apple platforms; routed Metal device codegen dispatch to new target.build.tilelang_metal FFI function.
Metal Fragment-to-SIMDGROUP IR Rewrite
tilelang/transform/metal_fragment_to_simdgroup.py, tilelang/engine/phase.py
Introduced MetalFragmentToSimdgroup pass that rewrites GEMM accumulator allocations from local.fragment to metal.simdgroup scope (excluding FP8 cases), with insertion into the pass pipeline before layout inference.
Target-Aware Layout Inference
src/transform/layout_inference.cc, src/transform/lower_device_storage_access_info.cc, tilelang/transform/decouple_type_cast.py
Made fragment buffer layout presence checks Metal-aware (skipped on Metal); excluded fragment-scoped allocates from device storage access info lowering; extended is_local_buffer to include SIMD group buffers.

Metal SIMD Group Abstractions & High-Level Macros

Layer / File(s) Summary
Metal SIMD Group Intrinsic Emitter
tilelang/intrinsics/metal_macro_generator.py
Implements MPSIntrinEmitter to emit Metal TIR macros for 2D tile loads (ldmatrix A/B with optional transpose), warp-level multiply-accumulate via T.simdgroup_multiply_accumulate, and simdgroup-to-memory copies with warp/fragment-grid mapping.
Metal SIMD Group Register Tile Abstraction
tilelang/tileop/metal_simdgroup.py
Defines RegisterTile/MMATile/RowVector dataclasses with layout metadata; provides macros for allocation (alloc_rt), fragment-level operations (fill/load/store/mma on 8x8 tiles), composition (mma_ab/mma_abt over tile grids), and scalar/vector reductions (row_max/row_sum/prefix_block_vector).
Metal GEMM Lowering
tilelang/tileop/gemm/gemm_metal.py, tilelang/tileop/gemm/__init__.py
Implements GemmMetal lowering that validates tiling (M/N/chunk/tile dims multiples of 8), routes non-FP8 inputs to Metal SIMDGROUP via updated _select_gemm_instruction, and generates shared-to-shared GEMM with optional intermediate SIMDGROUP accumulation via MPSIntrinEmitter and simplified TIR.
Quantized Ops & GDN Helpers
tilelang/tileop/metal_quant.py, tilelang/tileop/metal_gdn.py
Added Metal quantization tile sizing (QuantSimdgroupTile, SMALL_TILE/LARGE_TILE), FP8/FP4/E8M0 decoding functions, and GDN macros for KKT scoring, W/U computation, and causal gating via SIMD group accumulation and staged outputs.

JIT & Host Integration

Layer / File(s) Summary
Device Selection for MPS Fallback
tilelang/jit/adapter/base.py
Updated get_current_device_functor() to try MPS device when CUDA init raises an exception, falling back to CPU if both CUDA and MPS are unavailable.
Metal Kernel Source Exposure
tilelang/jit/adapter/torch/metal.py
Added get_kernel_source(kernel_only: bool) -> str method to MetalKernelAdapter for retrieving generated Metal kernel source.

Tests, Benchmarks & Documentation

Layer / File(s) Summary
Codegen Validation Tests
testing/python/metal/test_metal_gemm_v2_linux.py, testing/python/metal/test_metal_local_var.py, testing/python/metal/test_metal_simdgroup_store.py
Platform-independent tests that lower kernels for Metal and assert generated Metal source contains expected simdgroup operations, kernel entry points, and initialization patterns.
Metal Runtime Tests
testing/python/metal/test_metal_gemm_v2.py, testing/python/metal/test_metal_simdgroup_store.py
Hardware-gated tests (requires_metal) that compile GEMM kernels, run on MPS tensors, and validate outputs against PyTorch reference matmul with configurable tolerances.
Internal Scaffolding & Quantization Tests
testing/python/metal/test_metal_internal_scaffolding.py
Comprehensive internal-only test suite validating packed FP8/FP4 decode, quantized matmuls, GDN KKT/W/U component scoring, register tile MMA, and Metal source boundary constraints; includes opt-in benchmark suites with synchronized timing.
Device Selection Tests
testing/python/jit/test_tilelang_jit_adapter_mps.py
Tests for CUDA unavailability/exception scenarios ensuring fallback to MPS or CPU device selection.
Benchmark Utility
benchmark/matmul_metal/benchmark_matmul_metal.py
Standalone Metal GEMM benchmark script that sweeps block configurations, compares TileLang SIMD group kernels against PyTorch MPS matmul, and reports TFLOPS.
Coverage Documentation
testing/python/metal/metal_internal_runtime_coverage.md
Documents internal-only Metal test coverage scope, fail-closed boundaries, known blockers, and verification hooks.

Platform Dependencies

Layer / File(s) Summary
macOS-Specific FFI Constraints
pyproject.toml, requirements.txt, requirements-dev.txt
Added platform selector constraint apache-tvm-ffi<0.1.8; platform_system == 'Darwin' to prevent incompatible FFI versions on macOS alongside existing ~=0.1.0,>=0.1.2 baseline.

Sequence Diagram

sequenceDiagram
    participant User as Application
    participant JIT as TileLang JIT
    participant Lower as Lowering Pipeline
    participant Pass as FragmentToSimdgroup Pass
    participant Codegen as Metal CodeGen
    participant Target as Metal Target

    User->>JIT: Define `@T.prim_func` GEMM<br/>(shared A, shared B, fragment C)
    JIT->>JIT: Compile with tilelang.jit()
    JIT->>Lower: Lower to IR
    Lower->>Pass: Apply MetalFragmentToSimdgroup<br/>Rewrite fragment→simdgroup
    Pass->>Pass: Detect GEMM ops,<br/>exclude FP8 cases,<br/>remap allocations
    Pass-->>Lower: Return transformed mod
    Lower->>Lower: LayoutInference<br/>(no fragment layout needed on Metal)
    Lower->>Codegen: Invoke CodeGenTileLangMetal
    Codegen->>Codegen: Generate kernel::<br/>- Bind thread indices<br/>- Allocate 8x8 simdgroup tiles<br/>- Emit simdgroup ops<br/>- Handle local.var scalars
    Codegen->>Target: Optional: Compile to metallib<br/>via tvm_callback_metal_compile
    Target-->>Codegen: Return Metal binary
    Codegen-->>JIT: Return MetalModule
    JIT-->>User: Return callable kernel
    User->>JIT: Call kernel(A_mps, B_mps, C_mps)
    JIT->>Target: Launch on Apple MPS
    Target->>Target: Execute simdgroup matrix ops,<br/>copy results
    Target-->>User: Return results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

🐰 Metal rabbit hops with glee,
SIMD groups dance in harmony,
Tile by tile, they multiply,
On Apple's GPU, fast they fly! 🚀

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
Files documenting the actual PRs we just opened upstream:

- PR #1: ml-explore/mlx#3476 — from_dlpack Metal-aware consumer (against main, clean)
- PR #2: apache/tvm#19504 — TVM_METAL_STORAGE_MODE env opt-in (against main, clean)
- PR #3: tile-ai/tilelang#2139 — mixed-dtype T.gemm via scalar fallback (stacks on PR #2130)
- PR #4: tile-ai/tilelang#2140 — FP8-input T.gemm scalar fallback routing (stacks on PR #2130)
- PR #5: tile-ai/tilelang#2141 — T.Pipelined num_stages>1 3D buffer fix (stacks on PR #2130)
- PR #6: tile-ai/tilelang#2142 — T.fp8_scaled_matmul DSL intrinsic (stacks on PR #2130)

Deferred (split into companion PRs needed): tilelang_metal_fp8 and
tilelang_metal_fp8_vector each touch both tilelang supermodule and the
TileLang/tvm vendored submodule. These need 2 PRs each — one to
tile-ai/tilelang, one to TileLang/tvm — separate filing round.

PRs #3-#6 are independent of each other; each branches directly from
jorgecurious/tilelang:metal-gemm-upstream-rebase HEAD 971c17b, so they
can be reviewed in any order. They DO depend on the upstream 4-PR Apple
Metal landing chain (#1869, #2118, #2121, #2130) merging first; if any
of those land separately, ours can be retargeted at main.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves Metal backend lowering for T.gemm, including routing FP8-input GEMMs away from simdgroup MMA (which can’t legally represent FP8 on Apple Silicon today) and adding the supporting Metal codegen/scaffolding needed for simdgroup accumulators and stores.

Changes:

  • Add Metal-specific GEMM dispatch (GemmInst.METAL_SIMDGROUP) with FP8-input detection that routes Metal FP8 GEMMs to the scalar fallback.
  • Introduce a Metal pass to rewrite GEMM accumulators from local.fragment to metal.simdgroup (and skip that rewrite for FP8-input GEMMs).
  • Add/extend Metal codegen, lowering, and tests to support metal.simdgroup allocation/fill/copy and MPS-oriented runtime coverage.

Reviewed changes

Copilot reviewed 36 out of 37 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tilelang/utils/language.py Add is_metal_simdgroup() scope helper.
tilelang/transform/metal_fragment_to_simdgroup.py New Metal pass rewriting GEMM accumulators to metal.simdgroup, skipping FP8-input GEMMs.
tilelang/transform/decouple_type_cast.py Treat metal.simdgroup buffers as “local” for cast-decoupling decisions.
tilelang/tileop/metal_simdgroup.py New internal simdgroup helper macros/types (RegisterTile/RowVector/etc.).
tilelang/tileop/metal_quant.py Add packed-uint8 FP8/FP4/e8m0 decode helpers for Metal kernels.
tilelang/tileop/metal_gdn.py Add internal GDN/attention-style simdgroup tile macros.
tilelang/tileop/gemm/inst.py Extend GemmInst with METAL_SIMDGROUP.
tilelang/tileop/gemm/gemm_metal.py New Metal simdgroup GEMM lowering implementation.
tilelang/tileop/gemm/init.py Metal GEMM instruction selection + FP8-input routing to scalar; hook up GemmMetal.
tilelang/jit/adapter/torch/metal.py Expose Metal kernel source via adapter.
tilelang/jit/adapter/base.py Prefer MPS device selection when CUDA is unavailable/initialization fails.
tilelang/intrinsics/metal_macro_generator.py New MPSIntrinEmitter to generate simdgroup load/store/MMA macros.
tilelang/engine/phase.py Insert Metal fragment→simdgroup rewrite pass before layout inference.
tilelang/engine/lower.py Switch Metal device codegen to target.build.tilelang_metal.
testing/python/metal/test_metal_simdgroup_store.py Tests for simdgroup accumulator path + direct simdgroup_store to device memory.
testing/python/metal/test_metal_local_var.py Tests for local.var scalar lowering on Metal.
testing/python/metal/test_metal_internal_scaffolding.py Broad internal Metal scaffolding/runtime/source-boundary probes.
testing/python/metal/test_metal_gemm_v2_linux.py Cross-platform (no Metal runtime) codegen-only tests for Metal GEMM v2.
testing/python/metal/test_metal_gemm_v2.py Runtime correctness tests for Metal GEMM v2 on Metal hardware.
testing/python/metal/metal_internal_runtime_coverage.md Document internal Metal runtime coverage expectations.
testing/python/jit/test_tilelang_jit_adapter_mps.py Tests for MPS device selection behavior in the JIT adapter.
src/transform/lower_device_storage_access_info.cc Allow .fragment scope tags in storage-access-info lowering.
src/transform/layout_inference.cc Metal-specific relaxation for fragment layout completeness checks.
src/target/codegen_metal.h Add TileLang Metal codegen declaration (CodeGenTileLangMetal).
src/target/codegen_metal.cc Implement TileLang Metal codegen including metal.simdgroup + local.var support; register target.build.tilelang_metal.
src/op/utils.h Add IsSIMDGroupBuffer/IsRegisterBuffer helpers.
src/op/parallel.cc Guard fragment layout access when layout may be absent.
src/op/gemm.h Add kMetalSimdgroup GEMM inst enum value.
src/op/gemm.cc Select Metal simdgroup GEMM inst for Metal targets; adjust warp partition heuristics for Metal.
src/op/fill.cc Add fill lowering for metal.simdgroup buffers via make_filled_simdgroup_matrix.
src/op/copy.h Add Metal simdgroup copy inst + lowering hooks.
src/op/copy.cc Implement simdgroup-store lowering from metal.simdgroup → shared/global.
src/backend/metal/CMakeLists.txt Always build Metal codegen source; gate runtime pieces on Apple/USE_METAL.
requirements.txt Constrain apache-tvm-ffi on Darwin.
requirements-dev.txt Constrain apache-tvm-ffi on Darwin (dev).
pyproject.toml Constrain apache-tvm-ffi on Darwin (wheel/install).
benchmark/matmul_metal/benchmark_matmul_metal.py Add a simple Metal GEMM benchmark script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +126 to +160
def _rewrite_scope(body, var_map):
buf_map = {}

def _pre_order(stmt):
if isinstance(stmt, tir.Block):
new_alloc_bufs = []
changed = False
for buf in stmt.alloc_buffers:
new_buf = _remap_buffer(buf, var_map)
new_alloc_bufs.append(new_buf)
if not new_buf.same_as(buf):
buf_map[buf] = new_buf
changed = True
if changed:
new_body = tir.stmt_functor.substitute(stmt.body, var_map)
new_block = tir.Block(
stmt.iter_vars,
stmt.reads,
stmt.writes,
stmt.name_hint,
new_body,
stmt.init,
new_alloc_bufs,
stmt.match_buffers,
stmt.annotations,
)
return (
tir.BlockRealize(
stmt.iter_vars,
tir.const(True, "bool"),
new_block,
)
if False
else new_block
)
Comment on lines +68 to +99
def _collect_fragment_gemm_accum_vars(body: tir.Stmt) -> set:
"""Walk the body and return fragment vars safe to rewrite to simdgroup.

GEMM accumulators backed by ``local.fragment`` are eligible for the
rewrite to ``metal.simdgroup``, which the Metal simdgroup MMA path
needs. We exclude FP8-input GEMMs because the dispatcher routes them
to the scalar fallback (Apple has no native FP8 ALU through M5; the
per-element T.cast invokes the storage-only decode helpers from the
FP8 prelude -- see audiohacking fp8_scaled_matmul_kernel for the
analogous pattern). For those GEMMs the accumulator must stay in
``local.fragment`` so the scalar fallback can perform its
per-element T.cast(..., accum_dtype) arithmetic without tripping the
Metal codegen's check that ``metal.simdgroup`` allocations are
scalar 8x8 blocks.
"""
accum_vars: set = set()
gemm_ops = _get_gemm_ops()

def _visitor(stmt):
if isinstance(stmt, tir.Evaluate) and isinstance(stmt.value, tir.Call):
call = stmt.value
if call.op in gemm_ops and len(call.args) >= 3:
# FP8 inputs (storage-only on Metal) route to the scalar
# fallback; exclude their accumulators from the simdgroup
# rewrite so the codegen does not allocate a
# ``metal.simdgroup`` buffer for them.
a_buf = _extract_buffer_from_region(call.args[0])
b_buf = _extract_buffer_from_region(call.args[1])
fp8_inputs = (a_buf is not None and _is_fp8_dtype(a_buf.dtype)) or \
(b_buf is not None and _is_fp8_dtype(b_buf.dtype))
if fp8_inputs:
return
Comment on lines +1 to +10
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
Comment on lines +198 to +202
# helpers. The runtime mapping of GemmInst.Scalar to
# GemmMetalScalar on Metal targets is provided by the
# tilelang_gemm_mixed_dtype companion patch (PR #2118 stack);
# without it the resulting kernel will not lower correctly,
# but the routing decision is the load-bearing change here.
return lambda: torch.device("cuda", current_device())
except Exception:
return lambda: torch.device("cuda", torch.cuda.current_device())
pass
Comment on lines +436 to +445
// Check that all local.fragment buffers have inferred layouts.
// On Metal targets, fragment buffers used as GEMM accumulators are
// lowered to opaque simdgroup matrices, so they have no explicit
// thread-level layout and can be safely skipped.
for (const auto &[buffer, _] : use_list_) {
if (IsFragmentBuffer(buffer)) {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
if (!TargetIsMetal(target_) && layout_map.count(buffer) == 0) {
ICHECK(false) << "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
<< "Only float16, float32, and bfloat16 are supported, but got "
<< op->dtype;
ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
<< constant_size << " bytes\n";
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)

362-383: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard the earlier fragment dereference too.

This branch now treats T.layout_map[buffer] as optional, but IsBufferCompletelyReplicated(buffer, T.layout_map) at Lines 362-363 still does as<Fragment>().value(). On the Metal path you described, that means we can still crash before reaching the new frag.has_value() check. Make the helper return false for a missing fragment, or check has_value() before calling it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/parallel.cc` around lines 362 - 383,
IsBufferCompletelyReplicated(buffer, T.layout_map) may dereference a missing
Fragment; update the call site or the helper to avoid crashing: either modify
IsBufferCompletelyReplicated to check T.layout_map[buffer].as<Fragment>() and
return false when no fragment exists, or check that
T.layout_map[buffer].as<Fragment>().has_value() before calling .value() inside
the helper so it never dereferences an absent Fragment. Ensure this change keeps
the later logic that uses frag.has_value() and the
source_buffer/read_source_buffer flow intact.
🧹 Nitpick comments (2)
tilelang/engine/phase.py (1)

200-204: ⚡ Quick win

Gate this rewrite on target.kind.name == "metal".

The comment says this is a Metal-only rewrite, but the pass currently runs for every target. Guarding it here keeps non-Metal pipelines unchanged and avoids depending on MetalFragmentToSimdgroup being a perfect no-op everywhere else.

♻️ Suggested change
-    # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup
-    # before layout inference (which would otherwise require a layout for them)
-    from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup
-
-    mod = MetalFragmentToSimdgroup(mod)
+    # On Metal, rewrite local.fragment GEMM accumulators to metal.simdgroup
+    # before layout inference (which would otherwise require a layout for them)
+    if target.kind.name == "metal":
+        from tilelang.transform.metal_fragment_to_simdgroup import MetalFragmentToSimdgroup
+
+        mod = MetalFragmentToSimdgroup(mod)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/engine/phase.py` around lines 200 - 204, The Metal-specific rewrite
MetalFragmentToSimdgroup is being applied unconditionally; gate it so it only
runs when the compilation target is Metal by checking target.kind.name ==
"metal" before constructing/applying MetalFragmentToSimdgroup to mod. Locate the
code around the phase where mod is transformed (the
MetalFragmentToSimdgroup(mod) call) and wrap that construction/application in a
conditional using the existing target variable (e.g., if target.kind.name ==
"metal": mod = MetalFragmentToSimdgroup(mod)).
testing/python/metal/test_metal_gemm_v2_linux.py (1)

62-79: ⚡ Quick win

Add an FP8 regression case for the new routing behavior.

This suite only locks down the non-FP8 simdgroup path, but the PR’s actual behavior change is “FP8 inputs lower successfully without simdgroup MMA.” Please add a codegen-only FP8 case here that asserts Metal lowering succeeds and that simdgroup_multiply_accumulate is absent, otherwise the new dispatcher/pass interaction can regress silently.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_gemm_v2_linux.py` around lines 62 - 79, Add a
codegen-only FP8 regression test in this file that calls
assert_metal_gemm_v2_codegen with FP8 input dtype(s) (matching your FP8 type
symbol) and appropriate shapes (e.g., small 16/32 blocks) to ensure Metal
lowering succeeds; then assert the generated IR/source does not contain the
symbol simdgroup_multiply_accumulate (i.e., verify absence of that call). Name
the new test something like test_metal_gemm_v2_fp8_no_simdgroup and place it
alongside the other test_* functions so it runs in CI, using the existing
assert_metal_gemm_v2_codegen helper to drive codegen-only validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/backend/metal/CMakeLists.txt`:
- Around line 3-8: The CMakeLists currently unconditionally appends
src/target/codegen_metal.cc to TILE_LANG_SRCS causing build failures on
non-Apple/Non-Metal hosts because runtime/metal/metal_module.h is missing; fix
by moving the list(APPEND TILE_LANG_SRCS src/target/codegen_metal.cc) so it is
executed only inside the existing guards (USE_METAL and APPLE) in
CMakeLists.txt, or alternatively refactor the codegen_metal.cc file to split
Metal runtime includes from pure codegen logic (e.g., create a
codegen_metal_core.cc without runtime/metal/metal_module.h and keep
runtime-dependent bits in a separate file) and update TILE_LANG_SRCS to only add
the runtime-dependent file under the USE_METAL/APPLE guards so non-Metal builds
no longer attempt to compile runtime headers.

In `@src/op/copy.cc`:
- Around line 802-840: CopyNode::CheckSIMDGroupCopy currently accepts Metal
SIMD-group copies without proving dst_range is fully in-bounds, which allows
LowerSIMDGroupCopy to emit unpredicated simdgroup_store and OOB stores; modify
CheckSIMDGroupCopy to consult the buffer_oob/analyzer-based in-bounds check (the
same predicate used elsewhere) for dst (and if relevant src) and only return
true when the analyzer proves dst_range is fully inside the dst buffer bounds;
thread the analyzer/buffer_oob check into this path (or call the existing helper
used by other copy legality checks) so that CheckSIMDGroupCopy refuses
SIMD-group selection for boundary tiles that might be OOB.

In `@src/op/fill.cc`:
- Around line 183-202: Before emitting consecutive matrix fills, prove the
region maps to a dense whole-matrix span: compute the region's last element
offset (e.g., last_element = sum((region[i]->min + region[i]->extent - 1) *
strides[i])) and use analyzer to check both that FloorMod(element_offset,
matrix_elements) == 0 (already present) and that FloorMod(last_element + 1,
matrix_elements) == 0 and FloorDiv(last_element, matrix_elements) ==
matrix_index_base + IntImm(DataType::Int(32), num_matrices - 1); if any of these
cannot be proven, do not run the consecutive Call loop that uses
matrix_index_base/num_matrices (the stmts loop) and instead fall back to the
non-consecutive/dense-element emission path. Reference variables: region,
strides, element_offset, last_element, matrix_elements, matrix_index_base,
num_matrices, analyzer->CanProveEqual, and the stmts/Call emission block.

In `@testing/python/metal/test_metal_local_var.py`:
- Around line 34-40: The zero-init count assertion is too strict for
_make_local_var_func() because it only guarantees one default-initialized local
(y) while x is explicitly initialized to 3; update the assertion in
test_metal_local_var.py that currently uses "assert
len(re.findall(r\"\\bint\\s+\\w+\\s*=\\s*0;\", src)) >= 2" to accept a single
zero-initialized local (e.g., change >= 2 to >= 1) so the test verifies the
presence of default-initialized locals without failing on valid codegen that
doesn't produce incidental temporaries.

In `@tilelang/jit/adapter/base.py`:
- Around line 82-84: The device-selection try/except currently only wraps
_lazy_init() at construction time, letting exceptions from _cuda_getDevice()
escape when the returned lambda is later invoked; move the CUDA and MPS
device-query logic inside the returned thunk (the lambda returned by the device
functor) so the try/except covers the actual runtime call to
torch._C._cuda_getDevice() and fallback to MPS/CPU; update both the device
functor and get_current_stream_functor to defer queries into their returned
closures and keep _lazy_init() where needed. Also add a regression test that
monkeypatches torch._C._cuda_getDevice to raise when the thunk is executed and
assert the device_functor()/get_current_stream_functor() fall back to MPS or CPU
as expected.

---

Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 362-383: IsBufferCompletelyReplicated(buffer, T.layout_map) may
dereference a missing Fragment; update the call site or the helper to avoid
crashing: either modify IsBufferCompletelyReplicated to check
T.layout_map[buffer].as<Fragment>() and return false when no fragment exists, or
check that T.layout_map[buffer].as<Fragment>().has_value() before calling
.value() inside the helper so it never dereferences an absent Fragment. Ensure
this change keeps the later logic that uses frag.has_value() and the
source_buffer/read_source_buffer flow intact.

---

Nitpick comments:
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Around line 62-79: Add a codegen-only FP8 regression test in this file that
calls assert_metal_gemm_v2_codegen with FP8 input dtype(s) (matching your FP8
type symbol) and appropriate shapes (e.g., small 16/32 blocks) to ensure Metal
lowering succeeds; then assert the generated IR/source does not contain the
symbol simdgroup_multiply_accumulate (i.e., verify absence of that call). Name
the new test something like test_metal_gemm_v2_fp8_no_simdgroup and place it
alongside the other test_* functions so it runs in CI, using the existing
assert_metal_gemm_v2_codegen helper to drive codegen-only validation.

In `@tilelang/engine/phase.py`:
- Around line 200-204: The Metal-specific rewrite MetalFragmentToSimdgroup is
being applied unconditionally; gate it so it only runs when the compilation
target is Metal by checking target.kind.name == "metal" before
constructing/applying MetalFragmentToSimdgroup to mod. Locate the code around
the phase where mod is transformed (the MetalFragmentToSimdgroup(mod) call) and
wrap that construction/application in a conditional using the existing target
variable (e.g., if target.kind.name == "metal": mod =
MetalFragmentToSimdgroup(mod)).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 81b49bed-fc4f-498a-bbce-19185f4b4703

📥 Commits

Reviewing files that changed from the base of the PR and between d135bd1 and 1962b1d.

📒 Files selected for processing (37)
  • benchmark/matmul_metal/benchmark_matmul_metal.py
  • pyproject.toml
  • requirements-dev.txt
  • requirements.txt
  • src/backend/metal/CMakeLists.txt
  • src/op/copy.cc
  • src/op/copy.h
  • src/op/fill.cc
  • src/op/gemm.cc
  • src/op/gemm.h
  • src/op/parallel.cc
  • src/op/utils.h
  • src/target/codegen_metal.cc
  • src/target/codegen_metal.h
  • src/transform/layout_inference.cc
  • src/transform/lower_device_storage_access_info.cc
  • testing/python/jit/test_tilelang_jit_adapter_mps.py
  • testing/python/metal/metal_internal_runtime_coverage.md
  • testing/python/metal/test_metal_gemm_v2.py
  • testing/python/metal/test_metal_gemm_v2_linux.py
  • testing/python/metal/test_metal_internal_scaffolding.py
  • testing/python/metal/test_metal_local_var.py
  • testing/python/metal/test_metal_simdgroup_store.py
  • tilelang/engine/lower.py
  • tilelang/engine/phase.py
  • tilelang/intrinsics/metal_macro_generator.py
  • tilelang/jit/adapter/base.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/tileop/gemm/__init__.py
  • tilelang/tileop/gemm/gemm_metal.py
  • tilelang/tileop/gemm/inst.py
  • tilelang/tileop/metal_gdn.py
  • tilelang/tileop/metal_quant.py
  • tilelang/tileop/metal_simdgroup.py
  • tilelang/transform/decouple_type_cast.py
  • tilelang/transform/metal_fragment_to_simdgroup.py
  • tilelang/utils/language.py

Comment on lines +3 to +8
# Metal codegen is pure C++ and can generate Metal shader source on any
# platform. Always compile it so target.build.tilelang_metal is available for
# cross-compilation and source-level tests on non-Apple hosts.
list(APPEND TILE_LANG_SRCS
src/target/codegen_metal.cc
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== CMake block =="
sed -n '1,40p' src/backend/metal/CMakeLists.txt

echo
echo "== codegen_metal.cc includes =="
sed -n '1,50p' src/target/codegen_metal.cc | grep '^#include'

echo
echo "== metal_module.h presence in this checkout =="
fd -a 'metal_module\.h$' . || true

Repository: tile-ai/tilelang

Length of output: 1331


Unconditionally compiling src/target/codegen_metal.cc breaks non-Metal builds.

The CMakeLists.txt appends src/target/codegen_metal.cc to TILE_LANG_SRCS before the USE_METAL and APPLE guards take effect. However, this file includes runtime/metal/metal_module.h, which is not available on non-Metal or non-Apple platforms. This causes compilation to fail before the early return() statements can skip the Metal-specific configuration.

To fix: either move the source append behind the appropriate guards, or refactor codegen_metal.cc to isolate the Metal runtime dependencies from the codegen-only logic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/metal/CMakeLists.txt` around lines 3 - 8, The CMakeLists
currently unconditionally appends src/target/codegen_metal.cc to TILE_LANG_SRCS
causing build failures on non-Apple/Non-Metal hosts because
runtime/metal/metal_module.h is missing; fix by moving the list(APPEND
TILE_LANG_SRCS src/target/codegen_metal.cc) so it is executed only inside the
existing guards (USE_METAL and APPLE) in CMakeLists.txt, or alternatively
refactor the codegen_metal.cc file to split Metal runtime includes from pure
codegen logic (e.g., create a codegen_metal_core.cc without
runtime/metal/metal_module.h and keep runtime-dependent bits in a separate file)
and update TILE_LANG_SRCS to only add the runtime-dependent file under the
USE_METAL/APPLE guards so non-Metal builds no longer attempt to compile runtime
headers.

Comment thread src/op/copy.cc
Comment on lines +802 to 840
bool CopyNode::CheckSIMDGroupCopy(Target target) const {
if (!TargetIsMetal(target) || !IsSIMDGroupBuffer(src)) {
return false;
}
if (!IsSharedBuffer(dst) && !IsGlobalBuffer(dst)) {
return false;
}
if (src->dtype != dst->dtype) {
return false;
}
if (src_range.size() != 2 || dst_range.size() != 2 ||
dst->shape.size() != 2) {
return false;
}

int total_elements = 1;
for (auto extent : src->shape) {
auto imm = extent.as<IntImmNode>();
if (!imm) {
return false;
}
total_elements *= imm->value;
}
if (total_elements % 64 != 0) {
return false;
}

for (int i = 0; i < 2; ++i) {
auto src_shape = src->shape[i].as<IntImmNode>();
auto src_min = src_range[i]->min.as<IntImmNode>();
auto src_extent = src_range[i]->extent.as<IntImmNode>();
auto dst_extent = dst_range[i]->extent.as<IntImmNode>();
if (!src_shape || !src_min || src_min->value != 0 || !src_extent ||
!dst_extent || src_extent->value != src_shape->value ||
src_extent->value != dst_extent->value || src_extent->value % 8 != 0) {
return false;
}
}
return true;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard Metal SIMD-group copy against edge-tile OOB stores.

This legality check never proves that dst_range is fully in-bounds, but LowerSIMDGroupCopy later emits unpredicated simdgroup_store calls. For boundary tiles on non-divisible shapes, this can select kMetalSIMDGroup and write past dst where LowerNormalCopy would have kept the bounds predicate. Please thread buffer_oob/analyzer-based in-bounds checks through this path before returning true.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/copy.cc` around lines 802 - 840, CopyNode::CheckSIMDGroupCopy
currently accepts Metal SIMD-group copies without proving dst_range is fully
in-bounds, which allows LowerSIMDGroupCopy to emit unpredicated simdgroup_store
and OOB stores; modify CheckSIMDGroupCopy to consult the
buffer_oob/analyzer-based in-bounds check (the same predicate used elsewhere)
for dst (and if relevant src) and only return true when the analyzer proves
dst_range is fully inside the dst buffer bounds; thread the analyzer/buffer_oob
check into this path (or call the existing helper used by other copy legality
checks) so that CheckSIMDGroupCopy refuses SIMD-group selection for boundary
tiles that might be OOB.

Comment thread src/op/fill.cc
Comment on lines +183 to +202
PrimExpr element_offset = 0;
for (size_t i = 0; i < region.size(); ++i) {
element_offset += region[i]->min * strides[i];
}
PrimExpr matrix_elements = IntImm(element_offset.dtype(), 64);
ICHECK(
analyzer->CanProveEqual(FloorMod(element_offset, matrix_elements), 0))
<< "simdgroup fill region must start on an 8x8 matrix boundary";
PrimExpr matrix_index_base = FloorDiv(element_offset, matrix_elements);
Array<Stmt> stmts;
for (int i = 0; i < num_matrices; i++) {
stmts.push_back(Evaluate(
Call(DataType::Handle(), builtin::make_filled_simdgroup_matrix(),
{dst->data, matrix_index_base + IntImm(DataType::Int(32), i),
fill_value, IntImm(DataType::Int(32), 8),
IntImm(DataType::Int(32), 8)})));
}
if (stmts.size() == 1)
return stmts[0];
return SeqStmt(stmts);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-contiguous regions before emitting consecutive matrix fills.

This path derives one matrix_index_base from element_offset and then fills num_matrices consecutive 8x8 matrices, but it never proves that region itself is contiguous in matrix order. A strided slice can pass the current % 64 == 0 and alignment checks yet still target non-consecutive elements, so this will overwrite the wrong matrices instead of the requested region. Please add a contiguity check here and fall back when the region is not a dense whole-matrix span.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/fill.cc` around lines 183 - 202, Before emitting consecutive matrix
fills, prove the region maps to a dense whole-matrix span: compute the region's
last element offset (e.g., last_element = sum((region[i]->min +
region[i]->extent - 1) * strides[i])) and use analyzer to check both that
FloorMod(element_offset, matrix_elements) == 0 (already present) and that
FloorMod(last_element + 1, matrix_elements) == 0 and FloorDiv(last_element,
matrix_elements) == matrix_index_base + IntImm(DataType::Int(32), num_matrices -
1); if any of these cannot be proven, do not run the consecutive Call loop that
uses matrix_index_base/num_matrices (the stmts loop) and instead fall back to
the non-consecutive/dense-element emission path. Reference variables: region,
strides, element_offset, last_element, matrix_elements, matrix_index_base,
num_matrices, analyzer->CanProveEqual, and the stmts/Call emission block.

Comment on lines +34 to +40
# local.var should lower to scalar declarations/stores rather than arrays or
# an unsupported storage scope.
assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
assert re.search(r"\w+\s*=\s*3;", src), src
assert re.search(r"\w+\s*=\s*\(\w+ \+ 4\);", src), src
assert "local.var" not in src
assert "thread int" not in src
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Loosen the zero-init assertion to match the test kernel.

_make_local_var_func() only guarantees one default-initialized local (y); x is explicitly initialized to 3. Requiring >= 2 matches incidental temporaries rather than the local.var behavior you're trying to lock down, so this can fail on valid codegen changes.

Suggested adjustment
-    assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 2, src
+    assert len(re.findall(r"\bint\s+\w+\s*=\s*0;", src)) >= 1, src
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/metal/test_metal_local_var.py` around lines 34 - 40, The
zero-init count assertion is too strict for _make_local_var_func() because it
only guarantees one default-initialized local (y) while x is explicitly
initialized to 3; update the assertion in test_metal_local_var.py that currently
uses "assert len(re.findall(r\"\\bint\\s+\\w+\\s*=\\s*0;\", src)) >= 2" to
accept a single zero-initialized local (e.g., change >= 2 to >= 1) so the test
verifies the presence of default-initialized locals without failing on valid
codegen that doesn't produce incidental temporaries.

Comment on lines +82 to +84
pass
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return lambda: torch.device("mps")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Inspecting deferred device lookup in BaseKernelAdapter:"
sed -n '76,86p' tilelang/jit/adapter/base.py

echo
echo "Searching for test coverage of _lazy_init vs _cuda_getDevice failure:"
rg -n -C2 '_lazy_init|_cuda_getDevice' tilelang/jit/adapter/base.py testing/python/jit/test_tilelang_jit_adapter_mps.py

Repository: tile-ai/tilelang

Length of output: 2182


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Checking all tests in test_tilelang_jit_adapter_mps.py ==="
cat testing/python/jit/test_tilelang_jit_adapter_mps.py

echo
echo "=== Checking where device_functor() is called ==="
rg -n "device_functor\(\)" tilelang/jit/adapter/

echo
echo "=== Checking similar pattern in get_current_stream_functor ==="
sed -n '51,66p' tilelang/jit/adapter/base.py

echo
echo "=== Search for _cuda_getDevice usage beyond base.py ==="
rg -n "_cuda_getDevice" --type py

Repository: tile-ai/tilelang

Length of output: 3471


Move CUDA/MPS device-query logic into the returned thunk to ensure fallback handles deferred failures.

The except handler only wraps _lazy_init() during functor construction, but _cuda_getDevice() is only called when the returned lambda executes later. If the device query fails at that time, the exception bypasses the fallback logic entirely.

Suggested fix
     `@staticmethod`
     def get_current_device_functor() -> Callable[[], torch.device]:
         """Return a callable that yields Torch's current device.
 
         Similar to the stream functor, we capture a callable that, when called,
         fetches the current device according to PyTorch. On CPU or when CUDA is
         unavailable, returns ``torch.device('cpu')``.
         """
-        if torch.cuda.is_available():
-            try:
-                torch.cuda._lazy_init()
-                current_device = torch._C._cuda_getDevice
-                return lambda: torch.device("cuda", current_device())
-            except Exception:
-                pass
-        if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
-            return lambda: torch.device("mps")
-        # CPU fallback
-        return lambda: torch.device("cpu")
+        def _current_device() -> torch.device:
+            if torch.cuda.is_available():
+                try:
+                    torch.cuda._lazy_init()
+                    return torch.device("cuda", torch._C._cuda_getDevice())
+                except Exception:
+                    pass
+            if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
+                return torch.device("mps")
+            return torch.device("cpu")
+
+        return _current_device

Add a regression test that monkeypatches torch._C._cuda_getDevice to raise when device_functor() is invoked, verifying fallback to MPS/CPU. Note: get_current_stream_functor has an identical structural issue.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/jit/adapter/base.py` around lines 82 - 84, The device-selection
try/except currently only wraps _lazy_init() at construction time, letting
exceptions from _cuda_getDevice() escape when the returned lambda is later
invoked; move the CUDA and MPS device-query logic inside the returned thunk (the
lambda returned by the device functor) so the try/except covers the actual
runtime call to torch._C._cuda_getDevice() and fallback to MPS/CPU; update both
the device functor and get_current_stream_functor to defer queries into their
returned closures and keep _lazy_init() where needed. Also add a regression test
that monkeypatches torch._C._cuda_getDevice to raise when the thunk is executed
and assert the device_functor()/get_current_stream_functor() fall back to MPS or
CPU as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants