From 8f12aea3495a6cd19ace58dd7e20bdd514d1df03 Mon Sep 17 00:00:00 2001 From: VitalyR Date: Tue, 12 May 2026 13:11:28 +0800 Subject: [PATCH 1/4] [MetaXGPU] Add MACA BSM intrinsic lowering --- src/op/builtin.cc | 25 +++++ src/op/builtin.h | 17 ++++ src/target/codegen_maca.cc | 80 +++++++++++++++- src/tl_templates/maca/common.h | 1 + ...st_tilelang_language_access_ptr_codegen.py | 38 ++++++++ tilelang/contrib/mxcc.py | 5 + tilelang/language/tir/ir.py | 7 ++ tilelang/language/tir/ir.pyi | 29 ++++++ tilelang/language/tir/op.py | 93 +++++++++++++++++++ 9 files changed, 294 insertions(+), 1 deletion(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ab518924..91614d51 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -306,6 +306,31 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(maca_ldg_b128_bsm_predicator) + .set_num_inputs(10) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(maca_ldg_b64_bsm_predicator) + .set_num_inputs(10) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(maca_arrive_gvmcnt) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(maca_arrive_bsmcnt) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(maca_barrier_inst) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(fence_proxy_async) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index e48f55fe..39131d80 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -516,6 +516,23 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); */ TVM_DLL const Op &ptx_cp_async(); +/*! + * \brief MACA BSM direct global-to-shared load intrinsics. + * + * These ops expose the low-level C500 BSM load builtins used by handwritten + * MACA HGEMM kernels. They are intentionally target-specific and lower only in + * the MACA code generator. + */ +TVM_DLL const Op &maca_ldg_b128_bsm_predicator(); +TVM_DLL const Op &maca_ldg_b64_bsm_predicator(); + +/*! + * \brief MACA BSM/global-memory wait and barrier intrinsics. + */ +TVM_DLL const Op &maca_arrive_gvmcnt(); +TVM_DLL const Op &maca_arrive_bsmcnt(); +TVM_DLL const Op &maca_barrier_inst(); + /*! * \brief Pack two b16 value into a b32 value * diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index fe9beefd..ff4ae2b4 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -1581,7 +1581,47 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << ss.str(); this->stream << ");\n"; }; - if (op->op.same_as(tl::maca_memcpy_async())) { + if (op->op.same_as(builtin::ptx_cp_async())) { + // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, + // args[3] = predicate (optional) + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, bytes, [predicate])"; + + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + + this->PrintIndent(); + if (op->args.size() == 3) { + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[3]); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << ", " << src << ", " << condition << ");\n"; + } + } else if (op->op.same_as(tl::ptx_cp_async())) { + // TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr, + // args[2] = bytes, args[3] = predicate (optional) + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, bytes, [predicate])"; + + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + + this->PrintIndent(); + if (op->args.size() == 3) { + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[3]); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << ", " << src << ", " << condition << ");\n"; + } + } else if (op->op.same_as(tl::maca_memcpy_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = barrier ICHECK(op->args.size() == 4) @@ -1603,6 +1643,44 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { << "maca_barrier_arrive_and_wait expects 1 argument (bar)"; std::string dummyRet = this->PrintExpr(op->args[0]); this->stream << "barrier_arrive_and_wait(" << dummyRet << ");\n"; + } else if (op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) || + op->op.same_as(tl::maca_ldg_b64_bsm_predicator())) { + ICHECK_EQ(op->args.size(), 10U) + << "maca_ldg_b{64,128}_bsm_predicator expects 10 arguments " + ""; + std::string builtin_name = op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) + ? "__builtin_mxc_ldg_b128_bsm_predicator" + : "__builtin_mxc_ldg_b64_bsm_predicator"; + this->PrintIndent(); + this->stream << builtin_name << "("; + for (size_t i = 0; i < 9; ++i) { + if (i > 0) { + this->stream << ", "; + } + this->stream << this->PrintExpr(op->args[i]); + } + this->stream << ", " << Downcast(op->args[9])->value << ");\n"; + } else if (op->op.same_as(tl::maca_arrive_gvmcnt())) { + ICHECK_EQ(op->args.size(), 1U); + this->PrintIndent(); + this->stream << "__builtin_mxc_arrive_gvmcnt(" << this->PrintExpr(op->args[0]) + << ");\n"; + } else if (op->op.same_as(tl::maca_arrive_bsmcnt())) { + ICHECK_EQ(op->args.size(), 1U); + this->PrintIndent(); + this->stream << "__builtin_mxc_arrive_bsmcnt(" << this->PrintExpr(op->args[0]) + << ");\n"; + } else if (op->op.same_as(tl::maca_barrier_inst())) { + ICHECK_EQ(op->args.size(), 0U); + this->PrintIndent(); + this->stream << "__builtin_mxc_barrier_inst();\n"; + } else if (op->op.same_as(builtin::ptx_commit_group())) { + print_extern_call_stmt("tl::cp_async_commit"); + } else if (op->op.same_as(builtin::ptx_wait_group())) { + int n = Downcast(op->args[0])->value; + std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; + print_extern_call_stmt(func_name, 1); } else if (op->op.same_as(builtin::create_barriers())) { this->PrintIndent(); int barrier_count = Downcast(op->args[0])->value; diff --git a/src/tl_templates/maca/common.h b/src/tl_templates/maca/common.h index ce48e4f0..492178e2 100644 --- a/src/tl_templates/maca/common.h +++ b/src/tl_templates/maca/common.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include diff --git a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py index 578d5255..c2eb99e6 100644 --- a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py +++ b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py @@ -165,5 +165,43 @@ def main( assert "tl::cp_async_wait<0>" not in src, "Did not expect async_copy lowering to auto-emit wait" +def test_maca_bsm_intrinsics_codegen(): + """Smoke-test codegen for the MACA BSM builtin wrappers.""" + + @T.prim_func + def main( + A: T.Tensor((64,), T.uint8), + B: T.Tensor((64,), T.uint8), + ): + with T.Kernel(1, threads=32): + S = T.alloc_shared((64,), T.uint8) + T.maca_ldg_b128_bsm_predicator( + T.address_of(S[0]), + T.address_of(A[0]), + 0, + True, + True, + False, + True, + 1, + 1, + "MACA_ICMP_EQ", + ) + T.maca_arrive_gvmcnt(4) + T.maca_arrive_bsmcnt(2) + T.maca_barrier_inst() + B[0] = S[0] + + kernel = tilelang.compile(main, out_idx=[1], target="maca") + src = kernel.get_kernel_source() + print("=== MACA BSM builtin codegen ===") + print(src) + assert "__builtin_mxc_ldg_b128_bsm_predicator" in src + assert "__builtin_mxc_arrive_gvmcnt(4)" in src + assert "__builtin_mxc_arrive_bsmcnt(2)" in src + assert "__builtin_mxc_barrier_inst();" in src + assert '"MACA_ICMP_EQ"' not in src + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/contrib/mxcc.py b/tilelang/contrib/mxcc.py index f6729778..795cac56 100644 --- a/tilelang/contrib/mxcc.py +++ b/tilelang/contrib/mxcc.py @@ -6,6 +6,7 @@ import re import os +import shlex import subprocess from tilelang.env import MACA_HOME, TILELANG_TEMPLATE_PATH import tvm_ffi @@ -88,6 +89,10 @@ def compile_maca(code, target_format="mcbin", arch=None, options=None, path_targ else: raise ValueError("options must be str or list of str") + extra_env_flags = os.environ.get("TILELANG_MXCC_FLAGS") + if extra_env_flags: + cmd += shlex.split(extra_env_flags) + cmd += ["-D__FAST_HALF_CVT__"] cmd += ["-o", file_target] cmd += [temp_code] diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index ebef8591..0dfbae91 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -297,6 +297,13 @@ def wrapped(*args, **kwargs): ptx_tcgen05_mma_blockscaled_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_blockscaled_ss) ptx_ldmatrix = _tir_op.ptx_ldmatrix ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +maca_ldg_b128_bsm_predicator = _tir_op.maca_ldg_b128_bsm_predicator +maca_ldg_b64_bsm_predicator = _tir_op.maca_ldg_b64_bsm_predicator +maca_ldg_b128_bsm = _tir_op.maca_ldg_b128_bsm +maca_ldg_b64_bsm = _tir_op.maca_ldg_b64_bsm +maca_arrive_gvmcnt = _tir_op.maca_arrive_gvmcnt +maca_arrive_bsmcnt = _tir_op.maca_arrive_bsmcnt +maca_barrier_inst = _tir_op.maca_barrier_inst ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) mma_store = _dtype_forward(_tir_op.mma_store) mma_fill = _dtype_forward(_tir_op.mma_fill) diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi index e5401688..5d463f97 100644 --- a/tilelang/language/tir/ir.pyi +++ b/tilelang/language/tir/ir.pyi @@ -134,6 +134,35 @@ def tvm_store_matrix_sync( def ptx_wait_group(num: int) -> PrimExpr: ... def ptx_commit_group() -> _T: ... def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... +def maca_ldg_b128_bsm_predicator( + dst_addr: PrimExpr, + src_addr: PrimExpr, + offset: _T = 0, + cache_global: _T = True, + cache_shared: _T = True, + evict: _T = False, + wait: _T = True, + cmp_lhs: _T = 1, + cmp_rhs: _T = 1, + cmp_type: str = "MACA_ICMP_EQ", +) -> PrimExpr: ... +def maca_ldg_b64_bsm_predicator( + dst_addr: PrimExpr, + src_addr: PrimExpr, + offset: _T = 0, + cache_global: _T = True, + cache_shared: _T = True, + evict: _T = False, + wait: _T = True, + cmp_lhs: _T = 1, + cmp_rhs: _T = 1, + cmp_type: str = "MACA_ICMP_EQ", +) -> PrimExpr: ... +def maca_ldg_b128_bsm(dst_addr: PrimExpr, src_addr: PrimExpr, offset: _T = 0) -> PrimExpr: ... +def maca_ldg_b64_bsm(dst_addr: PrimExpr, src_addr: PrimExpr, offset: _T = 0) -> PrimExpr: ... +def maca_arrive_gvmcnt(num: _T) -> PrimExpr: ... +def maca_arrive_bsmcnt(num: _T) -> PrimExpr: ... +def maca_barrier_inst() -> PrimExpr: ... def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ... def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index ffd240cd..e1aa4f20 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1580,6 +1580,99 @@ def ptx_cp_async(dst_access_ptr, src_access_ptr, num_elems, predicate=None): return tirx.call_intrin("", tirx.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, num_elems, predicate) +def _maca_call_intrin(op_name, *args): + from tvm import tirx + + return tirx.call_intrin("handle", tirx.op.Op.get(op_name), *args) + + +def maca_ldg_b128_bsm_predicator( + dst_addr, + src_addr, + offset=0, + cache_global=True, + cache_shared=True, + evict=False, + wait=True, + cmp_lhs=1, + cmp_rhs=1, + cmp_type="MACA_ICMP_EQ", +): + """MACA BSM 128-bit global-to-shared load with hardware predication. + + This is a low-level MetaX/MACA primitive for hand-scheduled kernels. The + destination/source arguments should be raw addresses such as + ``T.address_of(S[0])`` or ``T.access_ptr(...)``. ``cmp_type`` is emitted as + a MACA enum token such as ``MACA_ICMP_EQ`` or ``MACA_ICMP_SLT``. + """ + return _maca_call_intrin( + "tl.maca_ldg_b128_bsm_predicator", + dst_addr, + src_addr, + offset, + cache_global, + cache_shared, + evict, + wait, + cmp_lhs, + cmp_rhs, + cmp_type, + ) + + +def maca_ldg_b64_bsm_predicator( + dst_addr, + src_addr, + offset=0, + cache_global=True, + cache_shared=True, + evict=False, + wait=True, + cmp_lhs=1, + cmp_rhs=1, + cmp_type="MACA_ICMP_EQ", +): + """MACA BSM 64-bit global-to-shared load with hardware predication.""" + return _maca_call_intrin( + "tl.maca_ldg_b64_bsm_predicator", + dst_addr, + src_addr, + offset, + cache_global, + cache_shared, + evict, + wait, + cmp_lhs, + cmp_rhs, + cmp_type, + ) + + +def maca_ldg_b128_bsm(dst_addr, src_addr, offset=0): + """Convenience wrapper for an unconditional MACA BSM 128-bit load.""" + return maca_ldg_b128_bsm_predicator(dst_addr, src_addr, offset) + + +def maca_ldg_b64_bsm(dst_addr, src_addr, offset=0): + """Convenience wrapper for an unconditional MACA BSM 64-bit load.""" + return maca_ldg_b64_bsm_predicator(dst_addr, src_addr, offset) + + +def maca_arrive_gvmcnt(num): + """Wait for MACA global-memory operations using ``__builtin_mxc_arrive_gvmcnt``.""" + return _maca_call_intrin("tl.maca_arrive_gvmcnt", num) + + +def maca_arrive_bsmcnt(num): + """Wait for MACA BSM operations using ``__builtin_mxc_arrive_bsmcnt``.""" + return _maca_call_intrin("tl.maca_arrive_bsmcnt", num) + + +def maca_barrier_inst(): + """Emit MACA ``__builtin_mxc_barrier_inst``.""" + return _maca_call_intrin("tl.maca_barrier_inst") + + def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id): """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk From fbf9ab781318a2d5b30e6ab6a7479a797b778b4c Mon Sep 17 00:00:00 2001 From: VitalyR Date: Mon, 18 May 2026 20:29:18 +0800 Subject: [PATCH 2/4] [MetaXGPU] Add MACA GEMM compiler-path layout support Expose MACA GEMM A/B and C fragment layouts to Python and wire the dense TileLang GEMM template path through MACA-specific layouts. Teach MACA codegen to emit tl_gemm calls and derive tl.ptx_cp_async byte widths from access-pointer element types, then add the template header pieces needed by the compiler path. Validation: git diff --cached --check; ./.venv/bin/python -m py_compile tilelang/layout/__init__.py tilelang/layout/swizzle.py tilelang/tileop/gemm/gemm_maca_mma.py testing/maca/tilelibrary/test_tilelang_maca_gemm_template_contract.py testing/maca/language/test_tilelang_language_access_ptr_codegen.py --- src/backend/common/op/reduce.h | 25 +- src/backend/cuda/op/reduce.cc | 16 + src/backend/maca/codegen/codegen_maca.cc | 48 +++ src/backend/maca/op/reduce.cc | 11 + src/backend/rocm/op/reduce.cc | 12 + src/layout/gemm_layouts.cc | 31 +- src/layout/layout.cc | 10 + src/layout/layout.h | 1 + src/op/builtin.cc | 5 + src/op/builtin.h | 12 + src/target/codegen_maca.cc | 149 ++++++++- src/tl_templates/cuda/reduce.h | 6 + src/tl_templates/hip/reduce.h | 6 + src/tl_templates/maca/barrier.h | 167 ++++++++++ src/tl_templates/maca/copy.h | 22 ++ src/tl_templates/maca/gemm.h | 39 ++- src/tl_templates/maca/gemm_wsm.h | 158 +++++++++ src/tl_templates/maca/reduce.h | 6 + src/transform/lower_tile_op.cc | 111 ++++--- ...st_tilelang_language_access_ptr_codegen.py | 47 +++ ...st_tilelang_maca_gemm_template_contract.py | 72 ++++ tilelang/contrib/mxcc.py | 11 +- tilelang/language/gemm_op.py | 4 + tilelang/layout/__init__.py | 2 + tilelang/layout/swizzle.py | 16 + tilelang/maca/intrinsics/layout/mma_layout.py | 4 +- .../intrinsics/macro/mma_macro_generator.py | 2 +- tilelang/maca/op/gemm/gemm_mma.py | 310 ++++++++++++++++-- tilelang/tileop/gemm/gemm_base.py | 7 + 29 files changed, 1208 insertions(+), 102 deletions(-) create mode 100644 src/tl_templates/maca/barrier.h create mode 100644 src/tl_templates/maca/gemm_wsm.h create mode 100644 testing/maca/tilelibrary/test_tilelang_maca_gemm_template_contract.py diff --git a/src/backend/common/op/reduce.h b/src/backend/common/op/reduce.h index 28fa4be2..9a19a88d 100644 --- a/src/backend/common/op/reduce.h +++ b/src/backend/common/op/reduce.h @@ -569,9 +569,16 @@ template struct ReduceLowerer { std::string reducer = reduce::MakeCodegenReducer(op, can_batch_pack ? vsize : 1) .value(); - std::string allreduce = Impl::MakeBatchAllReduce( - reducer, reducing_threads, *scale, thread_offset, - T.thread_bounds->extent, eff_batch, reducing_threads, T.target); + std::string allreduce = + can_batch_pack + ? Impl::MakeBatchAllReduce(reducer, reducing_threads, *scale, + thread_offset, + T.thread_bounds->extent, eff_batch, + reducing_threads, T.target) + : Impl::MakeBatchAllReduceOffset( + reducer, reducing_threads, *scale, thread_offset, + T.thread_bounds->extent, batch, reducing_threads, + T.target); DataType ws_dtype = can_batch_pack ? clear_buffer->dtype.with_lanes(vsize) @@ -674,16 +681,8 @@ template struct ReduceLowerer { } else { for (int chunk = 0; chunk < num_chunks; chunk++) { int64_t flat_offset = static_cast(chunk) * batch; - Array chunk_indices; - for (int d = 0; d < buf_ndim; d++) { - int64_t idx = - (flat_offset / buf_strides[d]) % buf_shape_vals[d]; - chunk_indices.push_back(Integer(idx)); - } - PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(clear_buffer, chunk_indices)}); - - Array args = {StringImm(allreduce), ptr}; + Array args = {StringImm(allreduce), clear_buffer->data, + Integer(flat_offset)}; if (need_workspace) { args.push_back(workspace); } diff --git a/src/backend/cuda/op/reduce.cc b/src/backend/cuda/op/reduce.cc index 04fc9d22..9a7d556a 100644 --- a/src/backend/cuda/op/reduce.cc +++ b/src/backend/cuda/op/reduce.cc @@ -45,6 +45,22 @@ struct Reduce : backend::ReduceLowerer { return ss.str(); } + static std::string + MakeBatchAllReduceOffset(std::string reducer, int reducing_threads, int scale, + PrimExpr thread_offset, PrimExpr all_threads, + int batch, int workspace_stride, Target target) { + std::stringstream ss; + ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", " + << scale << ", " << thread_offset; + if (TargetHasSMVersionGE(target, 90)) { + ss << ", tl::NamedBarrier<" << all_threads << ">"; + } else { + ss << ", tl::SyncThreadsBarrier"; + } + ss << ", " << batch << ", " << workspace_stride << ">::run_batch_offset"; + return ss.str(); + } + static std::string MakeScalarAllReduce(std::string reducer, int reducing_threads, int scale, PrimExpr thread_offset, diff --git a/src/backend/maca/codegen/codegen_maca.cc b/src/backend/maca/codegen/codegen_maca.cc index e34521c1..6ee22be4 100644 --- a/src/backend/maca/codegen/codegen_maca.cc +++ b/src/backend/maca/codegen/codegen_maca.cc @@ -1578,6 +1578,39 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { << "maca_barrier_arrive_and_wait expects 1 argument (bar)"; std::string dummyRet = this->PrintExpr(op->args[0]); this->stream << "barrier_arrive_and_wait(" << dummyRet << ");\n"; + } else if (op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) || + op->op.same_as(tl::maca_ldg_b64_bsm_predicator())) { + ICHECK_EQ(op->args.size(), 10U) + << "maca_ldg_b{64,128}_bsm_predicator expects 10 arguments " + ""; + std::string builtin_name = + op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) + ? "__builtin_mxc_ldg_b128_bsm_predicator" + : "__builtin_mxc_ldg_b64_bsm_predicator"; + this->PrintIndent(); + this->stream << builtin_name << "("; + for (size_t i = 0; i < 9; ++i) { + if (i > 0) { + this->stream << ", "; + } + this->stream << this->PrintExpr(op->args[i]); + } + this->stream << ", " << Downcast(op->args[9])->value << ");\n"; + } else if (op->op.same_as(tl::maca_arrive_gvmcnt())) { + ICHECK_EQ(op->args.size(), 1U); + this->PrintIndent(); + this->stream << "__builtin_mxc_arrive_gvmcnt(" + << this->PrintExpr(op->args[0]) << ");\n"; + } else if (op->op.same_as(tl::maca_arrive_bsmcnt())) { + ICHECK_EQ(op->args.size(), 1U); + this->PrintIndent(); + this->stream << "__builtin_mxc_arrive_bsmcnt(" + << this->PrintExpr(op->args[0]) << ");\n"; + } else if (op->op.same_as(tl::maca_barrier_inst())) { + ICHECK_EQ(op->args.size(), 0U); + this->PrintIndent(); + this->stream << "__builtin_mxc_barrier_inst();\n"; } else if (op->op.same_as(builtin::create_barriers())) { this->PrintIndent(); int barrier_count = Downcast(op->args[0])->value; @@ -2016,6 +2049,21 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ")"; } else if (op->op.same_as(builtin::thread_return())) { os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_wsm())) { + ICHECK(op->args.size() == 7) + << "tl_gemm_wsm expects 7 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments { return ss.str(); } + static std::string + MakeBatchAllReduceOffset(std::string reducer, int reducing_threads, int scale, + PrimExpr thread_offset, PrimExpr all_threads, + int batch, int workspace_stride, Target target) { + std::stringstream ss; + ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", " + << scale << ", " << thread_offset << ", tl::SyncThreadsBarrier" + << ", " << batch << ", " << workspace_stride << ">::run_batch_offset"; + return ss.str(); + } + static std::string MakeScalarAllReduce(std::string reducer, int reducing_threads, int scale, PrimExpr thread_offset, diff --git a/src/backend/rocm/op/reduce.cc b/src/backend/rocm/op/reduce.cc index 41d85476..98ce4255 100644 --- a/src/backend/rocm/op/reduce.cc +++ b/src/backend/rocm/op/reduce.cc @@ -33,6 +33,18 @@ struct Reduce : backend::ReduceLowerer { return ss.str(); } + static std::string MakeBatchAllReduceOffset(std::string reducer, + int reducing_threads, int scale, + PrimExpr thread_offset, PrimExpr, + int batch, int workspace_stride, + Target) { + std::stringstream ss; + ss << "tl::AllReduce<" << reducer << ", " << reducing_threads << ", " + << scale << ", " << thread_offset << ", " << batch << ", " + << workspace_stride << ">::run_batch_offset"; + return ss.str(); + } + static std::string MakeScalarAllReduce(std::string reducer, int reducing_threads, int scale, PrimExpr thread_offset, PrimExpr, diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index fb097406..064a0b80 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -213,10 +213,12 @@ Fragment makeGemmFragmentCMACA(const int block_m, const int block_n, PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i; PrimExpr index = FloorMod(j->var, 4); auto base_layout = Fragment({i, j}, {index}, forward_thread, rep); - auto warp_layout = + // Match CUTE partition_shape_C accumulator storage: the thread-group + // repeat must be applied before the per-thread 16x16 tile repeats. + auto thread_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); auto block_layout = - warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); + thread_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); return block_layout; } @@ -1034,6 +1036,31 @@ Layout makeGemmABLayoutMACA(int mat_stride, int mat_continuous, int continuity, } } +Layout makeMacaGemmABLayout(const Buffer &buffer, int kfactor) { + auto info = GetSwizzleShapeInfoChecked(buffer); + auto mat_stride = static_cast(info.stride); + auto mat_continuous = static_cast(info.continuous); + Layout base; + if (info.element_size == 16 && kfactor == 1 && mat_continuous % 64 == 0 && + mat_stride % 64 != 32) { + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + constexpr int vector_size = 4; + PrimExpr ts = FloorDiv(i, 64); + PrimExpr s = FloorMod(FloorDiv(i, vector_size), 16); + PrimExpr tc = FloorDiv(j, 16); + PrimExpr c = FloorMod(j, 16); + PrimExpr vec = FloorMod(i, vector_size); + PrimExpr s_swizzle = xor16x16(s, c); + PrimExpr index = vec + (s_swizzle + c * 16) * vector_size; + base = Layout(Array{mat_stride, mat_continuous}, {tc, ts, index}); + } else { + base = makeGemmABLayoutMACA(mat_stride, mat_continuous, mat_continuous, + info.element_size, kfactor); + } + return ExpandLayout2D(base, buffer); +} + Layout makeSwizzledLayout(const Buffer &buffer, bool k_inner, bool allow_pad) { auto info = GetSwizzleShapeInfoChecked(buffer); Layout base; diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 4e9c555f..819dd6c3 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -1154,6 +1154,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("tl.make_linear_layout", [](Array shape) { return makeLinearLayout(shape); }) + .def("tl.make_maca_gemm_ab_layout", + [](const Buffer &buffer, int kfactor) { + return makeMacaGemmABLayout(buffer, kfactor); + }) + .def("tl.make_maca_gemm_fragment_c", + [](int block_m, int block_n, int warp_m, int warp_n, + int element_size) { + return makeGemmFragmentCMACA(block_m, block_n, warp_m, warp_n, + element_size); + }) .def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); }) .def("tl.make_gemm_fragment_8x8_transposed", []() { return makeGemmFragment8x8Transposed(); }) diff --git a/src/layout/layout.h b/src/layout/layout.h index 381f9569..7feea4f6 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -268,6 +268,7 @@ Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kPack); Layout makeGemmABLayoutMACA(int mat_stride, int mat_continuous, int continuity, int element_size, int kfactor); +Layout makeMacaGemmABLayout(const Buffer &buffer, int kfactor); Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 91614d51..ec998027 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -568,6 +568,11 @@ TIR_DEFINE_TL_BUILTIN(loop_break) TIR_DEFINE_TL_BUILTIN(tl_gemm).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(tl_gemm_wsm) + .set_num_inputs(7) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(tl_gemm_sp) .set_num_inputs(5) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 39131d80..4db4ad0e 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -1005,6 +1005,18 @@ TVM_DLL const Op &tvm_rdna_wmma_store(); */ TVM_DLL const Op &tl_gemm(); +/*! + * \brief tilelang intrinsic for MACA GEMM variants that need a wider generated + * source ABI than tl_gemm. + * + * The first use is the HGEMM WSM-aware compiler-path probe. It preserves the + * template-call style of tl_gemm while adding pointer arguments for generated + * shared/WSM storage plus the original global A/B source operands: + * T.call_intrin("handle", "tl.tl_gemm_wsm", op_instance_str, + * A_ptr, B_ptr, C_ptr, WSM_ptr, A_source_ptr, B_source_ptr) + */ +TVM_DLL const Op &tl_gemm_wsm(); + /*! * \brief tilelang intrinsic for sparse matrix multiplication (GEMM with * sparsity). diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index ff4ae2b4..df76cb93 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -12,6 +12,9 @@ #include #include +#include +#include +#include #include #include #include @@ -27,6 +30,94 @@ namespace codegen { using namespace tvm::tl::codegen; using namespace ffi; +namespace { + +bool IsValidCPAsyncTransferBytes(int64_t bytes) { + return bytes == 4 || bytes == 8 || bytes == 16; +} + +int GetCPAsyncBytesImm(const PrimExpr &bytes_expr, const char *op_name) { + const auto *bytes_imm = bytes_expr.as(); + ICHECK(bytes_imm) << op_name << " byte count must be IntImm, but got " + << bytes_expr; + ICHECK_GT(bytes_imm->value, 0); + ICHECK(IsValidCPAsyncTransferBytes(bytes_imm->value)) + << op_name << " byte count must be one of {4, 8, 16}, but got " + << bytes_imm->value; + return static_cast(bytes_imm->value); +} + +bool IsEnabledEnv(const char *name) { + const char *value = std::getenv(name); + if (value == nullptr) { + return false; + } + std::string text(value); + return text == "1" || text == "true" || text == "TRUE" || text == "on" || + text == "ON" || text == "yes" || text == "YES"; +} + +std::optional GetAccessPtrElementType(const PrimExpr &expr) { + const auto *ptr_call = expr.as(); + if (ptr_call == nullptr) { + return std::nullopt; + } + if (ptr_call->op.same_as(builtin::address_of())) { + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "address_of arg must be BufferLoad"; + return buffer_load->buffer->dtype; + } + if (ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(!ptr_call->args.empty()); + return ptr_call->args[0].dtype(); + } + if (ptr_call->op.same_as(tl::access_ptr())) { + ICHECK_EQ(ptr_call->args.size(), 3U) + << "tl.access_ptr expects 3 args: (BufferLoad, extent, rw_mask)"; + const auto *buffer_load = ptr_call->args[0].as(); + ICHECK(buffer_load) << "tl.access_ptr arg0 must be BufferLoad"; + return buffer_load->buffer->dtype; + } + return std::nullopt; +} + +int GetTileLangCPAsyncTransferBytes(const CallNode *op) { + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, [predicate])"; + const auto *num_elems_imm = op->args[2].as(); + ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " + << op->args[2]; + int64_t num_elems = num_elems_imm->value; + ICHECK_GT(num_elems, 0); + + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + auto src_elem_type = GetAccessPtrElementType(op->args[1]); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "tl::ptx_cp_async expects address_of, tl.access_ptr, or " + "tvm_access_ptr operands"; + + int64_t dst_total_bits = + num_elems * dst_elem_type.value().bits() * dst_elem_type.value().lanes(); + int64_t src_total_bits = + num_elems * src_elem_type.value().bits() * src_elem_type.value().lanes(); + ICHECK_EQ(dst_total_bits, src_total_bits) + << "tl::ptx_cp_async requires src/dst transfer widths to match, but got " + << dst_total_bits << " vs " << src_total_bits << " bits"; + ICHECK_EQ(dst_total_bits % 8, 0) + << "tl::ptx_cp_async requires byte-aligned transfers, but got " + << dst_total_bits << " bits"; + + int64_t total_bytes = dst_total_bits / 8; + ICHECK(IsValidCPAsyncTransferBytes(total_bytes)) + << "tl::ptx_cp_async requires a final PTX byte width in {4, 8, 16}, but " + "got " + << total_bytes; + return static_cast(total_bytes); +} + +} // namespace + struct MACAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { @@ -313,6 +404,7 @@ std::string CodeGenTileLangMACA::Finish() { } decl_stream << "#include \n"; + decl_stream << "#include \n"; if (enable_sparse_gemm_) { decl_stream << "#include \n"; } @@ -1519,6 +1611,21 @@ void CodeGenTileLangMACA::PrintVecStore(const BufferNode *buffer, DataType t, scope = GetPtrStorageScope(buffer->data); } + if (IsEnabledEnv("TILELANG_MACA_REWRITE_PACKED_B_COPY_TO_BSM") && + scope == "shared.dyn" && t.is_float16() && t.lanes() == 4 && + value.find("*(uint2*)(b +") != std::string::npos) { + auto buffer_ref = this->GetBufferRef(t, buffer, base); + this->PrintIndent(); + this->stream << "__builtin_mxc_ldg_b64_bsm_predicator((unsigned char*)(&(" + << buffer_ref << ")), (unsigned char*)(&(" << value + << ")), 0, true, true, false, true, 1, 1, MACA_ICMP_EQ);\n"; + this->PrintIndent(); + this->stream << "__builtin_mxc_arrive_gvmcnt(0);\n"; + this->PrintIndent(); + this->stream << "__builtin_mxc_barrier_inst();\n"; + return; + } + if (scope != "global" || t.bits() * t.lanes() <= 128) { this->CodeGenC::PrintVecStore(buffer, t, base, value); return; @@ -1590,7 +1697,8 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); - std::string size = this->PrintExpr(op->args[2]); + int total_bytes = GetCPAsyncBytesImm(op->args[2], "ptx_cp_async"); + std::string size = std::to_string(total_bytes); this->PrintIndent(); if (op->args.size() == 3) { @@ -1603,14 +1711,11 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { } } else if (op->op.same_as(tl::ptx_cp_async())) { // TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr, - // args[2] = bytes, args[3] = predicate (optional) - ICHECK(op->args.size() == 3 || op->args.size() == 4) - << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, bytes, [predicate])"; - + // args[2] = num_elems, args[3] = predicate (optional) + int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); - std::string size = this->PrintExpr(op->args[2]); + std::string size = std::to_string(total_bytes); this->PrintIndent(); if (op->args.size() == 3) { @@ -1649,9 +1754,10 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { << "maca_ldg_b{64,128}_bsm_predicator expects 10 arguments " ""; - std::string builtin_name = op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) - ? "__builtin_mxc_ldg_b128_bsm_predicator" - : "__builtin_mxc_ldg_b64_bsm_predicator"; + std::string builtin_name = + op->op.same_as(tl::maca_ldg_b128_bsm_predicator()) + ? "__builtin_mxc_ldg_b128_bsm_predicator" + : "__builtin_mxc_ldg_b64_bsm_predicator"; this->PrintIndent(); this->stream << builtin_name << "("; for (size_t i = 0; i < 9; ++i) { @@ -1664,13 +1770,13 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::maca_arrive_gvmcnt())) { ICHECK_EQ(op->args.size(), 1U); this->PrintIndent(); - this->stream << "__builtin_mxc_arrive_gvmcnt(" << this->PrintExpr(op->args[0]) - << ");\n"; + this->stream << "__builtin_mxc_arrive_gvmcnt(" + << this->PrintExpr(op->args[0]) << ");\n"; } else if (op->op.same_as(tl::maca_arrive_bsmcnt())) { ICHECK_EQ(op->args.size(), 1U); this->PrintIndent(); - this->stream << "__builtin_mxc_arrive_bsmcnt(" << this->PrintExpr(op->args[0]) - << ");\n"; + this->stream << "__builtin_mxc_arrive_bsmcnt(" + << this->PrintExpr(op->args[0]) << ");\n"; } else if (op->op.same_as(tl::maca_barrier_inst())) { ICHECK_EQ(op->args.size(), 0U); this->PrintIndent(); @@ -2376,6 +2482,21 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ")"; } else if (op->op.same_as(builtin::thread_return())) { os << "return"; + } else if (op->op.same_as(tl::tl_gemm())) { + ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); + } else if (op->op.same_as(tl::tl_gemm_wsm())) { + ICHECK(op->args.size() == 7) + << "tl_gemm_wsm expects 7 arguments , but got " + << op->args.size(); + auto op_instance = Downcast(op->args[0]); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments + static TL_DEVICE void run_batch_offset(T *x, int offset, + T *red_buf = nullptr) { + run_batch(x + offset, red_buf); + } + private: using Next = AllReduce; diff --git a/src/tl_templates/hip/reduce.h b/src/tl_templates/hip/reduce.h index dd12a1b8..0f39fe3c 100644 --- a/src/tl_templates/hip/reduce.h +++ b/src/tl_templates/hip/reduce.h @@ -145,6 +145,12 @@ struct AllReduce { workspace_stride>::run_batch(x, red_buf); } } + + template + static __device__ void run_batch_offset(T *x, int base_offset, + T *red_buf = nullptr) { + run_batch(x + base_offset, red_buf); + } }; template struct CumSum1D { diff --git a/src/tl_templates/maca/barrier.h b/src/tl_templates/maca/barrier.h new file mode 100644 index 00000000..b24f5db8 --- /dev/null +++ b/src/tl_templates/maca/barrier.h @@ -0,0 +1,167 @@ +// Copyright (c) 2025 MetaX Integrated Circuits (Shanghai) Co., Ltd. All rights +// reserved. + +#pragma once + +#include "common.h" + +namespace tl { +TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count); +TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit); +TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit); +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier); +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, + uint32_t pred); +TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes); +TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes); +template +TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar); +template +TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar); +TL_DEVICE void fence_proxy_async(); +TL_DEVICE void fence_barrier_init(); +} // namespace tl + +struct Barrier { + uint64_t value; + + TL_DEVICE void init(uint32_t arrive_count) { + tl::mbarrier_init(value, arrive_count); + } + TL_DEVICE void arrive() { tl::mbarrier_arrive(value); } + TL_DEVICE void arrive(int cta_id, uint32_t pred) { + tl::mbarrier_arrive(value, cta_id, pred); + } + TL_DEVICE void wait(int phase_bit) { tl::mbarrier_wait(value, phase_bit); } + TL_DEVICE void expect_transaction(uint32_t transaction_bytes) { + tl::mbarrier_expect_tx(value, transaction_bytes); + } + TL_DEVICE void arrive_and_expect_tx(uint32_t transaction_bytes) { + tl::mbarrier_arrive_expect_tx(value, transaction_bytes); + } +}; + +namespace tl { + +TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.init.shared.b64 [%1], %0;" + : + : "r"(arrive_count), "r"(smem_int_ptr) + : "memory"); +} + +TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint32_t waitComplete; + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_int_ptr), "r"(phase_bit) + : "memory"); + return waitComplete; +} + +TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) { + if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + uint32_t ticks = 0x989680; + asm volatile("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT_%=: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE_%=; \n\t" + "bra LAB_WAIT_%=; \n\t" + "DONE_%=: \n\t" + "}" + : + : "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks) + : "memory"); + } +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" + : + : "r"(smem_int_ptr) + : "memory"); +} + +TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id, + uint32_t pred) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + if (pred) { + asm volatile("{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_int_ptr), "r"(cta_id) + : "memory"); + } +} + +TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr) + : "memory"); +} + +TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier, + uint32_t transaction_bytes) { + uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;" + : + : "r"(transaction_bytes), "r"(smem_int_ptr) + : "memory"); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" + : + : "r"(smem_int_mbar) + : "memory"); +} + +template +TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) { + uint32_t smem_int_mbar; + if constexpr (std::is_pointer_v) { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(smem_mbar)); + } else { + smem_int_mbar = smem_ptr_to_uint(reinterpret_cast(&smem_mbar)); + } + asm volatile("{\n\t" + "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_int_mbar) + : "memory"); +} + +TL_DEVICE void fence_proxy_async() { + asm volatile("fence.proxy.async.shared::cta;" : : : "memory"); +} + +TL_DEVICE void fence_barrier_init() { + asm volatile("fence.mbarrier_init.release.cluster;" : : : "memory"); +} + +} // namespace tl diff --git a/src/tl_templates/maca/copy.h b/src/tl_templates/maca/copy.h index e592b6dd..17247226 100644 --- a/src/tl_templates/maca/copy.h +++ b/src/tl_templates/maca/copy.h @@ -1,8 +1,30 @@ // Copyright (c) 2025 MetaX Integrated Circuits (Shanghai) Co., Ltd. All rights // reserved. #pragma once +#include "barrier.h" #include "common.h" +#include namespace tl { +TL_DEVICE void cp_async_commit() { asm volatile("" ::: "memory"); } + +template TL_DEVICE void cp_async_wait() { + mctlass::arch::maca_cp_async_wait(); + __syncthreadshared(); +} + +template +TL_DEVICE void cp_async_gs(void *lds_base_ptr, void const *global_base_ptr) { + static_assert(N == 16 || N == 8 || N == 4); + mctlass::arch::maca_cp_async_zfill(lds_base_ptr, global_base_ptr); +} + +template +TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, + void const *global_base_ptr, bool cond) { + static_assert(N == 16 || N == 8 || N == 4); + mctlass::arch::maca_cp_async_zfill(lds_base_ptr, global_base_ptr, cond); +} + // Global memory load intrinsics with explicit vector widths // MACA-compatible implementation using standard pointer casts // load_global_32: Load 32 bits, return uint32_t diff --git a/src/tl_templates/maca/gemm.h b/src/tl_templates/maca/gemm.h index 681922ba..577d1f03 100644 --- a/src/tl_templates/maca/gemm.h +++ b/src/tl_templates/maca/gemm.h @@ -3,6 +3,7 @@ #pragma once +#include "barrier.h" #include "common.h" #include #include @@ -16,6 +17,10 @@ template <> struct DispatchInstruction { using MMA = MMA_Atom>; }; +template <> struct DispatchInstruction<__half, __half, float> { + using MMA = MMA_Atom>; +}; + template struct OperandTraits; @@ -82,6 +87,20 @@ class GemmTensorOp { Layout, Int, _1>>, Layout>>; + template + static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { + return layout; + } + + template + static CUTE_DEVICE auto + remove_swizzle(ComposedLayout const &layout) { + if constexpr (sizeof(B_type) == 2) { + return layout.layout_b(); + } + return layout; + } + CUTE_DEVICE static void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), @@ -107,15 +126,20 @@ class GemmTensorOp { Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); if constexpr (clear_accum) { clear(acc); } + copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); + CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); - copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); - gemm(tiled_mma, tCrA(_, _, k), tCrB(_, _, k), acc); + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } @@ -138,6 +162,7 @@ class GemmTensorOp { Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); + auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); @@ -146,9 +171,13 @@ class GemmTensorOp { clear(acc); } + copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); + CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { - copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); - gemm(tiled_mma, tCrA(_, _, k), tCrB(_, _, k), acc); + if (k < size<2>(tCrA) - 1) { + copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); + } + gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } }; @@ -175,3 +204,5 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { MMA::body_rs(pA, pB, accum); } } // namespace tl + +#include "gemm_wsm.h" diff --git a/src/tl_templates/maca/gemm_wsm.h b/src/tl_templates/maca/gemm_wsm.h new file mode 100644 index 00000000..ca2ca389 --- /dev/null +++ b/src/tl_templates/maca/gemm_wsm.h @@ -0,0 +1,158 @@ +#pragma once + +#include "gemm.h" + +namespace tl { + +using WsmFloat4 = __NATIVE_VECTOR__(4, float); +using WsmUint2 = __NATIVE_VECTOR__(2, uint); +using WsmUint4 = __NATIVE_VECTOR__(4, uint); + +template +MCTLASS_DEVICE void gemm_ss_wsm(A_type *pA, B_type *pB, C_type *accum, + WSM_type *wsm, A_source_type *global_A, + B_source_type *global_B) { + constexpr int Stage = 4; + static_assert(!trans_A && !trans_B, + "gemm_ss_wsm only supports non-transposed A/B operands"); + static_assert(num_warp_m == 1 && num_warp_n == 1, + "gemm_ss_wsm currently uses a fixed single-warp layout"); + static_assert(kPack == 8, + "gemm_ss_wsm expects kPack=8 for its K/8 BSM predicate"); + static_assert(AStrideElements % 8 == 0, + "gemm_ss_wsm requires AStrideElements divisible by 8"); + static_assert(sizeof(A_type) == 2 && sizeof(B_type) == 2, + "gemm_ss_wsm expects fp16 A/B operands"); + static_assert(sizeof(C_type) == 4, + "gemm_ss_wsm expects fp32 accumulator fragments"); + static_assert(Stage == 4, "gemm_ss_wsm hardcodes a 4-stage schedule"); + uchar *WSM = reinterpret_cast(wsm); + uchar *APtr = const_cast(reinterpret_cast(global_A)); + uchar *BPtr = const_cast(reinterpret_cast(global_B)); + const int tid = threadIdx.x; + const int slot = __builtin_mxc_readfirstlane(tid / 64); + const int lane = tid & 63; + uchar *WSM_Ldg = WSM + slot * 0x400; + uchar *WSM_lds = WSM; + int ALdsOffset[4]; + int BLdsOffset[4]; + int ALdgOffset[2][Stage]; + int BLdgOffset[2][Stage]; + const int lda_vec = AStrideElements / 8; + const int B_row_offset = lane + slot * 64 * (N / 16); + const int lds_row = tid & 15; + int lds_col[4]; + +#pragma unroll + for (int stage_i = 0; stage_i < Stage; ++stage_i) { + ALdgOffset[0][stage_i] = (tid + 16 * lda_vec * stage_i) * 16; + ALdgOffset[1][stage_i] = (tid + 16 * lda_vec * (4 + stage_i)) * 16; + BLdgOffset[0][stage_i] = (B_row_offset + 64 * stage_i) * 16; + BLdgOffset[1][stage_i] = (B_row_offset + 64 * (4 + stage_i)) * 16; + } + +#pragma unroll + for (int i = 0; i < 4; ++i) { + lds_col[i] = (4 * i + lane / 16) ^ lds_row; + } + +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int tmp = lds_row * 16 + lds_col[i]; + ALdsOffset[i] = (tmp + (slot / 2) * 0x1000 / 16) * 16; + BLdsOffset[i] = (tmp + 0x2000 / 16 + (slot & 1) * 0x1000 / 16) * 16; + } + +#pragma unroll + for (int stage_i = 0; stage_i < Stage; ++stage_i) { + __builtin_mxc_ldg_b128_bsm_predicator( + WSM_Ldg + 0x4000 * stage_i + 0x0000, APtr + ALdgOffset[0][stage_i], 0, + true, true, false, true, 0, K / 8, MACA_ICMP_SLT); + __builtin_mxc_ldg_b128_bsm_predicator( + WSM_Ldg + 0x4000 * stage_i + 0x1000, APtr + ALdgOffset[1][stage_i], 0, + true, true, false, true, 0, K / 8, MACA_ICMP_SLT); + __builtin_mxc_ldg_b128_bsm_predicator( + WSM_Ldg + 0x4000 * stage_i + 0x2000, BPtr + BLdgOffset[0][stage_i], 0, + true, true, false, true, stage_i * 16, N, MACA_ICMP_SLT); + __builtin_mxc_ldg_b128_bsm_predicator( + WSM_Ldg + 0x4000 * stage_i + 0x3000, BPtr + BLdgOffset[1][stage_i], 0, + true, true, false, true, stage_i * 16 + 64, N, MACA_ICMP_SLT); + } + + __builtin_mxc_arrive_gvmcnt(4 * (Stage - 1)); + __builtin_mxc_barrier_inst(); + + WsmFloat4 C_f32[4][4]; +#pragma unroll + for (int row = 0; row < 4; ++row) { +#pragma unroll + for (int col = 0; col < 4; ++col) { + if constexpr (clear_accum) { + C_f32[row][col] = WsmFloat4{0.0f, 0.0f, 0.0f, 0.0f}; + } else { + C_f32[row][col] = reinterpret_cast(accum)[row * 4 + col]; + } + } + } + +#pragma unroll + for (int stage_i = 0; stage_i < Stage; ++stage_i) { + WsmUint4 a_frag0 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + ALdsOffset[0]); + WsmUint4 a_frag1 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + ALdsOffset[1]); + WsmUint4 a_frag2 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + ALdsOffset[2]); + WsmUint4 a_frag3 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + ALdsOffset[3]); + WsmUint4 b_frag0 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + BLdsOffset[0]); + WsmUint4 b_frag1 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + BLdsOffset[1]); + WsmUint4 b_frag2 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + BLdsOffset[2]); + WsmUint4 b_frag3 = *reinterpret_cast( + WSM_lds + 0x4000 * stage_i + BLdsOffset[3]); + WsmUint2 mma_a[4] = {{a_frag0[0], a_frag0[1]}, + {a_frag1[0], a_frag1[1]}, + {a_frag2[0], a_frag2[1]}, + {a_frag3[0], a_frag3[1]}}; + WsmUint2 mma_b[4] = {{b_frag0[0], b_frag0[1]}, + {b_frag1[0], b_frag1[1]}, + {b_frag2[0], b_frag2[1]}, + {b_frag3[0], b_frag3[1]}}; + +#pragma unroll + for (int row = 0; row < 4; ++row) { +#pragma unroll + for (int col = 0; col < 4; ++col) { + C_f32[row][col] = __builtin_mxc_mma_16x16x16f16(mma_a[row], mma_b[col], + C_f32[row][col]); + } + } + } + + reinterpret_cast(accum)[0] = C_f32[0][0]; + reinterpret_cast(accum)[1] = C_f32[0][1]; + reinterpret_cast(accum)[2] = C_f32[0][2]; + reinterpret_cast(accum)[3] = C_f32[0][3]; + reinterpret_cast(accum)[4] = C_f32[1][0]; + reinterpret_cast(accum)[5] = C_f32[1][1]; + reinterpret_cast(accum)[6] = C_f32[1][2]; + reinterpret_cast(accum)[7] = C_f32[1][3]; + reinterpret_cast(accum)[8] = C_f32[2][0]; + reinterpret_cast(accum)[9] = C_f32[2][1]; + reinterpret_cast(accum)[10] = C_f32[2][2]; + reinterpret_cast(accum)[11] = C_f32[2][3]; + reinterpret_cast(accum)[12] = C_f32[3][0]; + reinterpret_cast(accum)[13] = C_f32[3][1]; + reinterpret_cast(accum)[14] = C_f32[3][2]; + reinterpret_cast(accum)[15] = C_f32[3][3]; + (void)pA; + (void)pB; +} + +} // namespace tl diff --git a/src/tl_templates/maca/reduce.h b/src/tl_templates/maca/reduce.h index 42227992..1b4c90bb 100644 --- a/src/tl_templates/maca/reduce.h +++ b/src/tl_templates/maca/reduce.h @@ -113,6 +113,12 @@ struct AllReduce { } } + template + static TL_DEVICE void run_batch_offset(T *x, int offset, + T *red_buf = nullptr) { + run_batch(x + offset, red_buf); + } + private: using Next = AllReduce; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 8756a712..b9d1ae7a 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -421,9 +421,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } int CheckAndGetBufferRowSize(const Buffer &buffer) { - ICHECK(buffer->shape.size() >= 2) + ICHECK(buffer->shape.size() >= 1) << "The dimension of Buffer \"" << buffer->name << "\" with shape " - << buffer->shape << " should be at least 2"; + << buffer->shape << " should be at least 1"; auto dim = buffer->shape.size(); auto buffer_row_size = buffer->shape[dim - 1].as()->value; @@ -435,6 +435,59 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool rewritten{false}; }; + PrimExpr FlattenPhysicalIndices(const Array &indices, + const Array &shape) { + ICHECK_EQ(indices.size(), shape.size()) + << "Indices size and shape size must match for physical buffer access " + << "but got indices size: " << indices.size() + << " and shape size: " << shape.size(); + + PrimExpr offset = 0; + PrimExpr stride = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + offset += indices[i] * stride; + stride *= shape[i]; + } + return analyzer_->Simplify(offset); + } + + Array LinearOffsetToIndices(const PrimExpr &offset, + const Array &shape) { + Array indices; + PrimExpr remaining = offset; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + indices.insert(indices.begin(), floormod(remaining, shape[i])); + remaining = floordiv(remaining, shape[i]); + } + return indices; + } + + Array BuildPhysicalIndices(const Array &forward_indices, + const Array &target_shape, + const PrimExpr &linear_offset) { + if (target_shape.size() == forward_indices.size() + 1) { + PrimExpr layout_extent = 1; + for (size_t i = 1; i < target_shape.size(); ++i) { + layout_extent *= target_shape[i]; + } + + Array indices; + indices.push_back( + analyzer_->Simplify(floordiv(linear_offset, layout_extent))); + for (const auto &forward_index : forward_indices) { + indices.push_back(forward_index); + } + return indices; + } + + ICHECK_EQ(target_shape.size(), forward_indices.size()) + << "Remapped access pointer rank mismatch: forward indices size " + << forward_indices.size() << ", target shape size " + << target_shape.size(); + PrimExpr new_offset = FlattenPhysicalIndices(forward_indices, target_shape); + return LinearOffsetToIndices(new_offset, target_shape); + } + AccessPtrResult HandleAccessPtrAndOffset(const PrimExpr &access_ptr, const Optional &offset = std::nullopt, @@ -510,20 +563,14 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } // Apply layout transformation auto forward_indices = layout->Forward(multi_dim_indices); - PrimExpr new_offset = 0; - PrimExpr stride_offset = 1; - for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { - new_offset += forward_indices[i] * stride_offset; - stride_offset *= new_shape[i]; - } - new_offset = analyzer_->Simplify(new_offset); - Array new_indices; + Array new_indices = + BuildPhysicalIndices(forward_indices, new_shape, elem_offset); layout_remap_.Set(new_buffer, layout); // Build new tvm_access_ptr call with new buffer and offset Array new_args = access_ptr_call->args; new_args.Set(1, new_buffer->data); // Replace data var - new_args.Set(2, new_offset); // Replace offset + new_args.Set(2, FlattenPhysicalIndices(new_indices, new_shape)); result.rewritten = true; result.expr = Call(access_ptr_call->dtype, access_ptr_call->op, new_args, access_ptr_call->annotations, access_ptr_call->span); @@ -602,20 +649,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } auto forward_indices = layout.value()->Forward(multi_dim_indices); - PrimExpr new_offset = 0; - PrimExpr stride_offset = 1; - for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { - new_offset += forward_indices[i] * stride_offset; - stride_offset *= new_shape[i]; - } - new_offset = analyzer_->Simplify(new_offset); - - Array new_indices; - for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { - new_indices.insert(new_indices.begin(), - floormod(new_offset, new_shape[i])); - new_offset = floordiv(new_offset, new_shape[i]); - } + Array new_indices = + BuildPhysicalIndices(forward_indices, new_shape, smem_offset); Array new_args = {BufferLoad(new_buffer, new_indices)}; if (buffer_remap_.count(remap_key)) { @@ -705,20 +740,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } auto forward_indices = layout.value()->Forward(multi_dim_indices); - PrimExpr new_offset = 0; - PrimExpr stride_offset = 1; - for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { - new_offset += forward_indices[i] * stride_offset; - stride_offset *= new_shape[i]; - } - new_offset = analyzer_->Simplify(new_offset); - - Array new_indices; - for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { - new_indices.insert(new_indices.begin(), - floormod(new_offset, new_shape[i])); - new_offset = floordiv(new_offset, new_shape[i]); - } + Array new_indices = + BuildPhysicalIndices(forward_indices, new_shape, smem_offset); Array new_args = {BufferLoad(new_buffer, new_indices), extent, rw_mask}; @@ -1001,6 +1024,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return call; } + if (op->op.same_as(builtin::address_of()) || + op->op.same_as(tl::access_ptr())) { + auto access_ptr = tvm::ffi::GetRef(op); + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, op->dtype); + if (new_access_ptr.rewritten) { + return new_access_ptr.expr; + } + } + // Default: visit normally auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); return call; diff --git a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py index c2eb99e6..9ead9b30 100644 --- a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py +++ b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py @@ -3,6 +3,10 @@ import tilelang.testing import pytest from tilelang import tvm +from tilelang.utils.target import check_maca_availability + + +requires_maca = pytest.mark.skipif(not check_maca_availability(), reason="Requires MACA") @tilelang.testing.requires_cuda @@ -165,6 +169,7 @@ def main( assert "tl::cp_async_wait<0>" not in src, "Did not expect async_copy lowering to auto-emit wait" +@requires_maca def test_maca_bsm_intrinsics_codegen(): """Smoke-test codegen for the MACA BSM builtin wrappers.""" @@ -203,5 +208,47 @@ def main( assert '"MACA_ICMP_EQ"' not in src +@requires_maca +def test_maca_bsm_byte_view_feeds_gemm_codegen(): + """BSM byte staging can alias a half view consumed by MACA GEMM lowering.""" + + @T.prim_func + def main( + A: T.Tensor((128, 64), T.float16), + B: T.Tensor((128, 64), T.float16), + C: T.Tensor((128, 128), T.float16), + ): + with T.Kernel(1, 1, threads=256): + A_shared = T.alloc_shared((128, 64), T.float16) + B_storage = T.alloc_shared((128, 128), T.uint8) + B_shared = T.view(B_storage, (128, 64), dtype=T.float16) + C_local = T.alloc_fragment((128, 128), T.float32) + T.copy(A, A_shared) + T.clear(C_local) + T.maca_ldg_b128_bsm_predicator( + T.address_of(B_storage[0, 0]), + T.address_of(B[0, 0]), + 0, + True, + True, + False, + True, + 1, + 1, + "MACA_ICMP_EQ", + ) + T.maca_arrive_gvmcnt(0) + T.maca_barrier_inst() + T.gemm(A_shared, B_shared, C_local, False, True) + T.copy(C_local, C) + + kernel = tilelang.compile(main, out_idx=[2], target="maca") + src = kernel.get_kernel_source() + assert "__builtin_mxc_ldg_b128_bsm_predicator" in src + assert "__builtin_mxc_arrive_gvmcnt(0)" in src + assert "__builtin_mxc_barrier_inst();" in src + assert "__builtin_mxc_mma_16x16x16f16" in src + + if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/maca/tilelibrary/test_tilelang_maca_gemm_template_contract.py b/testing/maca/tilelibrary/test_tilelang_maca_gemm_template_contract.py new file mode 100644 index 00000000..7c5f2ec7 --- /dev/null +++ b/testing/maca/tilelibrary/test_tilelang_maca_gemm_template_contract.py @@ -0,0 +1,72 @@ +from pathlib import Path + + +def test_maca_dense_template_normalizes_partitioned_fragments_before_gemm(): + repo_root = Path(__file__).resolve().parents[3] + gemm_header = repo_root / "src" / "tl_templates" / "maca" / "gemm.h" + + source = gemm_header.read_text() + normalized = " ".join(source.split()) + + assert "static CUTE_DEVICE auto remove_swizzle(Layout const &layout)" in normalized + assert "static CUTE_DEVICE auto remove_swizzle(ComposedLayout const &layout)" in normalized + assert source.count("CUTE_UNROLL") >= 2 + assert source.count("auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));") >= 2 + assert source.count("gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);") >= 2 + assert "Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);" in source + assert source.count("copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));") >= 2 + assert source.count("copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));") >= 2 + + +def test_maca_dense_template_fragment_c_matches_cute_partition_order(): + repo_root = Path(__file__).resolve().parents[3] + layout_source = repo_root / "src" / "layout" / "gemm_layouts.cc" + + source = layout_source.read_text() + maca_fragment = source.split("Fragment makeGemmFragmentCMACA", 1)[1].split("Fragment makeGemmFragmentCHopper", 1)[0] + + assert "base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false)" in maca_fragment + assert "thread_layout->Repeat({warp_m / 16, warp_n / 16}, false, false)" in maca_fragment + assert "base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true)" not in maca_fragment + + +def test_maca_wsm_template_declares_supported_contract(): + repo_root = Path(__file__).resolve().parents[3] + wsm_header = repo_root / "src" / "tl_templates" / "maca" / "gemm_wsm.h" + + source = wsm_header.read_text() + + assert "static_assert(!trans_A && !trans_B" in source + assert "static_assert(num_warp_m == 1 && num_warp_n == 1" in source + assert "static_assert(kPack == 8" in source + assert "static_assert(AStrideElements % 8 == 0" in source + assert "static_assert(sizeof(A_type) == 2 && sizeof(B_type) == 2" in source + assert "static_assert(sizeof(C_type) == 4" in source + + +def test_maca_wsm_lowering_names_workspace_size(): + repo_root = Path(__file__).resolve().parents[3] + gemm_mma = repo_root / "tilelang" / "maca" / "op" / "gemm" / "gemm_mma.py" + + source = gemm_mma.read_text() + + assert "MACA_WSM_STAGE_BYTES = 0x4000" in source + assert "MACA_WSM_STAGE_COUNT = 4" in source + assert "MACA_WSM_WORKSPACE_BYTES = MACA_WSM_STAGE_BYTES * MACA_WSM_STAGE_COUNT" in source + assert "T.alloc_shared((MACA_WSM_WORKSPACE_BYTES,)" in source + + +def test_maca_wsm_lowering_falls_back_for_unsupported_contracts(): + repo_root = Path(__file__).resolve().parents[3] + gemm_mma = repo_root / "tilelang" / "maca" / "op" / "gemm" / "gemm_mma.py" + + source = gemm_mma.read_text() + + assert "def _can_use_maca_gemm_wsm(" in source + assert "not trans_a" in source + assert "not trans_b" in source + assert "num_warp_m == 1" in source + assert "num_warp_n == 1" in source + assert "k_pack == 8" in source + assert "a_source_stride % 8 == 0" in source + assert 'consumer_surface = "direct_tl_gemm_ss"' in source diff --git a/tilelang/contrib/mxcc.py b/tilelang/contrib/mxcc.py index 795cac56..dce168c3 100644 --- a/tilelang/contrib/mxcc.py +++ b/tilelang/contrib/mxcc.py @@ -37,6 +37,12 @@ def compile_maca(code, target_format="mcbin", arch=None, options=None, path_targ path_target : str, optional Output file. + Environment Variables + --------------------- + TILELANG_MXCC_FLAGS : str, optional + Extra MXCC command-line flags parsed with shell-like syntax. These + flags are appended after explicit options and before output specs. + Return ------ cubin : bytearray @@ -91,7 +97,10 @@ def compile_maca(code, target_format="mcbin", arch=None, options=None, path_targ extra_env_flags = os.environ.get("TILELANG_MXCC_FLAGS") if extra_env_flags: - cmd += shlex.split(extra_env_flags) + try: + cmd += shlex.split(extra_env_flags) + except ValueError as exc: + raise ValueError(f"malformed TILELANG_MXCC_FLAGS={extra_env_flags!r}: {exc}") from exc cmd += ["-D__FAST_HALF_CVT__"] cmd += ["-o", file_target] diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index 81ae93cf..0f412776 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -156,6 +156,7 @@ def gemm( clear_accum: bool = False, k_pack: int = 1, mbar: BarrierType | None = None, + annotations: dict | None = None, ) -> tirx.PrimExpr: """TileLang GEMM operator. @@ -179,6 +180,8 @@ def gemm( k_pack (int): Numbers of packed matrix cores, for ROCm only. Defaults to 1. mbar (BarrierType, i.e. Buffer | BufferLoad, or Var, optional): Mbarrier in Blackwell. Required when this GEMM lowers to TCGEN5MMA. Defaults to None. + annotations (dict, optional): Backend-specific metadata consumed by target lowering. + Defaults to None. Returns: tirx.Call: A handle to the GEMM operation. @@ -195,6 +198,7 @@ def gemm( k_pack, 0, mbar, + annotations=annotations, ) diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index ae50e86c..0460611f 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -12,6 +12,8 @@ make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 + make_maca_gemm_ab_layout, # noqa: F401 + make_maca_gemm_fragment_c, # noqa: F401 make_gemm_fragment_8x8, # noqa: F401 make_gemm_fragment_8x8_transposed, # noqa: F401 make_fully_replicated_layout_fragment, # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 07ce0912..ca2b7e39 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -140,6 +140,22 @@ def make_linear_layout(buffer_or_load_or_region: BufferLikeType): return _ffi_api.make_linear_layout(list(shape)) +def make_maca_gemm_ab_layout(buffer: BufferLikeType, kfactor: int): + """ + Create the MACA GEMM shared-memory layout used by the compiler-path + template backend. + """ + buf, _, _ = _get_buffer_info(buffer) + return _ffi_api.make_maca_gemm_ab_layout(buf, kfactor) + + +def make_maca_gemm_fragment_c(block_m: int, block_n: int, warp_m: int, warp_n: int, element_size: int): + """ + Create the MACA GEMM accumulator fragment layout. + """ + return _ffi_api.make_maca_gemm_fragment_c(block_m, block_n, warp_m, warp_n, element_size) + + def make_gemm_fragment_8x8(): """ Create a standard 8x8 GEMM fragment layout for ldmatrix/stmatrix. diff --git a/tilelang/maca/intrinsics/layout/mma_layout.py b/tilelang/maca/intrinsics/layout/mma_layout.py index 1cadf7b8..4f803d5c 100644 --- a/tilelang/maca/intrinsics/layout/mma_layout.py +++ b/tilelang/maca/intrinsics/layout/mma_layout.py @@ -147,8 +147,8 @@ def thread_id_shared_access_64x16_to_16x64_layout_B(thread_id, local_id): def shared_16x64_to_local_64x16_layout_B(i, j): - thread_id = i + 16 * (j // 16) - local = j % 16 + thread_id = j + (i // 16) * 16 + local = i % 16 return thread_id, local diff --git a/tilelang/maca/intrinsics/macro/mma_macro_generator.py b/tilelang/maca/intrinsics/macro/mma_macro_generator.py index f6f03f72..69d35d84 100644 --- a/tilelang/maca/intrinsics/macro/mma_macro_generator.py +++ b/tilelang/maca/intrinsics/macro/mma_macro_generator.py @@ -248,7 +248,7 @@ def get_ldmatrix_index_map(self, is_b=False): thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B ) else: - raise ValueError(f"k_dim must be 16 currently but got {k_dim}") + raise ValueError(f"k_dim must be 4, 8, 16, or 32 but got {k_dim}") return index_map, reverse_index_map diff --git a/tilelang/maca/op/gemm/gemm_mma.py b/tilelang/maca/op/gemm/gemm_mma.py index 2322ca55..8a92fbb9 100644 --- a/tilelang/maca/op/gemm/gemm_mma.py +++ b/tilelang/maca/op/gemm/gemm_mma.py @@ -1,28 +1,175 @@ from __future__ import annotations -from tilelang.tileop.gemm.gemm_base import GemmBase -from tilelang.layout import make_swizzled_layout -from tilelang.maca.intrinsics.macro.mma_macro_generator import ( - TensorCoreIntrinEmitter, +import os + +from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, + make_maca_gemm_ab_layout, + make_maca_gemm_fragment_c, + make_swizzled_layout, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang import tvm as tvm from tvm.target import Target from tvm.ir import Range from tvm import tirx from tilelang import language as T from tilelang.transform.simplify import _Simplify +from tilelang.maca.intrinsics.macro.mma_macro_generator import ( + TensorCoreIntrinEmitter, +) + + +def _get_maca_gemm_k_pack(default: int = 1) -> int: + value = os.environ.get("TILELANG_MACA_GEMM_K_PACK") + if value is None: + return default + try: + k_pack = int(value) + except ValueError as exc: + raise ValueError(f"TILELANG_MACA_GEMM_K_PACK must be an integer, got {value!r}") from exc + if k_pack < 1: + raise ValueError(f"TILELANG_MACA_GEMM_K_PACK must be >= 1, got {k_pack}") + return k_pack + + +def _get_maca_gemm_use_template(default: bool = False) -> bool: + value = os.environ.get("TILELANG_MACA_GEMM_USE_TEMPLATE") + if value is None: + return default + return value.strip().lower() not in {"0", "false", "f", "no", "n", ""} + + +def _get_maca_gemm_consumer_surface(default: str = "direct_tl_gemm_ss") -> str: + value = os.environ.get("TILELANG_MACA_GEMM_CONSUMER_SURFACE") + if value is None or not value.strip(): + return default + surface = value.strip().lower() + if surface == "not_direct_tl_gemm_ss": + return "wsm_aware" + if surface not in {"direct_tl_gemm_ss", "wsm_aware"}: + valid = "direct_tl_gemm_ss, wsm_aware, not_direct_tl_gemm_ss" + raise ValueError(f"TILELANG_MACA_GEMM_CONSUMER_SURFACE must be one of {valid}, got {value!r}") + return surface + + +def _make_maca_gemm_emitter(**kwargs): + return TensorCoreIntrinEmitter(**kwargs) + + +def _resolve_maca_gemm_shared_layout(value: str, env_key: str): + layout_name = value.strip().lower() + layout_name = { + "default": "swizzled", + "auto": "swizzled", + }.get(layout_name, layout_name) + + layout_factories = { + "swizzled": make_swizzled_layout, + "quarter": make_quarter_bank_swizzled_layout, + "half": make_half_bank_swizzled_layout, + "full": make_full_bank_swizzled_layout, + "linear": make_linear_layout, + } + if layout_name not in layout_factories: + valid = ", ".join(sorted(layout_factories)) + raise ValueError(f"{env_key} must be one of {valid}, got {value!r}") + return layout_factories[layout_name] + + +def _get_maca_gemm_shared_layout(): + value = os.environ.get("TILELANG_MACA_GEMM_SHARED_LAYOUT") + if value is None or not value.strip(): + return make_swizzled_layout + return _resolve_maca_gemm_shared_layout(value, "TILELANG_MACA_GEMM_SHARED_LAYOUT") + + +def _get_maca_gemm_shared_layout_for_operand(operand: str): + operand = operand.upper() + specific_key = f"TILELANG_MACA_GEMM_SHARED_LAYOUT_{operand}" + value = os.environ.get(specific_key) + if value is None or not value.strip(): + return _get_maca_gemm_shared_layout() + return _resolve_maca_gemm_shared_layout(value, specific_key) + + +def _format_maca_gemm_bool(value: bool) -> str: + return "true" if bool(value) else "false" + + +def _make_maca_gemm_template_name( + kind: str, + block_m: int, + block_n: int, + block_k: int, + num_warp_m: int, + num_warp_n: int, + trans_a: bool, + trans_b: bool, + clear_accum: bool, + k_pack: int, + extra_template_args: tuple[int, ...] = (), +) -> str: + template_args = [ + str(block_m), + str(block_n), + str(block_k), + str(num_warp_m), + str(num_warp_n), + _format_maca_gemm_bool(trans_a), + _format_maca_gemm_bool(trans_b), + _format_maca_gemm_bool(clear_accum), + str(k_pack), + *(str(value) for value in extra_template_args), + ] + return f"tl::gemm_{kind}<" + ", ".join(template_args) + ">" GEMM_INST_MMA = "maca.mma" +MACA_WSM_STAGE_BYTES = 0x4000 +MACA_WSM_STAGE_COUNT = 4 +MACA_WSM_WORKSPACE_BYTES = MACA_WSM_STAGE_BYTES * MACA_WSM_STAGE_COUNT + + +def _maca_dtype_name(dtype) -> str: + return str(dtype).lower() + + +def _can_use_maca_gemm_wsm( + *, + trans_a: bool, + trans_b: bool, + num_warp_m: int, + num_warp_n: int, + k_pack: int, + a_source_stride: int, + in_dtype, + accum_dtype, +) -> bool: + return ( + not trans_a + and not trans_b + and num_warp_m == 1 + and num_warp_n == 1 + and k_pack == 8 + and a_source_stride % 8 == 0 + and _maca_dtype_name(in_dtype) in {"float16", "half"} + and _maca_dtype_name(accum_dtype) in {"float", "float32"} + ) class GemmMMA(GemmBase): - def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tirx.Var | None = None): + def infer_layout(self, target: Target, thread_nums: int): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) - emitter = TensorCoreIntrinEmitter( + k_pack = _get_maca_gemm_k_pack(self.k_pack) + use_template = _get_maca_gemm_use_template(default=False) + mma_emitter = _make_maca_gemm_emitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, @@ -33,21 +180,36 @@ def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tirx.V warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, - thread_var=thread_var, + k_pack=k_pack, ) - return emitter + if use_template and self.is_gemm_ss(): - def infer_layout(self, target: Target, thread_nums: int): - mma_emitter = self._make_mma_emitter(target, thread_nums) + def shared_layout_a(buf): + return make_maca_gemm_ab_layout(buf, 1 if self.trans_A else 2) + + def shared_layout_b(buf): + return make_maca_gemm_ab_layout(buf, 2 if self.trans_B else 1) + + c_layout = make_maca_gemm_fragment_c( + int(self.M), + int(self.N), + int(warp_row_tiles), + int(warp_col_tiles), + self.C.dtype.bits, + ) + else: + shared_layout_a = _get_maca_gemm_shared_layout_for_operand("A") + shared_layout_b = _get_maca_gemm_shared_layout_for_operand("B") + c_layout = None if self.is_gemm_ss(): return { - self.A: make_swizzled_layout(self.A), - self.B: make_swizzled_layout(self.B), - self.C: mma_emitter.make_mma_store_layout(self.C), + self.A: shared_layout_a(self.A), + self.B: shared_layout_b(self.B), + self.C: c_layout if c_layout is not None else mma_emitter.make_mma_store_layout(self.C), } elif self.is_gemm_sr(): return { - self.A: make_swizzled_layout(self.A), + self.A: shared_layout_a(self.A), self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"), self.C: mma_emitter.make_mma_store_layout(self.C), } @@ -75,7 +237,24 @@ def lower( mbar_phase_expr: tirx.PrimExpr | None = None, ): thread_nums = thread_bounds.extent - mma_emitter = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_MMA) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + k_pack = _get_maca_gemm_k_pack(self.k_pack) + mma_emitter = _make_maca_gemm_emitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + k_pack=k_pack, + thread_var=thread_var, + ) in_dtype = self.in_dtype warp_rows = mma_emitter.warp_rows @@ -84,6 +263,8 @@ def lower( local_size_b = mma_emitter.local_size_b block_K = mma_emitter.chunk micro_size_k = mma_emitter.micro_size_k + k_pack = mma_emitter.k_pack + macro_size_k = micro_size_k * k_pack # We use region for memory input to support strided gemm # T.gemm(A_shared[0:128, :], B_shared, C_local) A_region = self.ARegion @@ -95,12 +276,87 @@ def lower( C_buf = C_region.buffer clear_accum = self.clear_accum + use_template = _get_maca_gemm_use_template(default=False) + consumer_surface = _get_maca_gemm_consumer_surface() - assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" + assert block_K >= macro_size_k, f"block_K ({block_K}) must be >= macro_size_k ({macro_size_k})" + assert block_K % macro_size_k == 0, f"block_K ({block_K}) must be divisible by macro_size_k ({macro_size_k})" assert is_full_region(C_region), "Fragment output C must be a full region" if self.is_gemm_ss(): + if use_template: + extra_template_args: tuple[int, ...] = () + if consumer_surface == "wsm_aware": + a_source_stride = self.annotations.get("maca_wsm_a_stride", None) + if a_source_stride is None: + a_source_stride = int(self.K) + a_source_stride = int(a_source_stride) + if _can_use_maca_gemm_wsm( + trans_a=bool(self.trans_A), + trans_b=bool(self.trans_B), + num_warp_m=int(m_warp), + num_warp_n=int(n_warp), + k_pack=int(k_pack), + a_source_stride=a_source_stride, + in_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + ): + extra_template_args = (a_source_stride,) + else: + consumer_surface = "direct_tl_gemm_ss" + gemm_kind = "ss_wsm" if consumer_surface == "wsm_aware" else "ss" + op_instance = _make_maca_gemm_template_name( + gemm_kind, + int(self.M), + int(self.N), + int(self.K), + int(m_warp), + int(n_warp), + bool(self.trans_A), + bool(self.trans_B), + bool(clear_accum), + int(k_pack), + extra_template_args=extra_template_args, + ) + + if consumer_surface == "wsm_aware": + a_source_ptr = self.annotations.get("maca_wsm_a_source_ptr") + b_source_ptr = self.annotations.get("maca_wsm_b_source_ptr") + if a_source_ptr is None: + a_source_ptr = T.access_ptr(A_region, "r") + if b_source_ptr is None: + b_source_ptr = T.access_ptr(B_region, "r") + + @T.prim_func + def _gemm_ss_wsm_template() -> None: + WSM = T.alloc_shared((MACA_WSM_WORKSPACE_BYTES,), T.uint8, scope="shared") + T.call_intrin( + "handle", + tirx.op.Op.get("tl.tl_gemm_wsm"), + op_instance, + T.access_ptr(A_region, "r"), + T.access_ptr(B_region, "r"), + T.access_ptr(C_region, "rw"), + T.address_of(WSM[0]), + a_source_ptr, + b_source_ptr, + ) + + return _Simplify(_gemm_ss_wsm_template, inline_let=True) + + @T.prim_func + def _gemm_ss_template() -> None: + T.call_intrin( + "handle", + tirx.op.Op.get("tl.tl_gemm"), + op_instance, + T.access_ptr(A_region, "r"), + T.access_ptr(B_region, "r"), + T.access_ptr(C_region, "rw"), + ) + + return _Simplify(_gemm_ss_template, inline_let=True) @T.prim_func def _gemm_ssr() -> None: @@ -109,13 +365,15 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a * k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * k_pack), in_dtype) if clear_accum: T.clear(C_buf) if self.mbar is not None: T.maca_barrier_arrive_and_wait(self.mbar) - for ki in T.serial(0, (block_K // micro_size_k)): + num_iters = block_K // macro_size_k + pipeline_stages = 4 if num_iters >= 4 else 0 + for ki in T.Pipelined(num_iters, num_stages=pipeline_stages): # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -146,11 +404,11 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a * k_pack), in_dtype) - for ki in T.serial(0, (block_K // micro_size_k)): - if clear_accum: - T.clear(C_buf) + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // macro_size_k)): # Load A into fragment mma_emitter.ldmatrix_a( A_local, @@ -176,10 +434,10 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * k_pack), in_dtype) if clear_accum: T.clear(C_buf) - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // macro_size_k)): # Load B into fragment mma_emitter.ldmatrix_b( B_local, @@ -205,7 +463,7 @@ def _gemm_rrr() -> None: accumulating into C_local. """ - for ki in T.serial(0, (block_K // micro_size_k)): + for ki in T.serial(0, (block_K // macro_size_k)): # Perform Matrix Multiplication mma_emitter.mma(A_buf, B_buf, C_buf, ki) diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 5408467b..387e7d51 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -165,6 +165,13 @@ def is_tcgen05(self) -> bool: def policy(self) -> GemmWarpPolicy: return getattr(self.gemm_node, "policy", None) + @property + def annotations(self): + annotations = getattr(self.gemm_node, "annotations", None) + if not annotations: + return {} + return annotations if isinstance(annotations, dict) else dict(annotations) + @property def mbarptr(self) -> PrimExpr: return getattr(self.gemm_node, "mbarPtr", tvm.tirx.const(0, T.uint32)) From be876e5f9dc0b294f96f26c19455ab7ad014480d Mon Sep 17 00:00:00 2001 From: VitalyR Date: Thu, 21 May 2026 17:47:49 +0800 Subject: [PATCH 3/4] Fix MACA reduce batch source assertion --- testing/python/language/test_tilelang_language_reduce.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index 73107111..f5b178ea 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -71,7 +71,7 @@ def _reduce_op(T, op, src, dst, dim, batch=1): ("abssum", T.int64, 128, 128, "fragment", "fragment", 64, 1), ("absmax", T.float32, 128, 128, "fragment", "fragment", 32, 1), ("absmax", T.int64, 128, 128, "fragment", "fragment", 64, 1), - # batch > 1: verify run_batch codegen and correctness together + # batch > 1: verify batched reduce codegen and correctness together ("sum", T.float32, 128, 64, "shared", "fragment", 256, 2), ("sum", T.float32, 128, 64, "shared", "fragment", 256, 4), ("sum", T.float16, 64, 128, "fragment", "fragment", 256, 4), @@ -120,8 +120,9 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype)): if batch > 1: src = jit_kernel.get_kernel_source() - m = re.search(r",\s*(\d+)\s*,\s*\d+\s*>::run_batch\(", src) - assert m is not None, f"Expected run_batch in generated source.\n{src}" + m = re.search(r",\s*(\d+)\s*,\s*\d+\s*>::run_batch(?:_offset)?\(", src) + assert m is not None, f"Expected batched reduce in generated source.\n{src}" + assert int(m.group(1)) > 1, f"Expected batch_size > 1, got {m.group(1)}.\n{src}" A = _make_input(M, N, dtype) B = jit_kernel(A) From db04e6fe42a0a556dfef94bdae0a5f7a0c6ab4ce Mon Sep 17 00:00:00 2001 From: VitalyR Date: Sat, 6 Jun 2026 03:14:29 +0800 Subject: [PATCH 4/4] [MetaxGPU] Preserve global access_ptr indices --- src/transform/lower_tile_op.cc | 8 ++++ ...st_tilelang_language_access_ptr_codegen.py | 43 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index b9d1ae7a..b427764d 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -1026,6 +1026,14 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (op->op.same_as(builtin::address_of()) || op->op.same_as(tl::access_ptr())) { + Optional resolved = ResolveBufferLoad(op->args[0]); + if (resolved.defined()) { + if (const auto *load = resolved.value().as()) { + if (IsGlobalBuffer(load->buffer)) { + return Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + } + } + } auto access_ptr = tvm::ffi::GetRef(op); auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, std::nullopt, op->dtype); diff --git a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py index 9ead9b30..f8355b06 100644 --- a/testing/maca/language/test_tilelang_language_access_ptr_codegen.py +++ b/testing/maca/language/test_tilelang_language_access_ptr_codegen.py @@ -169,6 +169,49 @@ def main( assert "tl::cp_async_wait<0>" not in src, "Did not expect async_copy lowering to auto-emit wait" +@requires_maca +def test_maca_global_atomic_add_preserves_logical_layout_indices_codegen(): + def make_dq_layout(dQ): + return T.Layout( + dQ.shape, + lambda b, h, l, d: [ + b, + h, + l // 8, + d // 8, + (d % 2), + 4 * (l % 8) + (d % 8) // 2, + ], + ) + + @T.prim_func + def main( + A: T.Tensor((64, 32), T.float16), + B: T.Tensor((64, 128), T.float16), + dQ: T.Tensor((1, 32, 512, 128), T.float32), + ): + with T.Kernel(32, 1, 1, threads=256) as (bx, by, bz): + A_shared = T.alloc_shared((64, 32), T.float16) + B_shared = T.alloc_shared((64, 128), T.float16) + dq = T.alloc_fragment((32, 128), T.float32) + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(A, A_shared) + T.copy(B, B_shared) + T.clear(dq) + for k in range(16): + T.gemm(A_shared, B_shared, dq, transpose_A=True) + T.atomic_add(dQ[bz, bx, k * 32 : (k + 1) * 32, :], dq) + + kernel = tilelang.compile(main, out_idx=None, target="maca") + src = kernel.get_kernel_source() + + assert "AtomicAdd" in src + assert "(k >> 4) + ((int)blockIdx.x)" not in src + assert "((k & 15) * 4096)" not in src + assert "((int)blockIdx.x) * 65536" in src + assert "(k * 4096)" in src + + @requires_maca def test_maca_bsm_intrinsics_codegen(): """Smoke-test codegen for the MACA BSM builtin wrappers."""