From 595e7e200bbd1f7d7be467c0d97d64b768d060d8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 7 May 2026 17:18:59 +0800 Subject: [PATCH 1/4] [Refactor] Move backend-specific GEMM implementations and transforms into backend directories Restructure the codebase so that each backend (cpu, cuda, rocm) owns its GEMM implementations, sparse GEMM implementations, and transform passes under a consistent op/ subdirectory layout. The shared GEMM registry and base classes remain in tileop/ as the platform-agnostic dispatch layer. - Move gemm/gemm_sp registries from backend/ into tileop/ as registry.py - Move CUDA GEMM impls (mma, mma_sm70, wgmma, tcgen05) into backend/cuda/op/gemm/ - Move CUDA sparse GEMM impl into backend/cuda/op/gemm_sp/ - Move CPU GEMM impl (scalar) into backend/cpu/op/gemm/ - Move ROCm GEMM impls (mfma, wmma) into backend/rocm/op/gemm/ - Move CUDA-specific transform passes from src/transform/ into src/backend/cuda/transform/ - Move CUDA runtime sources from src/runtime/ into src/backend/cuda/ - Remove dead backend-importing wrappers from transform/__init__.py - Update phase.py and tests to import CUDA transforms from their canonical location Each backend now has a symmetric op/ directory structure. Adding a new backend no longer requires modifying shared transform or tileop modules. --- src/backend/cuda/CMakeLists.txt | 3 ++- src/{runtime => backend/cuda}/runtime.cc | 2 +- src/{runtime => backend/cuda}/runtime.h | 8 +++--- .../cuda}/transform/lower_hopper_intrin.cc | 6 ++--- .../lower_l2_persistent_annotation.cc | 6 ++--- .../cuda}/transform/persist_threadblock.cc | 8 +++--- ..._tilelang_transform_lower_hopper_intrin.py | 5 ++-- tilelang/backend/__init__.py | 4 +-- tilelang/backend/cpu/__init__.py | 2 +- tilelang/backend/cpu/op/__init__.py | 1 + .../cpu/{gemm.py => op/gemm/__init__.py} | 4 +-- .../cpu/op}/gemm/gemm_scalar.py | 0 tilelang/backend/cuda/__init__.py | 4 +-- tilelang/backend/cuda/op/__init__.py | 4 +++ .../cuda/{gemm.py => op/gemm/__init__.py} | 12 +++++---- .../cuda/op}/gemm/gemm_mma.py | 2 +- .../cuda/op}/gemm/gemm_mma_sm70.py | 2 +- .../cuda/op}/gemm/gemm_tcgen05.py | 2 +- .../cuda/op}/gemm/gemm_wgmma.py | 2 +- .../{gemm_sp.py => op/gemm_sp/__init__.py} | 6 +++-- .../cuda/op}/gemm_sp/gemm_sp_mma.py | 2 +- tilelang/backend/cuda/transform/__init__.py | 27 +++++++++++++++++++ tilelang/backend/rocm/__init__.py | 2 +- tilelang/backend/rocm/op/__init__.py | 1 + .../rocm/{gemm.py => op/gemm/__init__.py} | 6 ++--- .../rocm/op}/gemm/gemm_mfma.py | 2 +- .../rocm/op}/gemm/gemm_wmma.py | 2 +- tilelang/engine/phase.py | 7 ++--- tilelang/tileop/gemm/__init__.py | 2 +- .../gemm.py => tileop/gemm/registry.py} | 0 tilelang/tileop/gemm_sp/__init__.py | 2 +- .../gemm_sp.py => tileop/gemm_sp/registry.py} | 0 tilelang/transform/__init__.py | 21 --------------- 33 files changed, 88 insertions(+), 69 deletions(-) rename src/{runtime => backend/cuda}/runtime.cc (99%) rename src/{runtime => backend/cuda}/runtime.h (80%) rename src/{ => backend/cuda}/transform/lower_hopper_intrin.cc (98%) rename src/{ => backend/cuda}/transform/lower_l2_persistent_annotation.cc (96%) rename src/{ => backend/cuda}/transform/persist_threadblock.cc (90%) create mode 100644 tilelang/backend/cpu/op/__init__.py rename tilelang/backend/cpu/{gemm.py => op/gemm/__init__.py} (60%) rename tilelang/{tileop => backend/cpu/op}/gemm/gemm_scalar.py (100%) create mode 100644 tilelang/backend/cuda/op/__init__.py rename tilelang/backend/cuda/{gemm.py => op/gemm/__init__.py} (69%) rename tilelang/{tileop => backend/cuda/op}/gemm/gemm_mma.py (99%) rename tilelang/{tileop => backend/cuda/op}/gemm/gemm_mma_sm70.py (99%) rename tilelang/{tileop => backend/cuda/op}/gemm/gemm_tcgen05.py (99%) rename tilelang/{tileop => backend/cuda/op}/gemm/gemm_wgmma.py (99%) rename tilelang/backend/cuda/{gemm_sp.py => op/gemm_sp/__init__.py} (51%) rename tilelang/{tileop => backend/cuda/op}/gemm_sp/gemm_sp_mma.py (99%) create mode 100644 tilelang/backend/cuda/transform/__init__.py create mode 100644 tilelang/backend/rocm/op/__init__.py rename tilelang/backend/rocm/{gemm.py => op/gemm/__init__.py} (65%) rename tilelang/{tileop => backend/rocm/op}/gemm/gemm_mfma.py (99%) rename tilelang/{tileop => backend/rocm/op}/gemm/gemm_wmma.py (99%) rename tilelang/{backend/gemm.py => tileop/gemm/registry.py} (100%) rename tilelang/{backend/gemm_sp.py => tileop/gemm_sp/registry.py} (100%) diff --git a/src/backend/cuda/CMakeLists.txt b/src/backend/cuda/CMakeLists.txt index 6eac59fca2..b6b1460a9b 100644 --- a/src/backend/cuda/CMakeLists.txt +++ b/src/backend/cuda/CMakeLists.txt @@ -137,7 +137,7 @@ if(TILELANG_USE_CUDA_STUBS) endif() file(GLOB TILE_LANG_CUDA_SRCS - src/runtime/runtime.cc + src/backend/cuda/runtime.cc src/backend/cuda/codegen/ptx.cc src/backend/cuda/codegen/codegen_cuda.cc src/backend/cuda/codegen/codegen_py.cc @@ -146,6 +146,7 @@ file(GLOB TILE_LANG_CUDA_SRCS src/backend/cuda/codegen/rt_mod_cuda.cc src/backend/cuda/codegen/rt_mod_cutedsl.cc src/backend/cuda/op/*.cc + src/backend/cuda/transform/*.cc ) list(REMOVE_ITEM TILE_LANG_CUDA_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/src/backend/cuda/op/copy_analysis.cc") diff --git a/src/runtime/runtime.cc b/src/backend/cuda/runtime.cc similarity index 99% rename from src/runtime/runtime.cc rename to src/backend/cuda/runtime.cc index c59dbb4299..9756ca124b 100644 --- a/src/runtime/runtime.cc +++ b/src/backend/cuda/runtime.cc @@ -1,5 +1,5 @@ /*! - * \file tl/runtime/runtime.h + * \file tl/backend/cuda/runtime.cc * \brief Runtime functions. * */ diff --git a/src/runtime/runtime.h b/src/backend/cuda/runtime.h similarity index 80% rename from src/runtime/runtime.h rename to src/backend/cuda/runtime.h index 4b389fc03e..90540fd789 100644 --- a/src/runtime/runtime.h +++ b/src/backend/cuda/runtime.h @@ -1,11 +1,11 @@ /*! - * \file tl/runtime/runtime.h + * \file tl/backend/cuda/runtime.h * \brief Runtime functions. * */ -#ifndef TVM_TL_RUNTIME_RUNTIME_H_ -#define TVM_TL_RUNTIME_RUNTIME_H_ +#ifndef TVM_TL_BACKEND_CUDA_RUNTIME_H_ +#define TVM_TL_BACKEND_CUDA_RUNTIME_H_ namespace tvm { namespace tl { @@ -25,4 +25,4 @@ constexpr const char *tvm_cuda_stream_reset_access_policy_window = } // namespace tl } // namespace tvm -#endif // TVM_TL_RUNTIME_RUNTIME_H_ +#endif // TVM_TL_BACKEND_CUDA_RUNTIME_H_ diff --git a/src/transform/lower_hopper_intrin.cc b/src/backend/cuda/transform/lower_hopper_intrin.cc similarity index 98% rename from src/transform/lower_hopper_intrin.cc rename to src/backend/cuda/transform/lower_hopper_intrin.cc index 18b405f2bb..f70f719cf3 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/backend/cuda/transform/lower_hopper_intrin.cc @@ -1,5 +1,5 @@ /*! - * \file lower hopper intrin.cc + * \file tl/backend/cuda/transform/lower_hopper_intrin.cc * \brief Lower Hopper intrinsics cuda GPU(sm90+) */ @@ -13,8 +13,8 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "backend/cuda/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/src/transform/lower_l2_persistent_annotation.cc b/src/backend/cuda/transform/lower_l2_persistent_annotation.cc similarity index 96% rename from src/transform/lower_l2_persistent_annotation.cc rename to src/backend/cuda/transform/lower_l2_persistent_annotation.cc index 1f7be710de..6b3a9b612b 100644 --- a/src/transform/lower_l2_persistent_annotation.cc +++ b/src/backend/cuda/transform/lower_l2_persistent_annotation.cc @@ -1,5 +1,5 @@ /*! - * \file lower_l2_persistent_annotation.cc + * \file tl/backend/cuda/transform/lower_l2_persistent_annotation.cc * \brief Lower L2 persistent annotation */ @@ -9,8 +9,8 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "backend/cuda/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/src/transform/persist_threadblock.cc b/src/backend/cuda/transform/persist_threadblock.cc similarity index 90% rename from src/transform/persist_threadblock.cc rename to src/backend/cuda/transform/persist_threadblock.cc index b64ffdcce8..4a0a09ecc9 100644 --- a/src/transform/persist_threadblock.cc +++ b/src/backend/cuda/transform/persist_threadblock.cc @@ -1,6 +1,6 @@ /*! - * \file lower_l2_persistent_annotation.cc - * \brief Lower L2 persistent annotation + * \file tl/backend/cuda/transform/persist_threadblock.cc + * \brief Persist thread blocks with cooperative groups. */ #include @@ -9,8 +9,8 @@ #include #include -#include "../op/builtin.h" -#include "../runtime/runtime.h" +#include "backend/cuda/runtime.h" +#include "op/builtin.h" namespace tvm { namespace tl { diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index 640e991828..70ea213247 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -1,5 +1,6 @@ from tilelang import tvm as tvm import tilelang as tl +from tilelang.backend.cuda import transform as cuda_transform from tilelang.utils.target import determine_target import tilelang.language as T import tilelang.testing @@ -12,7 +13,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = cuda_transform.LowerHopperIntrin()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) @@ -90,7 +91,7 @@ def before(): mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.tir.transform.BindTarget(auto_target)(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = cuda_transform.LowerHopperIntrin()(mod) func = mod["main"] assert not tvm.tir.analysis.undefined_vars(func.body, func.params) diff --git a/tilelang/backend/__init__.py b/tilelang/backend/__init__.py index cc871934ca..7262157505 100644 --- a/tilelang/backend/__init__.py +++ b/tilelang/backend/__init__.py @@ -1,5 +1,5 @@ -from .gemm import register_gemm_impl, resolve_gemm_impl # noqa: F401 -from .gemm_sp import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401 +from tilelang.tileop.gemm.registry import register_gemm_impl, resolve_gemm_impl # noqa: F401 +from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401 # Import built-in backend packages so their implementations register. from . import cpu as _cpu # noqa: F401,E402 diff --git a/tilelang/backend/cpu/__init__.py b/tilelang/backend/cpu/__init__.py index 3480af512d..e8fb2c24d2 100644 --- a/tilelang/backend/cpu/__init__.py +++ b/tilelang/backend/cpu/__init__.py @@ -1 +1 @@ -from . import gemm # noqa: F401 +from . import op # noqa: F401 diff --git a/tilelang/backend/cpu/op/__init__.py b/tilelang/backend/cpu/op/__init__.py new file mode 100644 index 0000000000..3480af512d --- /dev/null +++ b/tilelang/backend/cpu/op/__init__.py @@ -0,0 +1 @@ +from . import gemm # noqa: F401 diff --git a/tilelang/backend/cpu/gemm.py b/tilelang/backend/cpu/op/gemm/__init__.py similarity index 60% rename from tilelang/backend/cpu/gemm.py rename to tilelang/backend/cpu/op/gemm/__init__.py index affeaa3308..bdd710755d 100644 --- a/tilelang/backend/cpu/gemm.py +++ b/tilelang/backend/cpu/op/gemm/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_scalar import GEMM_INST_SCALAR, GemmScalar +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_scalar import GEMM_INST_SCALAR, GemmScalar def _match_scalar(target) -> bool: diff --git a/tilelang/tileop/gemm/gemm_scalar.py b/tilelang/backend/cpu/op/gemm/gemm_scalar.py similarity index 100% rename from tilelang/tileop/gemm/gemm_scalar.py rename to tilelang/backend/cpu/op/gemm/gemm_scalar.py diff --git a/tilelang/backend/cuda/__init__.py b/tilelang/backend/cuda/__init__.py index 5d013cefee..e4dd4f9e4d 100644 --- a/tilelang/backend/cuda/__init__.py +++ b/tilelang/backend/cuda/__init__.py @@ -1,2 +1,2 @@ -from . import gemm # noqa: F401 -from . import gemm_sp # noqa: F401 +from . import op # noqa: F401 +from . import transform # noqa: F401 diff --git a/tilelang/backend/cuda/op/__init__.py b/tilelang/backend/cuda/op/__init__.py new file mode 100644 index 0000000000..743e9e3c63 --- /dev/null +++ b/tilelang/backend/cuda/op/__init__.py @@ -0,0 +1,4 @@ +"""CUDA op registration frontends.""" + +from . import gemm # noqa: F401 +from . import gemm_sp # noqa: F401 diff --git a/tilelang/backend/cuda/gemm.py b/tilelang/backend/cuda/op/gemm/__init__.py similarity index 69% rename from tilelang/backend/cuda/gemm.py rename to tilelang/backend/cuda/op/gemm/__init__.py index 0072fda1aa..f78878b54a 100644 --- a/tilelang/backend/cuda/gemm.py +++ b/tilelang/backend/cuda/op/gemm/__init__.py @@ -1,10 +1,12 @@ +"""CUDA GEMM op registrations.""" + from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_mma import GEMM_INST_MMA, GemmMMA -from tilelang.tileop.gemm.gemm_mma_sm70 import GemmMMASm70 -from tilelang.tileop.gemm.gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5 -from tilelang.tileop.gemm.gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_mma import GEMM_INST_MMA, GemmMMA +from .gemm_mma_sm70 import GemmMMASm70 +from .gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5 +from .gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA from tilelang.utils.target import target_is_cuda, target_is_volta diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/backend/cuda/op/gemm/gemm_mma.py similarity index 99% rename from tilelang/tileop/gemm/gemm_mma.py rename to tilelang/backend/cuda/op/gemm/gemm_mma.py index 99e4eb4d9c..3baec0ed85 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/backend/cuda/op/gemm/gemm_mma.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, diff --git a/tilelang/tileop/gemm/gemm_mma_sm70.py b/tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py similarity index 99% rename from tilelang/tileop/gemm/gemm_mma_sm70.py rename to tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py index 1d4fd21058..dd66b48f2e 100644 --- a/tilelang/tileop/gemm/gemm_mma_sm70.py +++ b/tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py @@ -1,7 +1,7 @@ from __future__ import annotations # for Volta GPUs, which use legacy MMA instructions -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_volta_swizzled_layout from tilelang.intrinsics.mma_sm70_macro_generator import ( TensorCoreIntrinEmitter, diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/backend/cuda/op/gemm/gemm_tcgen05.py similarity index 99% rename from tilelang/tileop/gemm/gemm_tcgen05.py rename to tilelang/backend/cuda/op/gemm/gemm_tcgen05.py index 28d4c805be..78f9a24271 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/backend/cuda/op/gemm/gemm_tcgen05.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import ( Layout, make_full_bank_swizzled_layout, diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/backend/cuda/op/gemm/gemm_wgmma.py similarity index 99% rename from tilelang/tileop/gemm/gemm_wgmma.py rename to tilelang/backend/cuda/op/gemm/gemm_wgmma.py index 6618309263..939c8926fa 100644 --- a/tilelang/tileop/gemm/gemm_wgmma.py +++ b/tilelang/backend/cuda/op/gemm/gemm_wgmma.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import ( make_full_bank_swizzled_layout, make_half_bank_swizzled_layout, diff --git a/tilelang/backend/cuda/gemm_sp.py b/tilelang/backend/cuda/op/gemm_sp/__init__.py similarity index 51% rename from tilelang/backend/cuda/gemm_sp.py rename to tilelang/backend/cuda/op/gemm_sp/__init__.py index fead5b3d3d..ed8ade6377 100644 --- a/tilelang/backend/cuda/gemm_sp.py +++ b/tilelang/backend/cuda/op/gemm_sp/__init__.py @@ -1,7 +1,9 @@ +"""CUDA sparse GEMM op registrations.""" + from __future__ import annotations -from tilelang.backend.gemm_sp import register_gemm_sp_impl -from tilelang.tileop.gemm_sp.gemm_sp_mma import GemmSPMMA +from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl +from .gemm_sp_mma import GemmSPMMA from tilelang.utils.target import target_is_cuda diff --git a/tilelang/tileop/gemm_sp/gemm_sp_mma.py b/tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py similarity index 99% rename from tilelang/tileop/gemm_sp/gemm_sp_mma.py rename to tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py index 1a7964f9f4..6c0461ec64 100644 --- a/tilelang/tileop/gemm_sp/gemm_sp_mma.py +++ b/tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py @@ -1,4 +1,4 @@ -from .gemm_sp_base import GemmSPBase +from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter from tilelang.utils.language import is_shared, is_fragment diff --git a/tilelang/backend/cuda/transform/__init__.py b/tilelang/backend/cuda/transform/__init__.py new file mode 100644 index 0000000000..e6e4705c6f --- /dev/null +++ b/tilelang/backend/cuda/transform/__init__.py @@ -0,0 +1,27 @@ +"""CUDA-specific transformation frontends.""" + +from tilelang.transform import _ffi_api + + +def LowerHopperIntrin(): + """LowerHopperIntrin""" + if hasattr(_ffi_api, "LowerHopperIntrin"): + return _ffi_api.LowerHopperIntrin() # type: ignore + return lambda f: f + + +def LowerL2Persistent(): + """LowerL2Persistent""" + return _ffi_api.LowerL2Persistent() # type: ignore + + +def PersistThreadblock(): + """PersistThreadblock""" + return _ffi_api.PersistThreadblock() # type: ignore + + +__all__ = [ + "LowerHopperIntrin", + "LowerL2Persistent", + "PersistThreadblock", +] diff --git a/tilelang/backend/rocm/__init__.py b/tilelang/backend/rocm/__init__.py index 3480af512d..e8fb2c24d2 100644 --- a/tilelang/backend/rocm/__init__.py +++ b/tilelang/backend/rocm/__init__.py @@ -1 +1 @@ -from . import gemm # noqa: F401 +from . import op # noqa: F401 diff --git a/tilelang/backend/rocm/op/__init__.py b/tilelang/backend/rocm/op/__init__.py new file mode 100644 index 0000000000..3480af512d --- /dev/null +++ b/tilelang/backend/rocm/op/__init__.py @@ -0,0 +1 @@ +from . import gemm # noqa: F401 diff --git a/tilelang/backend/rocm/gemm.py b/tilelang/backend/rocm/op/gemm/__init__.py similarity index 65% rename from tilelang/backend/rocm/gemm.py rename to tilelang/backend/rocm/op/gemm/__init__.py index 94e7d17724..c08e949b35 100644 --- a/tilelang/backend/rocm/gemm.py +++ b/tilelang/backend/rocm/op/gemm/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from tilelang.backend.gemm import register_gemm_impl -from tilelang.tileop.gemm.gemm_mfma import GEMM_INST_MFMA, GemmMFMA -from tilelang.tileop.gemm.gemm_wmma import GEMM_INST_WMMA, GemmWMMA +from tilelang.tileop.gemm.registry import register_gemm_impl +from .gemm_mfma import GEMM_INST_MFMA, GemmMFMA +from .gemm_wmma import GEMM_INST_WMMA, GemmWMMA from tilelang.utils.target import target_is_hip diff --git a/tilelang/tileop/gemm/gemm_mfma.py b/tilelang/backend/rocm/op/gemm/gemm_mfma.py similarity index 99% rename from tilelang/tileop/gemm/gemm_mfma.py rename to tilelang/backend/rocm/op/gemm/gemm_mfma.py index 786baba96e..5ca8676183 100644 --- a/tilelang/tileop/gemm/gemm_mfma.py +++ b/tilelang/backend/rocm/op/gemm/gemm_mfma.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter, diff --git a/tilelang/tileop/gemm/gemm_wmma.py b/tilelang/backend/rocm/op/gemm/gemm_wmma.py similarity index 99% rename from tilelang/tileop/gemm/gemm_wmma.py rename to tilelang/backend/rocm/op/gemm/gemm_wmma.py index ab4ae6d50a..c9b1783d38 100644 --- a/tilelang/tileop/gemm/gemm_wmma.py +++ b/tilelang/backend/rocm/op/gemm/gemm_wmma.py @@ -2,7 +2,7 @@ from __future__ import annotations -from .gemm_base import GemmBase +from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.wmma_macro_generator import WMMAIntrinEmitter from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5563845214..4130e195a8 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -2,6 +2,7 @@ from tvm import tir, IRModule from tvm.target import Target import tilelang +from tilelang.backend.cuda import transform as cuda_transform from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, have_pdl @@ -204,7 +205,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map - mod = tilelang.transform.LowerL2Persistent()(mod) + mod = cuda_transform.LowerL2Persistent()(mod) # Decouple type cast vectorization constraints before vectorization mod = tilelang.transform.DecoupleTypeCast()(mod) # Legalize vectorized loops to ensure they are valid @@ -270,7 +271,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tir.transform.InferFragment()(mod) mod = tilelang.transform.LowerThreadAllreduce()(mod) mod = tilelang.transform.LowerLDGSTG()(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) + mod = cuda_transform.LowerHopperIntrin()(mod) # Global Barrier Synchronization must be applied before # SplitHostDevice pass, as the global barrier if allow_global_thread_synchronization(): @@ -305,6 +306,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) # Transform threadblock to persistent threadblock - mod = tilelang.transform.PersistThreadblock()(mod) + mod = cuda_transform.PersistThreadblock()(mod) return mod diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 22a2c91007..37f7e2d235 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -5,7 +5,7 @@ from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi -from tilelang.backend.gemm import resolve_gemm_impl +from .registry import resolve_gemm_impl from tilelang import _ffi_api diff --git a/tilelang/backend/gemm.py b/tilelang/tileop/gemm/registry.py similarity index 100% rename from tilelang/backend/gemm.py rename to tilelang/tileop/gemm/registry.py diff --git a/tilelang/tileop/gemm_sp/__init__.py b/tilelang/tileop/gemm_sp/__init__.py index 6e2c4a7d2b..1a49b86ec3 100644 --- a/tilelang/tileop/gemm_sp/__init__.py +++ b/tilelang/tileop/gemm_sp/__init__.py @@ -5,7 +5,7 @@ from tvm.ir import Range from tvm.runtime import Scriptable import tvm_ffi -from tilelang.backend.gemm_sp import resolve_gemm_sp_impl +from .registry import resolve_gemm_sp_impl from tilelang.tileop.base import GemmWarpPolicy diff --git a/tilelang/backend/gemm_sp.py b/tilelang/tileop/gemm_sp/registry.py similarity index 100% rename from tilelang/backend/gemm_sp.py rename to tilelang/tileop/gemm_sp/registry.py diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 677887bd49..599ccef1ad 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -121,17 +121,6 @@ def VerifyParallelLoop(): return _ffi_api.VerifyParallelLoop() # type: ignore -def LowerHopperIntrin(): - """LowerHopperIntrin - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore - - def ThreadSync(storage_scope: str): """Insert sync between parallel read/write of shared buffers. @@ -424,21 +413,11 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore -def LowerL2Persistent(): - """LowerL2Persistent""" - return _ffi_api.LowerL2Persistent() # type: ignore - - def MarkCudaSyncCalls(have_pdl: bool = False): """MarkCudaSyncCalls""" return _ffi_api.MarkCudaSyncCalls(have_pdl) # type: ignore -def PersistThreadblock(): - """PersistThreadblock""" - return _ffi_api.PersistThreadblock() # type: ignore - - def LowerSharedBarrier(): """LowerSharedBarrier""" return _ffi_api.LowerSharedBarrier() # type: ignore From d676d3661a1815b076ff67d3ddd3e92c559c749a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 8 May 2026 15:38:51 +0800 Subject: [PATCH 2/4] Refactor Python backend package layout --- .../matmul/benchmark_matmul_intrinsic.py | 4 ++-- docs/deeplearning_operators/matmul.md | 2 +- .../tilelang_bitnet_158_int8xint2_prefill.py | 4 ++-- .../example_dequant_gemm_fine_grained.py | 4 ++-- examples/gemm/README.md | 2 +- examples/gemm/example_gemm_intrinsics.py | 4 ++-- ...xample_tilelang_gemm_amd_fp8_preshuffle.py | 2 +- .../example_tilelang_gemm_fp8_intrinsic.py | 6 ++--- .../hadamard_transform/example_hadamard.py | 2 +- examples/plot_layout/README.md | 4 ++-- examples/plot_layout/fragment_mfma_load_a.py | 4 ++-- examples/plot_layout/fragment_mma_load_a.py | 4 ++-- .../amd/test_tilelang_gemm_mfma_intrinsic.py | 4 ++-- .../amd/test_tilelang_gemm_mfma_preshuffle.py | 4 ++-- .../test_tilelang_kernel_bf16_gemm_mma.py | 4 ++-- .../test_tilelang_kernel_fp8_gemm_mma.py | 4 ++-- ...test_tilelang_kernel_gemm_mma_intrinsic.py | 4 ++-- .../kernel/test_tilelang_kernel_gemm_simt.py | 2 +- .../test_tilelang_language_reshape.py | 2 +- .../test_tilelang_language_vectorize.py | 2 +- .../test_tilelang_tilelibrary_gemm_sp.py | 2 +- .../test_tilelang_tilelibrary_gemm_sp_v2.py | 2 +- ..._tilelang_transform_lower_hopper_intrin.py | 2 +- tilelang/__init__.py | 3 +++ tilelang/backend/__init__.py | 7 ------ tilelang/backend/rocm/__init__.py | 1 - tilelang/{backend => }/cpu/__init__.py | 0 tilelang/{backend => }/cpu/op/__init__.py | 0 .../{backend => }/cpu/op/gemm/__init__.py | 0 .../{backend => }/cpu/op/gemm/gemm_scalar.py | 0 tilelang/{backend => }/cuda/__init__.py | 1 + tilelang/cuda/intrinsics/__init__.py | 13 +++++++++++ tilelang/cuda/intrinsics/layout/__init__.py | 8 +++++++ .../intrinsics/layout}/mma_layout.py | 0 .../intrinsics/layout}/mma_sm70_layout.py | 0 .../intrinsics/layout}/mma_sp_layout.py | 2 +- .../intrinsics/layout}/utils.py | 10 -------- tilelang/cuda/intrinsics/macro/__init__.py | 6 +++++ .../intrinsics/macro}/mma_macro_generator.py | 6 ++--- .../macro}/mma_sm70_macro_generator.py | 2 +- .../macro}/mma_sp_macro_generator.py | 4 ++-- .../macro}/tcgen05_macro_generator.py | 0 .../macro}/wgmma_macro_generator.py | 2 +- tilelang/{backend => }/cuda/op/__init__.py | 0 .../{backend => }/cuda/op/gemm/__init__.py | 0 .../{backend => }/cuda/op/gemm/gemm_mma.py | 2 +- .../cuda/op/gemm/gemm_mma_sm70.py | 2 +- .../cuda/op/gemm/gemm_tcgen05.py | 2 +- .../{backend => }/cuda/op/gemm/gemm_wgmma.py | 2 +- .../{backend => }/cuda/op/gemm_sp/__init__.py | 0 .../cuda/op/gemm_sp/gemm_sp_mma.py | 2 +- .../{backend => }/cuda/transform/__init__.py | 0 tilelang/engine/phase.py | 7 +++--- tilelang/intrinsics/__init__.py | 10 ++++---- tilelang/language/gemm_op.py | 2 +- tilelang/rocm/__init__.py | 2 ++ tilelang/rocm/intrinsics/__init__.py | 12 ++++++++++ tilelang/{ => rocm}/intrinsics/mfma_layout.py | 0 .../intrinsics/mfma_macro_generator.py | 0 tilelang/rocm/intrinsics/utils.py | 23 +++++++++++++++++++ tilelang/{ => rocm}/intrinsics/wmma_layout.py | 0 .../intrinsics/wmma_macro_generator.py | 0 tilelang/{backend => }/rocm/op/__init__.py | 0 .../{backend => }/rocm/op/gemm/__init__.py | 0 .../{backend => }/rocm/op/gemm/gemm_mfma.py | 2 +- .../{backend => }/rocm/op/gemm/gemm_wmma.py | 2 +- 66 files changed, 129 insertions(+), 80 deletions(-) delete mode 100644 tilelang/backend/__init__.py delete mode 100644 tilelang/backend/rocm/__init__.py rename tilelang/{backend => }/cpu/__init__.py (100%) rename tilelang/{backend => }/cpu/op/__init__.py (100%) rename tilelang/{backend => }/cpu/op/gemm/__init__.py (100%) rename tilelang/{backend => }/cpu/op/gemm/gemm_scalar.py (100%) rename tilelang/{backend => }/cuda/__init__.py (63%) create mode 100644 tilelang/cuda/intrinsics/__init__.py create mode 100644 tilelang/cuda/intrinsics/layout/__init__.py rename tilelang/{intrinsics => cuda/intrinsics/layout}/mma_layout.py (100%) rename tilelang/{intrinsics => cuda/intrinsics/layout}/mma_sm70_layout.py (100%) rename tilelang/{intrinsics => cuda/intrinsics/layout}/mma_sp_layout.py (99%) rename tilelang/{intrinsics => cuda/intrinsics/layout}/utils.py (90%) create mode 100644 tilelang/cuda/intrinsics/macro/__init__.py rename tilelang/{intrinsics => cuda/intrinsics/macro}/mma_macro_generator.py (99%) rename tilelang/{intrinsics => cuda/intrinsics/macro}/mma_sm70_macro_generator.py (99%) rename tilelang/{intrinsics => cuda/intrinsics/macro}/mma_sp_macro_generator.py (99%) rename tilelang/{intrinsics => cuda/intrinsics/macro}/tcgen05_macro_generator.py (100%) rename tilelang/{intrinsics => cuda/intrinsics/macro}/wgmma_macro_generator.py (99%) rename tilelang/{backend => }/cuda/op/__init__.py (100%) rename tilelang/{backend => }/cuda/op/gemm/__init__.py (100%) rename tilelang/{backend => }/cuda/op/gemm/gemm_mma.py (99%) rename tilelang/{backend => }/cuda/op/gemm/gemm_mma_sm70.py (98%) rename tilelang/{backend => }/cuda/op/gemm/gemm_tcgen05.py (99%) rename tilelang/{backend => }/cuda/op/gemm/gemm_wgmma.py (99%) rename tilelang/{backend => }/cuda/op/gemm_sp/__init__.py (100%) rename tilelang/{backend => }/cuda/op/gemm_sp/gemm_sp_mma.py (99%) rename tilelang/{backend => }/cuda/transform/__init__.py (100%) create mode 100644 tilelang/rocm/__init__.py create mode 100644 tilelang/rocm/intrinsics/__init__.py rename tilelang/{ => rocm}/intrinsics/mfma_layout.py (100%) rename tilelang/{ => rocm}/intrinsics/mfma_macro_generator.py (100%) create mode 100644 tilelang/rocm/intrinsics/utils.py rename tilelang/{ => rocm}/intrinsics/wmma_layout.py (100%) rename tilelang/{ => rocm}/intrinsics/wmma_macro_generator.py (100%) rename tilelang/{backend => }/rocm/op/__init__.py (100%) rename tilelang/{backend => }/rocm/op/gemm/__init__.py (100%) rename tilelang/{backend => }/rocm/op/gemm/gemm_mfma.py (99%) rename tilelang/{backend => }/rocm/op/gemm/gemm_wmma.py (98%) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 4ef860c210..bc6b2b8e96 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -4,8 +4,8 @@ from tvm import DataType import tilelang as tl import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index 12189eb8fa..076be9f0f8 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -62,7 +62,7 @@ Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplicatio ```python import tilelang import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index 9d7ebcf88c..031783910d 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -7,12 +7,12 @@ import tilelang.language as T from tilelang import tvm as tvm from tvm import DataType -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( make_mma_swizzle_layout as make_swizzle_layout, ) import numpy as np -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index a870208083..3343b43267 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -141,8 +141,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( accum_dtype, transform_b, ): - from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout - from tilelang.intrinsics.mma_macro_generator import ( + from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout as make_swizzle_layout + from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform, ) diff --git a/examples/gemm/README.md b/examples/gemm/README.md index 9ab7fb6614..fdd919dbf8 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -174,7 +174,7 @@ Below is a more advanced snippet that showcases how to apply memory layouts, ena import tilelang.language as T # `make_mma_swizzle_layout` is a python-defined layout function # that helps align data for MMA (Matrix Multiply-Accumulate) operations. -from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout as make_swizzle_layout def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 15e552587e..4c264c0e4f 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -2,8 +2,8 @@ from tvm import DataType import tilelang import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py index fc7fb44003..a82cb54084 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py @@ -6,7 +6,7 @@ import tilelang.language as T from tilelang.tileop.base import GemmWarpPolicy from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index d9f749d9f2..2085ee8924 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -3,9 +3,9 @@ import tilelang.testing from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter -from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 65f463b71b..15efbf4467 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -1,6 +1,6 @@ import tilelang import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout import math import argparse diff --git a/examples/plot_layout/README.md b/examples/plot_layout/README.md index 8204e93d80..c2d3839e97 100644 --- a/examples/plot_layout/README.md +++ b/examples/plot_layout/README.md @@ -7,7 +7,7 @@ import tilelang.language as T from tvm import DataType from tvm.tir import IndexMap from typing import Literal, Callable -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.cuda.intrinsics.layout.utils import get_mma_micro_size from tilelang.tools import plot_layout def make_mma_load_base_layout(dtype: str = T.float16, @@ -36,7 +36,7 @@ def make_mma_load_base_layout(dtype: str = T.float16, AssertionError If `local_buf` is not detected to be a fragment buffer. """ - from tilelang.intrinsics.mma_layout import ( + from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x16_to_mma_32x8_layout_sr, shared_16x16_to_mma_32x8_layout_rs, shared_16x32_to_mma_32x16_layout, diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py index d45cc227bc..20a5cbba48 100644 --- a/examples/plot_layout/fragment_mfma_load_a.py +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -1,9 +1,9 @@ import tilelang.language as T from typing import Literal, Callable from tvm.tir import IndexMap -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.rocm.intrinsics.utils import get_mma_micro_size -from tilelang.intrinsics.mfma_layout import ( +from tilelang.rocm.intrinsics.mfma_layout import ( shared_16x4_to_local_64x1_layout_A, shared_16x16_to_local_64x4_layout_A, shared_16x32_to_local_64x8_layout_A, diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index df4a0b8870..7ac6bff30e 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -2,7 +2,7 @@ from typing import Literal, Callable from tvm import DataType from tvm.tir import IndexMap -from tilelang.intrinsics.utils import get_mma_micro_size +from tilelang.cuda.intrinsics.layout.utils import get_mma_micro_size def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: @@ -26,7 +26,7 @@ def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", " Describes how threads and indices in fragment are laid out. """ - from tilelang.intrinsics.mma_layout import ( + from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_a, shared_16x32_to_mma_32x16_layout_sr_a, diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 3fe33aebf0..00fac1a3a3 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -3,8 +3,8 @@ import tilelang.testing from tilelang import tvm as tvm import tilelang.language as T -from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout -from tilelang.intrinsics.mfma_macro_generator import ( +from tilelang.rocm.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.rocm.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index abedd1f19b..864ac58c7b 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -4,8 +4,8 @@ import tilelang.testing from tilelang import tvm as tvm import tilelang.language as T -from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout -from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.rocm.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout +from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter from tilelang.transform import simplify_prim_func from tilelang.utils import determine_fp8_type diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index 33eef09a56..12ff9c0586 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -4,8 +4,8 @@ import tilelang.testing from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index f8793ba2e9..ae728854a0 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -4,8 +4,8 @@ import tilelang.testing from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index 7f7f36c51d..76a6e3d610 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -4,8 +4,8 @@ import tilelang.testing from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.transform import simplify_prim_func diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py index 5c52f432d0..dd96e38f1a 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -4,7 +4,7 @@ from tilelang import tvm as tvm from tvm import DataType import tilelang.language as T -from tilelang.intrinsics import get_swizzle_layout +from tilelang.cuda.intrinsics import get_swizzle_layout from tilelang.transform import simplify_prim_func tilelang.testing.set_random_seed(0) diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 27388911b7..78e38de6b9 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -160,7 +160,7 @@ def test_reshape_fragment(): def reshape_layout_transform_shared(N, M, dtype): - from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout + from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout @T.prim_func def main( diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index f042339d42..7446d73eae 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -2,7 +2,7 @@ import tilelang.testing import tilelang.language as T -from tilelang.intrinsics import make_mma_swizzle_layout +from tilelang.cuda.intrinsics import make_mma_swizzle_layout import pytest diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index 8ffffd8ce0..de7808d9f0 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -7,7 +7,7 @@ from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.tensor import torch_assert_close -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter torch.backends.cuda.matmul.allow_tf32 = False diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 32742a005f..921e3b4de2 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -3,7 +3,7 @@ from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.utils.tensor import torch_assert_close from tilelang.layout import make_cutlass_metadata_layout -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter import tilelang.testing import torch diff --git a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py index 70ea213247..67b063b139 100644 --- a/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py +++ b/testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py @@ -1,6 +1,6 @@ from tilelang import tvm as tvm import tilelang as tl -from tilelang.backend.cuda import transform as cuda_transform +from tilelang.cuda import transform as cuda_transform from tilelang.utils.target import determine_target import tilelang.language as T import tilelang.testing diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 034ae3cbfb..f1c9fa9e4f 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -186,5 +186,8 @@ def _load_tile_lang_lib(): from .math import * # noqa: F403 from . import ir # noqa: F401 from . import tileop # noqa: F401 + from . import cpu as cpu # noqa: F401 + from . import cuda as cuda # noqa: F401 + from . import rocm as rocm # noqa: F401 del _lazy_load_lib diff --git a/tilelang/backend/__init__.py b/tilelang/backend/__init__.py deleted file mode 100644 index 7262157505..0000000000 --- a/tilelang/backend/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from tilelang.tileop.gemm.registry import register_gemm_impl, resolve_gemm_impl # noqa: F401 -from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401 - -# Import built-in backend packages so their implementations register. -from . import cpu as _cpu # noqa: F401,E402 -from . import cuda as _cuda # noqa: F401,E402 -from . import rocm as _rocm # noqa: F401,E402 diff --git a/tilelang/backend/rocm/__init__.py b/tilelang/backend/rocm/__init__.py deleted file mode 100644 index e8fb2c24d2..0000000000 --- a/tilelang/backend/rocm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import op # noqa: F401 diff --git a/tilelang/backend/cpu/__init__.py b/tilelang/cpu/__init__.py similarity index 100% rename from tilelang/backend/cpu/__init__.py rename to tilelang/cpu/__init__.py diff --git a/tilelang/backend/cpu/op/__init__.py b/tilelang/cpu/op/__init__.py similarity index 100% rename from tilelang/backend/cpu/op/__init__.py rename to tilelang/cpu/op/__init__.py diff --git a/tilelang/backend/cpu/op/gemm/__init__.py b/tilelang/cpu/op/gemm/__init__.py similarity index 100% rename from tilelang/backend/cpu/op/gemm/__init__.py rename to tilelang/cpu/op/gemm/__init__.py diff --git a/tilelang/backend/cpu/op/gemm/gemm_scalar.py b/tilelang/cpu/op/gemm/gemm_scalar.py similarity index 100% rename from tilelang/backend/cpu/op/gemm/gemm_scalar.py rename to tilelang/cpu/op/gemm/gemm_scalar.py diff --git a/tilelang/backend/cuda/__init__.py b/tilelang/cuda/__init__.py similarity index 63% rename from tilelang/backend/cuda/__init__.py rename to tilelang/cuda/__init__.py index e4dd4f9e4d..8ce2aa2507 100644 --- a/tilelang/backend/cuda/__init__.py +++ b/tilelang/cuda/__init__.py @@ -1,2 +1,3 @@ +from . import intrinsics # noqa: F401 from . import op # noqa: F401 from . import transform # noqa: F401 diff --git a/tilelang/cuda/intrinsics/__init__.py b/tilelang/cuda/intrinsics/__init__.py new file mode 100644 index 0000000000..8601d9342e --- /dev/null +++ b/tilelang/cuda/intrinsics/__init__.py @@ -0,0 +1,13 @@ +from .layout.utils import ( # noqa: F401 + mma_store_index_map, + get_ldmatrix_offset, + get_mma_micro_size, +) + +from .macro.mma_macro_generator import ( # noqa: F401 + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) + +from .layout.mma_layout import get_swizzle_layout # noqa: F401 +from .layout.mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/cuda/intrinsics/layout/__init__.py b/tilelang/cuda/intrinsics/layout/__init__.py new file mode 100644 index 0000000000..ff517fe501 --- /dev/null +++ b/tilelang/cuda/intrinsics/layout/__init__.py @@ -0,0 +1,8 @@ +from .utils import ( # noqa: F401 + mma_store_index_map, + get_ldmatrix_offset, + get_mma_micro_size, +) + +from .mma_layout import get_swizzle_layout # noqa: F401 +from .mma_layout import make_mma_swizzle_layout # noqa: F401 diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/cuda/intrinsics/layout/mma_layout.py similarity index 100% rename from tilelang/intrinsics/mma_layout.py rename to tilelang/cuda/intrinsics/layout/mma_layout.py diff --git a/tilelang/intrinsics/mma_sm70_layout.py b/tilelang/cuda/intrinsics/layout/mma_sm70_layout.py similarity index 100% rename from tilelang/intrinsics/mma_sm70_layout.py rename to tilelang/cuda/intrinsics/layout/mma_sm70_layout.py diff --git a/tilelang/intrinsics/mma_sp_layout.py b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py similarity index 99% rename from tilelang/intrinsics/mma_sp_layout.py rename to tilelang/cuda/intrinsics/layout/mma_sp_layout.py index 73da1289ab..c814e32307 100644 --- a/tilelang/intrinsics/mma_sp_layout.py +++ b/tilelang/cuda/intrinsics/layout/mma_sp_layout.py @@ -1,7 +1,7 @@ from tvm import DataType from typing import Literal -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( mma_load_a_32x4_to_shared_16x8_layout, mma_load_a_32x16_to_shared_16x32_layout, mma_load_a_32x8_to_shared_16x16_layout, diff --git a/tilelang/intrinsics/utils.py b/tilelang/cuda/intrinsics/layout/utils.py similarity index 90% rename from tilelang/intrinsics/utils.py rename to tilelang/cuda/intrinsics/layout/utils.py index f65fff1a9b..050ad09327 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/cuda/intrinsics/layout/utils.py @@ -10,11 +10,9 @@ mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) -from .mfma_layout import thread_id_shared_access_64x4_to_16x16_layout_C_n_m, thread_id_shared_access_64x16_to_32x32_layout_C_m_n from .mma_layout import get_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401 -from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 # the original implementation and insight is from the following code snippet @@ -89,14 +87,6 @@ def mma_store_index_map_fp64(thread_id, local_id): return mma_store_32x2_to_shared_8x8_layout_fp64(thread_id, local_id) -def mfma_store_index_map(thread_id, local_id): - return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) - - -def mfma_store_index_map_32x32(thread_id, local_id): - return thread_id_shared_access_64x16_to_32x32_layout_C_m_n(thread_id, local_id) - - def get_mma_micro_size(dtype: Literal["float16", "int8"]): # TODO(lei): FP8 related precision support. # Basic Tensor Core Matrix Multiply operation Unit diff --git a/tilelang/cuda/intrinsics/macro/__init__.py b/tilelang/cuda/intrinsics/macro/__init__.py new file mode 100644 index 0000000000..658f791220 --- /dev/null +++ b/tilelang/cuda/intrinsics/macro/__init__.py @@ -0,0 +1,6 @@ +from .mma_macro_generator import ( # noqa: F401 + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) + +from .mma_sp_macro_generator import SparseTensorCoreIntrinEmitter # noqa: F401 diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 26e34e6a51..6461dbd885 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -8,12 +8,12 @@ from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tilelang import tvm as tvm from tvm.runtime import convert -from .utils import ( +from ..layout.utils import ( mma_store_index_map, get_ldmatrix_offset, ) from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x8_to_mma_32x4_layout_sr_b, shared_16x16_to_mma_32x8_layout_sr_a, @@ -206,7 +206,7 @@ def get_thread_binding(self): return self.thread_var def get_store_index_map(self, inverse: bool = False) -> IndexMap: - from .utils import mma_store_index_map, mma_store_index_map_fp64 + from ..layout.utils import mma_store_index_map, mma_store_index_map_fp64 warp_size, local_size_c = self.WARP_SIZE, self.local_size_out if DataType(self.accum_dtype).bits == 64: diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_sm70_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py index 52679b169a..4fee93087a 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_sm70_macro_generator.py @@ -8,7 +8,7 @@ from tilelang import tvm as tvm from tvm.runtime import convert from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_sm70_layout import ( +from tilelang.cuda.intrinsics.layout.mma_sm70_layout import ( shared_16x4_to_mma_a_32x4_layout, shared_4x16_to_mma_b_32x4_layout, shared_16x4_to_mma_b_32x4_layout_trans, diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py similarity index 99% rename from tilelang/intrinsics/mma_sp_macro_generator.py rename to tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py index 18a37b8e83..826a0f58ec 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_sp_macro_generator.py @@ -6,13 +6,13 @@ from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad from tvm.ir import Range from tvm.runtime import convert -from .utils import ( +from ..layout.utils import ( mma_store_index_map, get_ldmatrix_offset, ) from tilelang.utils import is_fragment, get_buffer_region_from_load -from tilelang.intrinsics.mma_sp_layout import ( +from tilelang.cuda.intrinsics.layout.mma_sp_layout import ( shared_16x16_to_mma_sp_layout_sr_a, shared_16x16_to_mma_sp_layout_sr_b, shared_16x32_to_mma_sp_layout_sr_a, diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py similarity index 100% rename from tilelang/intrinsics/tcgen05_macro_generator.py rename to tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py similarity index 99% rename from tilelang/intrinsics/wgmma_macro_generator.py rename to tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py index 864420c771..f31c12fb94 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/wgmma_macro_generator.py @@ -15,7 +15,7 @@ make_linear_layout, ) from tvm.runtime import convert -from tilelang.intrinsics.mma_layout import ( +from tilelang.cuda.intrinsics.layout.mma_layout import ( shared_16x8_to_mma_32x4_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_a, shared_16x32_to_mma_32x16_layout_sr_a, diff --git a/tilelang/backend/cuda/op/__init__.py b/tilelang/cuda/op/__init__.py similarity index 100% rename from tilelang/backend/cuda/op/__init__.py rename to tilelang/cuda/op/__init__.py diff --git a/tilelang/backend/cuda/op/gemm/__init__.py b/tilelang/cuda/op/gemm/__init__.py similarity index 100% rename from tilelang/backend/cuda/op/gemm/__init__.py rename to tilelang/cuda/op/gemm/__init__.py diff --git a/tilelang/backend/cuda/op/gemm/gemm_mma.py b/tilelang/cuda/op/gemm/gemm_mma.py similarity index 99% rename from tilelang/backend/cuda/op/gemm/gemm_mma.py rename to tilelang/cuda/op/gemm/gemm_mma.py index 3baec0ed85..bd572c4075 100644 --- a/tilelang/backend/cuda/op/gemm/gemm_mma.py +++ b/tilelang/cuda/op/gemm/gemm_mma.py @@ -2,7 +2,7 @@ from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py b/tilelang/cuda/op/gemm/gemm_mma_sm70.py similarity index 98% rename from tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py rename to tilelang/cuda/op/gemm/gemm_mma_sm70.py index dd66b48f2e..ca5068cbc5 100644 --- a/tilelang/backend/cuda/op/gemm/gemm_mma_sm70.py +++ b/tilelang/cuda/op/gemm/gemm_mma_sm70.py @@ -3,7 +3,7 @@ # for Volta GPUs, which use legacy MMA instructions from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_volta_swizzled_layout -from tilelang.intrinsics.mma_sm70_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_sm70_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/backend/cuda/op/gemm/gemm_tcgen05.py b/tilelang/cuda/op/gemm/gemm_tcgen05.py similarity index 99% rename from tilelang/backend/cuda/op/gemm/gemm_tcgen05.py rename to tilelang/cuda/op/gemm/gemm_tcgen05.py index 78f9a24271..a6107083df 100644 --- a/tilelang/backend/cuda/op/gemm/gemm_tcgen05.py +++ b/tilelang/cuda/op/gemm/gemm_tcgen05.py @@ -8,7 +8,7 @@ make_quarter_bank_swizzled_layout, make_linear_layout, ) -from tilelang.intrinsics.tcgen05_macro_generator import ( +from tilelang.cuda.intrinsics.macro.tcgen05_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang import language as T diff --git a/tilelang/backend/cuda/op/gemm/gemm_wgmma.py b/tilelang/cuda/op/gemm/gemm_wgmma.py similarity index 99% rename from tilelang/backend/cuda/op/gemm/gemm_wgmma.py rename to tilelang/cuda/op/gemm/gemm_wgmma.py index 939c8926fa..5eabb1b797 100644 --- a/tilelang/backend/cuda/op/gemm/gemm_wgmma.py +++ b/tilelang/cuda/op/gemm/gemm_wgmma.py @@ -8,7 +8,7 @@ make_linear_layout, Layout, ) -from tilelang.intrinsics.wgmma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.wgmma_macro_generator import ( TensorCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment diff --git a/tilelang/backend/cuda/op/gemm_sp/__init__.py b/tilelang/cuda/op/gemm_sp/__init__.py similarity index 100% rename from tilelang/backend/cuda/op/gemm_sp/__init__.py rename to tilelang/cuda/op/gemm_sp/__init__.py diff --git a/tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py similarity index 99% rename from tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py rename to tilelang/cuda/op/gemm_sp/gemm_sp_mma.py index 6c0461ec64..dc381f7047 100644 --- a/tilelang/backend/cuda/op/gemm_sp/gemm_sp_mma.py +++ b/tilelang/cuda/op/gemm_sp/gemm_sp_mma.py @@ -1,6 +1,6 @@ from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter +from tilelang.cuda.intrinsics.macro.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter from tilelang.utils.language import is_shared, is_fragment from tilelang import tvm as tvm from tvm.target import Target diff --git a/tilelang/backend/cuda/transform/__init__.py b/tilelang/cuda/transform/__init__.py similarity index 100% rename from tilelang/backend/cuda/transform/__init__.py rename to tilelang/cuda/transform/__init__.py diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 4130e195a8..7c5b433dcb 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -2,7 +2,6 @@ from tvm import tir, IRModule from tvm.target import Target import tilelang -from tilelang.backend.cuda import transform as cuda_transform from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, have_pdl @@ -205,7 +204,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map - mod = cuda_transform.LowerL2Persistent()(mod) + mod = tilelang.cuda.transform.LowerL2Persistent()(mod) # Decouple type cast vectorization constraints before vectorization mod = tilelang.transform.DecoupleTypeCast()(mod) # Legalize vectorized loops to ensure they are valid @@ -271,7 +270,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tir.transform.InferFragment()(mod) mod = tilelang.transform.LowerThreadAllreduce()(mod) mod = tilelang.transform.LowerLDGSTG()(mod) - mod = cuda_transform.LowerHopperIntrin()(mod) + mod = tilelang.cuda.transform.LowerHopperIntrin()(mod) # Global Barrier Synchronization must be applied before # SplitHostDevice pass, as the global barrier if allow_global_thread_synchronization(): @@ -306,6 +305,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) # Transform threadblock to persistent threadblock - mod = cuda_transform.PersistThreadblock()(mod) + mod = tilelang.cuda.transform.PersistThreadblock()(mod) return mod diff --git a/tilelang/intrinsics/__init__.py b/tilelang/intrinsics/__init__.py index 1b3f106e71..b944ae89d0 100644 --- a/tilelang/intrinsics/__init__.py +++ b/tilelang/intrinsics/__init__.py @@ -1,14 +1,14 @@ -from .utils import ( +from tilelang.cuda.intrinsics.layout.utils import ( mma_store_index_map, # noqa: F401 get_ldmatrix_offset, # noqa: F401 ) -from .mma_macro_generator import ( +from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) -from .mma_layout import get_swizzle_layout # noqa: F401 -from .mma_layout import make_mma_swizzle_layout # noqa: F401 +from tilelang.cuda.intrinsics.layout.mma_layout import get_swizzle_layout # noqa: F401 +from tilelang.cuda.intrinsics.layout.mma_layout import make_mma_swizzle_layout # noqa: F401 -from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 +from tilelang.rocm.intrinsics.mfma_layout import make_mfma_swizzle_layout # noqa: F401 diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index d5ec03728f..1aa8aea17d 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -453,7 +453,7 @@ def make_blockscaled_gemm_layout( Returns: A Layout object for C's TMEM storage. """ - from tilelang.intrinsics.tcgen05_macro_generator import TensorCoreIntrinEmitter + from tilelang.cuda.intrinsics.macro.tcgen05_macro_generator import TensorCoreIntrinEmitter C_region = to_buffer_region(C) A_region = to_buffer_region(A) diff --git a/tilelang/rocm/__init__.py b/tilelang/rocm/__init__.py new file mode 100644 index 0000000000..a3b9cf6b63 --- /dev/null +++ b/tilelang/rocm/__init__.py @@ -0,0 +1,2 @@ +from . import intrinsics # noqa: F401 +from . import op # noqa: F401 diff --git a/tilelang/rocm/intrinsics/__init__.py b/tilelang/rocm/intrinsics/__init__.py new file mode 100644 index 0000000000..a972f683f1 --- /dev/null +++ b/tilelang/rocm/intrinsics/__init__.py @@ -0,0 +1,12 @@ +from .utils import ( # noqa: F401 + mfma_store_index_map, + mfma_store_index_map_32x32, + get_mma_micro_size, +) + +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 +from .mfma_macro_generator import ( # noqa: F401 + MatrixCoreIntrinEmitter, + MatrixCorePreshuffleIntrinEmitter, +) +from .wmma_macro_generator import WMMAIntrinEmitter # noqa: F401 diff --git a/tilelang/intrinsics/mfma_layout.py b/tilelang/rocm/intrinsics/mfma_layout.py similarity index 100% rename from tilelang/intrinsics/mfma_layout.py rename to tilelang/rocm/intrinsics/mfma_layout.py diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/rocm/intrinsics/mfma_macro_generator.py similarity index 100% rename from tilelang/intrinsics/mfma_macro_generator.py rename to tilelang/rocm/intrinsics/mfma_macro_generator.py diff --git a/tilelang/rocm/intrinsics/utils.py b/tilelang/rocm/intrinsics/utils.py new file mode 100644 index 0000000000..b1b4f68f76 --- /dev/null +++ b/tilelang/rocm/intrinsics/utils.py @@ -0,0 +1,23 @@ +from typing import Literal + +from .mfma_layout import ( + thread_id_shared_access_64x4_to_16x16_layout_C_n_m, + thread_id_shared_access_64x16_to_32x32_layout_C_m_n, +) +from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 + + +def mfma_store_index_map(thread_id, local_id): + return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) + + +def mfma_store_index_map_32x32(thread_id, local_id): + return thread_id_shared_access_64x16_to_32x32_layout_C_m_n(thread_id, local_id) + + +def get_mma_micro_size(dtype: Literal["float16", "int8"]): + micro_size_x = micro_size_y = 16 + micro_size_k = 16 + if dtype in {"float8_e4m3", "float8_e5m2", "int8"}: + micro_size_k = 32 + return micro_size_x, micro_size_y, micro_size_k diff --git a/tilelang/intrinsics/wmma_layout.py b/tilelang/rocm/intrinsics/wmma_layout.py similarity index 100% rename from tilelang/intrinsics/wmma_layout.py rename to tilelang/rocm/intrinsics/wmma_layout.py diff --git a/tilelang/intrinsics/wmma_macro_generator.py b/tilelang/rocm/intrinsics/wmma_macro_generator.py similarity index 100% rename from tilelang/intrinsics/wmma_macro_generator.py rename to tilelang/rocm/intrinsics/wmma_macro_generator.py diff --git a/tilelang/backend/rocm/op/__init__.py b/tilelang/rocm/op/__init__.py similarity index 100% rename from tilelang/backend/rocm/op/__init__.py rename to tilelang/rocm/op/__init__.py diff --git a/tilelang/backend/rocm/op/gemm/__init__.py b/tilelang/rocm/op/gemm/__init__.py similarity index 100% rename from tilelang/backend/rocm/op/gemm/__init__.py rename to tilelang/rocm/op/gemm/__init__.py diff --git a/tilelang/backend/rocm/op/gemm/gemm_mfma.py b/tilelang/rocm/op/gemm/gemm_mfma.py similarity index 99% rename from tilelang/backend/rocm/op/gemm/gemm_mfma.py rename to tilelang/rocm/op/gemm/gemm_mfma.py index 5ca8676183..81f53d6eeb 100644 --- a/tilelang/backend/rocm/op/gemm/gemm_mfma.py +++ b/tilelang/rocm/op/gemm/gemm_mfma.py @@ -2,7 +2,7 @@ from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.mfma_macro_generator import ( +from tilelang.rocm.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter, ) from tilelang.utils.language import is_shared, is_fragment, is_full_region diff --git a/tilelang/backend/rocm/op/gemm/gemm_wmma.py b/tilelang/rocm/op/gemm/gemm_wmma.py similarity index 98% rename from tilelang/backend/rocm/op/gemm/gemm_wmma.py rename to tilelang/rocm/op/gemm/gemm_wmma.py index c9b1783d38..4e0e0646e6 100644 --- a/tilelang/backend/rocm/op/gemm/gemm_wmma.py +++ b/tilelang/rocm/op/gemm/gemm_wmma.py @@ -4,7 +4,7 @@ from tilelang.tileop.gemm.gemm_base import GemmBase from tilelang.layout import make_swizzled_layout -from tilelang.intrinsics.wmma_macro_generator import WMMAIntrinEmitter +from tilelang.rocm.intrinsics.wmma_macro_generator import WMMAIntrinEmitter from tilelang.utils.language import is_shared, is_fragment, is_full_region from tilelang import tvm as tvm from tvm.target import Target From b2f64fedd0b269125768ccce95ca2140dadfb497 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 8 May 2026 15:47:46 +0800 Subject: [PATCH 3/4] Remove deprecated intrinsic implementations and related tests for matrix multiplication and element-wise addition. This includes the deletion of files for `benchmark_matmul_intrinsic.py`, `example_tilelang_gemm_fp8_intrinsic.py`, and associated test files, streamlining the codebase by eliminating unused components. --- .../matmul/benchmark_matmul_intrinsic.py | 316 ------------------ .../example_tilelang_gemm_fp8_intrinsic.py | 248 -------------- .../gemm_fp8/regression_example_gemm_fp8.py | 5 - examples/gemm_fp8/test_example_gemm_fp8.py | 5 - .../test_tilelang_kernel_element_wise_add.py | 109 ------ .../test_tilelang_kernel_fp8_gemm_mma.py | 228 ------------- ...test_tilelang_kernel_gemm_mma_intrinsic.py | 240 ------------- 7 files changed, 1151 deletions(-) delete mode 100644 benchmark/matmul/benchmark_matmul_intrinsic.py delete mode 100644 examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py delete mode 100644 testing/python/kernel/test_tilelang_kernel_element_wise_add.py delete mode 100644 testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py delete mode 100644 testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py deleted file mode 100644 index bc6b2b8e96..0000000000 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ /dev/null @@ -1,316 +0,0 @@ -import argparse -import logging -from tilelang import tvm as tvm -from tvm import DataType -import tilelang as tl -import tilelang.language as T -from tilelang.cuda.intrinsics import get_swizzle_layout -from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func -from tilelang.autotuner import autotune -import itertools - -# Configure logger -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - block_row_warps=1, - block_col_warps=1, - warp_row_tiles=16, - warp_col_tiles=16, - chunk=32, - stage=2, - enable_rasteration=False, -): - assert in_dtype in [ - T.float16, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - if out_dtype == T.int32: - micro_size_k = 32 - - # This is a debug config - # chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M, - block_N, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10, enable=enable_rasteration) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a(A_local, A_shared, ki) - - # Load B into fragment - mma_emitter.ldmatrix_b(B_local, B_shared, ki) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix(C_local, C_shared) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[i, j] - - return main - - -def ref_program(A, B): - """Reference matrix multiplication program.""" - return A @ B.T - - -def get_configs(args, kwargs): - """ - Generate a list of configuration dictionaries that will be used for tuning. - - Parameters - ---------- - with_roller : bool - Whether to enable bitblas roller to deduce search spaces - - Returns - ------- - list of dict - Each configuration dict includes various block sizes, pipeline stages, - thread numbers, and other parameters to explore during autotuning. - """ - M, N, K = args[:3] - with_roller = args[6] - - if with_roller: - from tilelang.carver.template import MatmulTemplate - from tilelang.carver.arch import CUDA - from tilelang.carver.arch import CDNA - from tilelang.carver.roller.rasterization import NoRasterization - import torch - - arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") - topk = 10 - - carve_template = MatmulTemplate( - M=M, - N=N, - K=K, - in_dtype=T.float16, - out_dtype=T.float16, - accum_dtype=T.float16, - ).with_arch(arch) - - func = carve_template.equivalent_function() - assert func is not None, "Function is None" - - roller_hints = carve_template.recommend_hints(topk=topk) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - configs = [] - for hint in roller_hints: - config = {} - block_m, block_n = hint.block - warp_m, warp_n = hint.warp - config["block_row_warps"] = block_m // warp_m - config["block_col_warps"] = block_n // warp_n - config["warp_row_tiles"] = warp_m - config["warp_col_tiles"] = warp_n - config["chunk"] = hint.rstep[0] - config["stage"] = hint.pipeline_stage - config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization - configs.append(config) - for config in configs: - print(config) - else: - iter_params = dict( - block_row_warps=[1, 2, 4], - block_col_warps=[1, 2, 4], - warp_row_tiles=[16, 32, 64, 128], - warp_col_tiles=[16, 32, 64, 128], - chunk=[32, 64, 128, 256], - stage=[0, 2], - enable_rasteration=[True, False], - ) - return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())] - - return configs - - -@autotune( - configs=get_configs, - warmup=3, - rep=5, - ref_prog=ref_program, - skip_check=True, -) -@tl.jit( - out_idx=[2], -) -def matmul( - M, - N, - K, - in_dtype=T.float16, - out_dtype=T.float16, - accum_dtype=T.float16, - with_roller=False, - block_row_warps=None, - block_col_warps=None, - warp_row_tiles=None, - warp_col_tiles=None, - chunk=None, - stage=None, - enable_rasteration=None, -): - """Create an autotuned tensor core matrix multiplication kernel.""" - - def kernel(): - return tl_matmul( - M, - N, - K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - stage=stage, - enable_rasteration=enable_rasteration, - ) - - return kernel() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Autotuned TensorCore MatMul Benchmark") - parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces") - parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type") - args = parser.parse_args() - - M, N, K = args.m, args.n, args.k - in_dtype = T.dtype(args.dtype) - out_dtype = T.float32 if in_dtype == T.int8 else T.float16 - accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 - with_roller = args.with_roller - with_roller = True - # Compute total floating-point operations - total_flops = 2 * M * N * K - - # Run autotuning - best_result = matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_roller) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency - - # Print benchmark results - print(f"Best latency (s): {best_latency}") - print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") - print(f"Best config: {best_config}") - print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py deleted file mode 100644 index 2085ee8924..0000000000 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ /dev/null @@ -1,248 +0,0 @@ -import torch -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.cuda.intrinsics import get_swizzle_layout -from tilelang.cuda.intrinsics.macro.mma_macro_generator import TensorCoreIntrinEmitter -from tilelang.rocm.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter -from tilelang.utils import determine_fp8_type - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@tilelang.jit(out_idx=[2]) -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.float8_e4m3fn, - T.float8_e4m3fnuz, - T.float8_e5m2, - T.float8_e5m2fnuz, - T.int8, - ], "Currently only float16, float8, and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - is_hip = torch.version.hip is not None - # MMA Wrapper to Auto Generate Code for MMA/MFMA - if is_hip: - mma_emitter = MatrixCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - else: - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - micro_size_x = mma_emitter.M_DIM - micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x)) - micro_size_k = mma_emitter.k_dim - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - threads = mma_emitter.threads - local_size_a = mma_emitter.local_size_a - local_size_b = mma_emitter.local_size_b - local_size_c = mma_emitter.local_size_out - warp_rows = mma_emitter.warp_rows - warp_cols = mma_emitter.warp_cols - - @T.prim_func - def gemm_fp8_intrinsic( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - if is_hip: - mma_emitter.mfma(A_local, B_local, C_local, ki) - else: - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return gemm_fp8_intrinsic - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = in_dtype.as_torch() - out_dtype = out_dtype.as_torch() - accum_dtype = accum_dtype.as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) - - profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) - - C = profiler(A, B) - - latency = profiler.do_bench(warmup=25) - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -def main(): - e4m3_dtype = determine_fp8_type() - assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) - e5m2_dtype = determine_fp8_type("e5m2") - assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) - - -def run_regression_perf(): - M, N, K = 4096, 4096, 4096 - out_dtype, accum_dtype = T.float32, T.float32 - in_dtype = determine_fp8_type() - kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - if torch.version.hip is None: - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - else: - latency_e4m3 = profiler_e4m3.do_bench() - return latency_e4m3 - - -if __name__ == "__main__": - main() diff --git a/examples/gemm_fp8/regression_example_gemm_fp8.py b/examples/gemm_fp8/regression_example_gemm_fp8.py index 3ba2f4f274..5bf0c80505 100644 --- a/examples/gemm_fp8/regression_example_gemm_fp8.py +++ b/examples/gemm_fp8/regression_example_gemm_fp8.py @@ -1,17 +1,12 @@ import tilelang.testing import example_tilelang_gemm_fp8 import example_tilelang_gemm_fp8_2xAcc -import example_tilelang_gemm_fp8_intrinsic def regression_example_tilelang_gemm_fp8_2xAcc(): tilelang.testing.process_func(example_tilelang_gemm_fp8_2xAcc.run_regression_perf) -def regression_example_tilelang_gemm_fp8_intrinsic(): - tilelang.testing.process_func(example_tilelang_gemm_fp8_intrinsic.run_regression_perf) - - def regression_example_tilelang_gemm_fp8(): tilelang.testing.process_func(example_tilelang_gemm_fp8.run_regression_perf) diff --git a/examples/gemm_fp8/test_example_gemm_fp8.py b/examples/gemm_fp8/test_example_gemm_fp8.py index 19a9ee00a7..3b657d72ae 100644 --- a/examples/gemm_fp8/test_example_gemm_fp8.py +++ b/examples/gemm_fp8/test_example_gemm_fp8.py @@ -1,6 +1,5 @@ import tilelang.testing import example_tilelang_gemm_fp8_2xAcc -import example_tilelang_gemm_fp8_intrinsic import example_tilelang_gemm_fp8 @@ -8,10 +7,6 @@ def test_example_tilelang_gemm_fp8_2xAcc(): example_tilelang_gemm_fp8_2xAcc.main() -def test_example_tilelang_gemm_fp8_intrinsic(): - example_tilelang_gemm_fp8_intrinsic.main() - - def test_example_tilelang_gemm_fp8(): example_tilelang_gemm_fp8.main() diff --git a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py deleted file mode 100644 index 501b38fda8..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py +++ /dev/null @@ -1,109 +0,0 @@ -import tilelang.testing -from tilelang import language as T -import torch - - -def elementwise_add( - M, - N, - block_M, - block_N, - in_dtype, - out_dtype, - threads, -): - @T.prim_func - def main( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - start_x = bx * block_N - start_y = by * block_M - - for local_y, local_x in T.Parallel(block_M, block_N): - y = start_y + local_y - x = start_x + local_x - - C[y, x] = A[y, x] + B[y, x] - - return main - - -def run_elementwise_add( - M, - N, - in_dtype, - out_dtype, - block_M, - block_N, - num_threads=128, -): - program = elementwise_add( - M, - N, - block_M, - block_N, - in_dtype, - out_dtype, - num_threads, - ) - - kernel = tilelang.compile(program, out_idx=[2]) - profiler = kernel.get_profiler() - - def ref_program(A, B): - C = torch.add(A, B) - C = C.to(torch.__getattribute__(out_dtype)) - return C - - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - - -def test_elementwise_add_f32(): - run_elementwise_add( - 512, - 1024, - T.float32, - T.float32, - 128, - 256, - ) - - -def test_elementwise_add_f16(): - run_elementwise_add( - 512, - 1024, - T.float16, - T.float16, - 128, - 256, - ) - - -def test_elementwise_add_i32(): - run_elementwise_add( - 512, - 1024, - T.int32, - T.int32, - 128, - 256, - ) - - -def test_elementwise_add_f32f16(): - run_elementwise_add( - 512, - 1024, - T.float32, - T.float16, - 128, - 256, - ) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py deleted file mode 100644 index ae728854a0..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.backends -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.cuda.intrinsics import get_swizzle_layout -from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.float8_e4m3fn, - T.float8_e5m2, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - print(src_code) - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = T.dtype(in_dtype).as_torch() - out_dtype = T.dtype(out_dtype).as_torch() - accum_dtype = T.dtype(accum_dtype).as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = kernel(A, B) - - latency = profiler.do_bench() - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - print(C) - print(ref_c) - torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 9) -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py deleted file mode 100644 index 76a6e3d610..0000000000 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ /dev/null @@ -1,240 +0,0 @@ -import torch -import torch.backends -from tilelang import tvm as tvm -import tilelang.testing -from tvm import DataType -import tilelang.language as T -from tilelang.cuda.intrinsics import get_swizzle_layout -from tilelang.cuda.intrinsics.macro.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func - -tilelang.testing.set_random_seed(0) - - -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - -@simplify_prim_func -def tl_matmul( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, -): - assert in_dtype in [ - T.float16, - T.bfloat16, - T.float8_e4m3fn, - T.float8_e5m2, - T.int8, - ], "Currently only float16 and int8 are supported" - assert out_dtype in [ - T.float16, - T.float32, - T.int32, - ], "Currently only float16, float32 and int32 are supported" - - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - - # This is a debug config - block_row_warps = 2 - block_col_warps = 2 - warp_row_tiles = 32 - warp_col_tiles = 32 - chunk = 32 if in_dtype == T.float16 else 64 - shared_scope = "shared.dyn" - - # Pipeline Stage - stage = 2 - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - - T.annotate_layout( - { - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - } - ) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - - for ko in T.Pipelined((K // block_K), num_stages=stage): - # Load A into shared memory - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - # Load B into shared memory - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - ) - - # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) - - # Perform STMatrix - mma_emitter.stmatrix( - C_local, - C_shared, - ) - - # Store shared into global - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) - profiler = kernel.get_profiler() - - src_code = kernel.get_kernel_source() - # src_code is the generated cuda source - assert src_code is not None - - in_dtype = T.dtype(in_dtype).as_torch() - out_dtype = T.dtype(out_dtype).as_torch() - accum_dtype = T.dtype(accum_dtype).as_torch() - - if in_dtype in {torch.int8, torch.int32}: - A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() - B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: - A = torch.randn(M, K).to(in_dtype).cuda() - B = torch.randn(N, K).to(in_dtype).cuda() - else: - A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 - B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 - - C = kernel(A, B) - - latency = profiler.do_bench() - - # Ensure that the latency is not None - assert latency is not None - - # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype) - tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 0) -def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) - assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) - assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 0) -def test_assert_tl_matmul_bfloat16(): - assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version(8, 9) -def test_assert_tl_matmul_fp8(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) - - -if __name__ == "__main__": - tilelang.testing.main() From 9752f55b143c22b050242266ebf307f9c5523b14 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 8 May 2026 16:28:12 +0800 Subject: [PATCH 4/4] Move CUDA transform passes back to common transform --- src/backend/cuda/CMakeLists.txt | 1 - src/{backend/cuda => }/transform/lower_hopper_intrin.cc | 2 +- .../cuda => }/transform/lower_l2_persistent_annotation.cc | 3 +-- src/{backend/cuda => }/transform/persist_threadblock.cc | 3 +-- 4 files changed, 3 insertions(+), 6 deletions(-) rename src/{backend/cuda => }/transform/lower_hopper_intrin.cc (99%) rename src/{backend/cuda => }/transform/lower_l2_persistent_annotation.cc (97%) rename src/{backend/cuda => }/transform/persist_threadblock.cc (94%) diff --git a/src/backend/cuda/CMakeLists.txt b/src/backend/cuda/CMakeLists.txt index 8868d37669..5918282457 100644 --- a/src/backend/cuda/CMakeLists.txt +++ b/src/backend/cuda/CMakeLists.txt @@ -146,7 +146,6 @@ file(GLOB TILE_LANG_CUDA_SRCS src/backend/cuda/codegen/rt_mod_cuda.cc src/backend/cuda/codegen/rt_mod_cutedsl.cc src/backend/cuda/op/*.cc - src/backend/cuda/transform/*.cc ) list(REMOVE_ITEM TILE_LANG_CUDA_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/src/backend/cuda/op/copy_analysis.cc") diff --git a/src/backend/cuda/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc similarity index 99% rename from src/backend/cuda/transform/lower_hopper_intrin.cc rename to src/transform/lower_hopper_intrin.cc index f70f719cf3..e9ea2cdbc4 100644 --- a/src/backend/cuda/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -1,5 +1,5 @@ /*! - * \file tl/backend/cuda/transform/lower_hopper_intrin.cc + * \file tl/transform/lower_hopper_intrin.cc * \brief Lower Hopper intrinsics cuda GPU(sm90+) */ diff --git a/src/backend/cuda/transform/lower_l2_persistent_annotation.cc b/src/transform/lower_l2_persistent_annotation.cc similarity index 97% rename from src/backend/cuda/transform/lower_l2_persistent_annotation.cc rename to src/transform/lower_l2_persistent_annotation.cc index 6b3a9b612b..5f9f44a5c2 100644 --- a/src/backend/cuda/transform/lower_l2_persistent_annotation.cc +++ b/src/transform/lower_l2_persistent_annotation.cc @@ -1,5 +1,5 @@ /*! - * \file tl/backend/cuda/transform/lower_l2_persistent_annotation.cc + * \file tl/transform/lower_l2_persistent_annotation.cc * \brief Lower L2 persistent annotation */ @@ -9,7 +9,6 @@ #include #include -#include "backend/cuda/runtime.h" #include "op/builtin.h" namespace tvm { diff --git a/src/backend/cuda/transform/persist_threadblock.cc b/src/transform/persist_threadblock.cc similarity index 94% rename from src/backend/cuda/transform/persist_threadblock.cc rename to src/transform/persist_threadblock.cc index 4a0a09ecc9..d9183d1e2b 100644 --- a/src/backend/cuda/transform/persist_threadblock.cc +++ b/src/transform/persist_threadblock.cc @@ -1,5 +1,5 @@ /*! - * \file tl/backend/cuda/transform/persist_threadblock.cc + * \file tl/transform/persist_threadblock.cc * \brief Persist thread blocks with cooperative groups. */ @@ -9,7 +9,6 @@ #include #include -#include "backend/cuda/runtime.h" #include "op/builtin.h" namespace tvm {