Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ if(TILELANG_USE_CUDA_STUBS)
endif()

file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/runtime.cc
src/backend/cuda/runtime.cc
src/backend/cuda/codegen/ptx.cc
src/backend/cuda/codegen/codegen_cuda.cc
src/backend/cuda/codegen/codegen_py.cc
Expand All @@ -146,6 +146,7 @@ file(GLOB TILE_LANG_CUDA_SRCS
src/backend/cuda/codegen/rt_mod_cuda.cc
src/backend/cuda/codegen/rt_mod_cutedsl.cc
src/backend/cuda/op/*.cc
src/backend/cuda/transform/*.cc
)
list(REMOVE_ITEM TILE_LANG_CUDA_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/src/backend/cuda/op/copy_analysis.cc")
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/runtime.cc → src/backend/cuda/runtime.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* \file tl/runtime/runtime.h
* \file tl/backend/cuda/runtime.cc
* \brief Runtime functions.
*
*/
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/runtime.h → src/backend/cuda/runtime.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*!
* \file tl/runtime/runtime.h
* \file tl/backend/cuda/runtime.h
* \brief Runtime functions.
*
*/

#ifndef TVM_TL_RUNTIME_RUNTIME_H_
#define TVM_TL_RUNTIME_RUNTIME_H_
#ifndef TVM_TL_BACKEND_CUDA_RUNTIME_H_
#define TVM_TL_BACKEND_CUDA_RUNTIME_H_

namespace tvm {
namespace tl {
Expand All @@ -25,4 +25,4 @@ constexpr const char *tvm_cuda_stream_reset_access_policy_window =
} // namespace tl
} // namespace tvm

#endif // TVM_TL_RUNTIME_RUNTIME_H_
#endif // TVM_TL_BACKEND_CUDA_RUNTIME_H_
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* \file lower hopper intrin.cc
* \file tl/backend/cuda/transform/lower_hopper_intrin.cc
* \brief Lower Hopper intrinsics cuda GPU(sm90+)
*/

Expand All @@ -13,8 +13,8 @@
#include <unordered_map>
#include <vector>

#include "../op/builtin.h"
#include "../runtime/runtime.h"
#include "backend/cuda/runtime.h"
#include "op/builtin.h"

namespace tvm {
namespace tl {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* \file lower_l2_persistent_annotation.cc
* \file tl/backend/cuda/transform/lower_l2_persistent_annotation.cc
* \brief Lower L2 persistent annotation
*/

Expand All @@ -9,8 +9,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "../runtime/runtime.h"
#include "backend/cuda/runtime.h"
#include "op/builtin.h"

namespace tvm {
namespace tl {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* \file lower_l2_persistent_annotation.cc
* \brief Lower L2 persistent annotation
* \file tl/backend/cuda/transform/persist_threadblock.cc
* \brief Persist thread blocks with cooperative groups.
*/

#include <tvm/ffi/reflection/registry.h>
Expand All @@ -9,8 +9,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "../runtime/runtime.h"
#include "backend/cuda/runtime.h"
#include "op/builtin.h"

namespace tvm {
namespace tl {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tilelang import tvm as tvm
import tilelang as tl
from tilelang.backend.cuda import transform as cuda_transform
from tilelang.utils.target import determine_target
import tilelang.language as T
import tilelang.testing
Expand All @@ -12,7 +13,7 @@ def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = cuda_transform.LowerHopperIntrin()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)
Expand Down Expand Up @@ -90,7 +91,7 @@ def before():

mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = cuda_transform.LowerHopperIntrin()(mod)
func = mod["main"]

assert not tvm.tir.analysis.undefined_vars(func.body, func.params)
Expand Down
4 changes: 2 additions & 2 deletions tilelang/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .gemm import register_gemm_impl, resolve_gemm_impl # noqa: F401
from .gemm_sp import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401
from tilelang.tileop.gemm.registry import register_gemm_impl, resolve_gemm_impl # noqa: F401
from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl, resolve_gemm_sp_impl # noqa: F401

# Import built-in backend packages so their implementations register.
from . import cpu as _cpu # noqa: F401,E402
Expand Down
2 changes: 1 addition & 1 deletion tilelang/backend/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import gemm # noqa: F401
from . import op # noqa: F401
1 change: 1 addition & 0 deletions tilelang/backend/cpu/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import gemm # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from tilelang.backend.gemm import register_gemm_impl
from tilelang.tileop.gemm.gemm_scalar import GEMM_INST_SCALAR, GemmScalar
from tilelang.tileop.gemm.registry import register_gemm_impl
from .gemm_scalar import GEMM_INST_SCALAR, GemmScalar


def _match_scalar(target) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tilelang/backend/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import gemm # noqa: F401
from . import gemm_sp # noqa: F401
from . import op # noqa: F401
from . import transform # noqa: F401
4 changes: 4 additions & 0 deletions tilelang/backend/cuda/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""CUDA op registration frontends."""

from . import gemm # noqa: F401
from . import gemm_sp # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""CUDA GEMM op registrations."""

from __future__ import annotations

from tilelang.backend.gemm import register_gemm_impl
from tilelang.tileop.gemm.gemm_mma import GEMM_INST_MMA, GemmMMA
from tilelang.tileop.gemm.gemm_mma_sm70 import GemmMMASm70
from tilelang.tileop.gemm.gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5
from tilelang.tileop.gemm.gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA
from tilelang.tileop.gemm.registry import register_gemm_impl
from .gemm_mma import GEMM_INST_MMA, GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_tcgen05 import GEMM_INST_TCGEN05, GemmTCGEN5
from .gemm_wgmma import GEMM_INST_WGMMA, GemmWGMMA
from tilelang.utils.target import target_is_cuda, target_is_volta


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

# for Volta GPUs, which use legacy MMA instructions
from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import (
make_full_bank_swizzled_layout,
make_half_bank_swizzled_layout,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""CUDA sparse GEMM op registrations."""

from __future__ import annotations

from tilelang.backend.gemm_sp import register_gemm_sp_impl
from tilelang.tileop.gemm_sp.gemm_sp_mma import GemmSPMMA
from tilelang.tileop.gemm_sp.registry import register_gemm_sp_impl
from .gemm_sp_mma import GemmSPMMA
from tilelang.utils.target import target_is_cuda


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .gemm_sp_base import GemmSPBase
from tilelang.tileop.gemm_sp.gemm_sp_base import GemmSPBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
from tilelang.utils.language import is_shared, is_fragment
Expand Down
27 changes: 27 additions & 0 deletions tilelang/backend/cuda/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""CUDA-specific transformation frontends."""

from tilelang.transform import _ffi_api


def LowerHopperIntrin():
"""LowerHopperIntrin"""
if hasattr(_ffi_api, "LowerHopperIntrin"):
return _ffi_api.LowerHopperIntrin() # type: ignore
return lambda f: f


def LowerL2Persistent():
"""LowerL2Persistent"""
return _ffi_api.LowerL2Persistent() # type: ignore


def PersistThreadblock():
"""PersistThreadblock"""
return _ffi_api.PersistThreadblock() # type: ignore


__all__ = [
"LowerHopperIntrin",
"LowerL2Persistent",
"PersistThreadblock",
]
2 changes: 1 addition & 1 deletion tilelang/backend/rocm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import gemm # noqa: F401
from . import op # noqa: F401
1 change: 1 addition & 0 deletions tilelang/backend/rocm/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import gemm # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from tilelang.backend.gemm import register_gemm_impl
from tilelang.tileop.gemm.gemm_mfma import GEMM_INST_MFMA, GemmMFMA
from tilelang.tileop.gemm.gemm_wmma import GEMM_INST_WMMA, GemmWMMA
from tilelang.tileop.gemm.registry import register_gemm_impl
from .gemm_mfma import GEMM_INST_MFMA, GemmMFMA
from .gemm_wmma import GEMM_INST_WMMA, GemmWMMA
from tilelang.utils.target import target_is_hip


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from .gemm_base import GemmBase
from tilelang.tileop.gemm.gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.wmma_macro_generator import WMMAIntrinEmitter
from tilelang.utils.language import is_shared, is_fragment, is_full_region
Expand Down
7 changes: 4 additions & 3 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
from tilelang.backend.cuda import transform as cuda_transform
from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, have_pdl

Expand Down Expand Up @@ -204,7 +205,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod)
mod = cuda_transform.LowerL2Persistent()(mod)
# Decouple type cast vectorization constraints before vectorization
mod = tilelang.transform.DecoupleTypeCast()(mod)
# Legalize vectorized loops to ensure they are valid
Expand Down Expand Up @@ -270,7 +271,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.InferFragment()(mod)
mod = tilelang.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerLDGSTG()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = cuda_transform.LowerHopperIntrin()(mod)
# Global Barrier Synchronization must be applied before
# SplitHostDevice pass, as the global barrier
if allow_global_thread_synchronization():
Expand Down Expand Up @@ -305,6 +306,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)

# Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
mod = cuda_transform.PersistThreadblock()(mod)

return mod
2 changes: 1 addition & 1 deletion tilelang/tileop/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.backend.gemm import resolve_gemm_impl
from .registry import resolve_gemm_impl
from tilelang import _ffi_api


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tilelang/tileop/gemm_sp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.backend.gemm_sp import resolve_gemm_sp_impl
from .registry import resolve_gemm_sp_impl
from tilelang.tileop.base import GemmWarpPolicy


Expand Down
File renamed without changes.
21 changes: 0 additions & 21 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,6 @@ def VerifyParallelLoop():
return _ffi_api.VerifyParallelLoop() # type: ignore


def LowerHopperIntrin():
"""LowerHopperIntrin

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore


def ThreadSync(storage_scope: str):
"""Insert sync between parallel read/write of shared buffers.

Expand Down Expand Up @@ -424,21 +413,11 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore


def LowerL2Persistent():
"""LowerL2Persistent"""
return _ffi_api.LowerL2Persistent() # type: ignore


def MarkCudaSyncCalls(have_pdl: bool = False):
"""MarkCudaSyncCalls"""
return _ffi_api.MarkCudaSyncCalls(have_pdl) # type: ignore


def PersistThreadblock():
"""PersistThreadblock"""
return _ffi_api.PersistThreadblock() # type: ignore


def LowerSharedBarrier():
"""LowerSharedBarrier"""
return _ffi_api.LowerSharedBarrier() # type: ignore
Expand Down