Skip to content
47 changes: 47 additions & 0 deletions src/tl_templates/hip/common.h
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,53 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}

TL_DEVICE int tl_dp4a_fallback(const int a, const int b, int c) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int ai = static_cast<int8_t>((a >> (8 * i)) & 0xff);
const int bi = static_cast<int8_t>((b >> (8 * i)) & 0xff);
c += ai * bi;
}
return c;
}

#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1200__) || defined(__gfx1201__)
#define TL_AMDGPU_HAS_SUDOT4 1
#endif

#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || \
defined(__gfx950__) || defined(__gfx1010__) || defined(__gfx1011__) || \
defined(__gfx1012__) || defined(__gfx1030__) || defined(__gfx1031__) || \
defined(__gfx1032__) || defined(__gfx1034__) || defined(__gfx1035__) || \
defined(TL_AMDGPU_HAS_SUDOT4)
#define TL_AMDGPU_HAS_SDOT4 1
#endif
Comment thread
coderabbitai[bot] marked this conversation as resolved.

TL_DEVICE int tl_dp4a(const int a, const int b, const int c) {
#if defined(TL_AMDGPU_HAS_SUDOT4)
return __builtin_amdgcn_sudot4(true, a, true, b, c, false);
#elif defined(TL_AMDGPU_HAS_SDOT4)
return __builtin_amdgcn_sdot4(a, b, c, false);
#else
return tl_dp4a_fallback(a, b, c);
#endif
}

template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
int a_int;
int b_int;
int c_int;
__builtin_memcpy(&a_int, a, sizeof(a_int));
__builtin_memcpy(&b_int, b, sizeof(b_int));
__builtin_memcpy(&c_int, c, sizeof(c_int));
const int out = tl_dp4a(a_int, b_int, c_int);
__builtin_memcpy(c, &out, sizeof(out));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm.
// hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard
// __habs overloads don't match exactly, causing ambiguous overload errors.
Expand Down
76 changes: 76 additions & 0 deletions testing/python/target/test_tilelang_rocm_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from tilelang import tvm as tvm
from tvm.target import Target

from tilelang.contrib.hipcc import _target_mcpu
import tilelang.utils.target as target_utils
from tilelang.utils.target import (
determine_target,
normalize_rocm_arch,
rocm_warp_size_for_arch,
target_get_mcpu,
target_get_rdna_generation,
target_get_warp_size,
target_is_cdna,
target_is_rdna,
)


def test_normalize_rocm_arch_strips_feature_suffix():
assert normalize_rocm_arch("gfx1151:sramecc+:xnack-") == "gfx1151"
assert normalize_rocm_arch("gfx942") == "gfx942"
assert normalize_rocm_arch("") is None
assert normalize_rocm_arch("sm_90") is None
assert rocm_warp_size_for_arch("gfx1151") == 32
assert rocm_warp_size_for_arch("gfx1030") == 32
assert rocm_warp_size_for_arch("gfx942") == 64


def test_target_mcpu_helpers():
target = Target("hip -mcpu=gfx1151:sramecc+:xnack-")
assert target_get_mcpu(target) == "gfx1151"
assert _target_mcpu(target) == "gfx1151"


def test_determine_target_adds_rdna_thread_warp_size():
target = determine_target("hip -mcpu=gfx1151", return_object=True)
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_auto_target_prefers_rocm_pytorch_over_cuda_toolkit(monkeypatch):
monkeypatch.setattr(target_utils.torch.version, "hip", "test", raising=False)
monkeypatch.setattr(target_utils, "check_hip_availability", lambda: True)
monkeypatch.setattr(target_utils, "check_cuda_availability", lambda: True)
monkeypatch.setattr(target_utils, "_detect_torch_rocm_arch", lambda: "gfx1151")

target = determine_target("auto", return_object=True)
assert target.kind.name == "hip"
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_rdna_gfx1151_target_classification():
target = Target("hip -mcpu=gfx1151")
assert target_is_rdna(target)
assert not target_is_cdna(target)
assert target_get_rdna_generation(target) == 11
assert target_get_warp_size(target) == 32


def test_carver_routes_rdna_without_instantiating_device(monkeypatch):
import torch

monkeypatch.setattr(torch.version, "hip", None, raising=False)
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
if hasattr(torch, "mps"):
monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False)

import tilelang.carver.arch as arch_mod

def fake_rdna(target):
return ("rdna", target)

monkeypatch.setattr(arch_mod, "RDNA", fake_rdna)
arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151"))
assert arch[0] == "rdna"
assert target_get_mcpu(arch[1]) == "gfx1151"
2 changes: 1 addition & 1 deletion tilelang/carver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .arch import CUDA, CDNA, RDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401
8 changes: 7 additions & 1 deletion tilelang/carver/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from .cuda import *
from .cpu import *
from .cdna import *
from .rdna import *
from .metal import *
from tvm.target import Target
import torch
from tilelang.utils.target import determine_target, target_is_rdna


def get_arch(target: str | Target = "cuda") -> TileDevice:
Expand All @@ -18,6 +20,8 @@ def get_arch(target: str | Target = "cuda") -> TileDevice:
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
if target_is_rdna(target):
return RDNA(target)
return CDNA(target)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif target.kind.name == "metal":
return METAL(target)
Expand All @@ -29,7 +33,7 @@ def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
if torch.version.hip is not None:
return get_arch("hip")
return get_arch(determine_target("auto", return_object=True))
if torch.cuda.is_available():
return get_arch("cuda")
elif torch.mps.is_available():
Expand All @@ -48,9 +52,11 @@ def auto_infer_current_arch() -> TileDevice:
"is_tensorcore_supported_precision",
"has_mma_support",
"is_cdna_arch",
"is_rdna_arch",
"is_metal_arch",
"CUDA",
"CDNA",
"RDNA",
"METAL",
"CPU",
]
52 changes: 52 additions & 0 deletions tilelang/carver/arch/rdna.py
Comment thread
lhl marked this conversation as resolved.
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from .cuda import TensorInstruction

_RDNA_DEFAULT_LDS_SIZE = 64 * 1024


def is_rdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, RDNA)


class RDNA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "RDNA"

reported_smem = device.max_shared_memory_per_block
self.smem_cap = reported_smem if reported_smem > 0 else _RDNA_DEFAULT_LDS_SIZE
self.compute_max_core = device.multi_processor_count
self.warp_size = 32
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = getattr(target, "l2_cache_size_bytes", 0)
self.transaction_size: list[int] = [32, 128]

# Keep the same units as the existing CUDA/CDNA heuristic. Strix Halo
# is a UMA part, so use a conservative global-memory score seed.
self.bandwidth: list[int] = [750, 12080]
self.available_tensor_instructions: list[TensorInstruction] | None = None

def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (TensorInstruction("wmma", [16, 16]),)
return [t.shape for t in self.available_tensor_instructions]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __repr__(self):
return f"RDNA({self.target})"


__all__ = [
"is_rdna_arch",
"RDNA",
]
9 changes: 9 additions & 0 deletions tilelang/carver/roller/policy/tensorcore.py
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def check_tile_shape_isvalid(self, td: TileDict):
return False
return super().check_tile_shape_isvalid(td)

def score_block_size(self, n):
base_score = super().score_block_size(n)
if getattr(self.arch, "platform", None) == "RDNA":
warps = (n + self.arch.warp_size - 1) // self.arch.warp_size
return (0 if warps == 8 else 1, abs(warps - 8), *base_score)
Comment thread
lhl marked this conversation as resolved.
return base_score

def _can_implement_layout(self, node: PrimFuncNode, td: TileDict):
# Not implemented yet
# This function is used to check whether we can implement swizzling
Expand Down Expand Up @@ -256,6 +263,8 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
if np.prod(space) < warps:
return None

factors = factorize(np.prod(space) // warps)

Expand Down
10 changes: 10 additions & 0 deletions tilelang/carver/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
is_volta_arch,
is_ampere_arch,
is_cdna_arch,
is_rdna_arch,
auto_infer_current_arch,
)
from ..roller.hint import Hint # Import the Hint class
Expand Down Expand Up @@ -95,6 +96,15 @@ def is_cdna_arch(self) -> bool:
"""
return is_cdna_arch(self._arch) if self._arch is not None else False

def is_rdna_arch(self) -> bool:
"""
Checks if the current architecture is an RDNA architecture.

Returns:
bool: True if the architecture is RDNA, False otherwise.
"""
return is_rdna_arch(self._arch) if self._arch is not None else False

def equivalent_function(self) -> PrimFunc:
"""
Returns the function associated with this template.
Expand Down
13 changes: 12 additions & 1 deletion tilelang/contrib/hipcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
from tvm.contrib.rocm import get_rocm_arch, find_rocm_path


def _target_mcpu(target):
try:
mcpu = target.attrs.get("mcpu")
except AttributeError:
return None
if mcpu is None:
return None
arch = str(mcpu).strip().split(":", maxsplit=1)[0]
return arch if arch.startswith("gfx") else None


def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False):
"""Compile HIP code with hipcc.

Expand Down Expand Up @@ -94,5 +105,5 @@ def compile_hip(code, target_format="hsaco", arch=None, options=None, path_targe
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
hsaco = compile_hip(code, target_format="hsaco", arch=_target_mcpu(target))
return hsaco
4 changes: 3 additions & 1 deletion tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tilelang.transform import PassConfigKey
from tilelang.transform.metal import MarkHostMetalContext
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target
from tilelang.utils.target import determine_target, target_get_mcpu
from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize,
Expand Down Expand Up @@ -160,9 +160,11 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None):

@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
arch = target_get_mcpu(target)
hsaco = hipcc.compile_hip(
code,
target_format="hsaco",
arch=arch,
options=[
"-std=c++17",
"-I" + TILELANG_TEMPLATE_PATH,
Expand Down
3 changes: 2 additions & 1 deletion tilelang/jit/adapter/libgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH
from tilelang.utils.target import target_get_mcpu

from .utils import is_cpu_target, is_cuda_target, is_hip_target

Expand Down Expand Up @@ -97,7 +98,7 @@ def compile_lib(self, timeout: float = None):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115
libpath = src.name.replace(".cpp", ".so")
rocm_path = find_rocm_path()
arch = get_rocm_arch(rocm_path)
arch = target_get_mcpu(target) or get_rocm_arch(rocm_path)
command = [
"hipcc",
"-std=c++17",
Expand Down
Loading