Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/op/parallel.cc
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
if (auto reducer_info_anno = op->annotations.Get(attr::kReducerInfo)) {
auto reducer_info_map =
reducer_info_anno.value().as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
}
StmtExprVisitor::VisitStmt_(op);
}
Expand Down
51 changes: 51 additions & 0 deletions src/tl_templates/hip/common.h
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,57 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}

TL_DEVICE int tl_dp4a_fallback(const int a, const int b, int c) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int ai = static_cast<int8_t>((a >> (8 * i)) & 0xff);
const int bi = static_cast<int8_t>((b >> (8 * i)) & 0xff);
c += ai * bi;
}
return c;
}

#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
defined(__gfx1200__) || defined(__gfx1201__)
#define TL_AMDGPU_HAS_SUDOT4 1
#endif

#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || \
defined(__gfx950__) || defined(__gfx1010__) || defined(__gfx1011__) || \
defined(__gfx1012__) || defined(__gfx1030__) || defined(__gfx1031__) || \
defined(__gfx1032__) || defined(__gfx1034__) || defined(__gfx1035__) || \
defined(TL_AMDGPU_HAS_SUDOT4)
#define TL_AMDGPU_HAS_SDOT4 1
#endif
Comment thread
coderabbitai[bot] marked this conversation as resolved.

TL_DEVICE int tl_dp4a(const int a, const int b, const int c) {
#if defined(TL_AMDGPU_HAS_SUDOT4)
return __builtin_amdgcn_sudot4(true, a, true, b, c, false);
#elif defined(TL_AMDGPU_HAS_SDOT4)
return __builtin_amdgcn_sdot4(a, b, c, false);
#else
return tl_dp4a_fallback(a, b, c);
#endif
}

template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(const InDatatype *a, const InDatatype *b, OutDatatype *c) {
static_assert(sizeof(InDatatype) == 1,
"DP4A expects a pointer to packed int8 lanes");
static_assert(sizeof(OutDatatype) == sizeof(int),
"DP4A expects 4-byte accumulator/output type");
int a_int;
int b_int;
int c_int;
__builtin_memcpy(&a_int, a, sizeof(a_int));
__builtin_memcpy(&b_int, b, sizeof(b_int));
__builtin_memcpy(&c_int, c, sizeof(c_int));
const int out = tl_dp4a(a_int, b_int, c_int);
__builtin_memcpy(c, &out, sizeof(out));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm.
// hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard
// __habs overloads don't match exactly, causing ambiguous overload errors.
Expand Down
8 changes: 5 additions & 3 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
auto result = IRMutatorWithAnalyzer::VisitStmt_(op).as<Block>().value();
// After iterating over the body, set all layout_map to block
auto p_result = result.CopyOnWrite();
auto layout_map = p_result->annotations.Get(attr::kLayoutMap)
->as<Map<Var, Layout>>()
.value_or(Map<Var, Layout>());
Map<Var, Layout> layout_map;
if (auto opt_layout_map = p_result->annotations.Get(attr::kLayoutMap)) {
layout_map =
opt_layout_map.value().as<Map<Var, Layout>>().value_or(layout_map);
}
for (auto &&[k, v] : new_layout_map_)
layout_map.Set(k, v);
if (!layout_map.empty())
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.flo
(128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 2),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack):
assert_tl_matmul_correctness(
M,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def assert_tl_matmul_correctness(
(256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(
M,
N,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gfx950_copy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def ref_program(A, B):
(True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_gfx950
def test_gfx950_copy_async_gemm_no_pipeline(trans_A, trans_B):
"""Non-pipelined GEMM (num_stages=0) must also produce correct results."""
prog = _matmul_kernel(
Expand Down
4 changes: 4 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_gemm_bf16bf16f32_nn():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nn():
run_gemm(
512,
Expand Down Expand Up @@ -205,10 +206,12 @@ def test_gemm_f16f16f16_nt():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64)

Expand All @@ -218,6 +221,7 @@ def test_gemm_f64f64f64_nt():
run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nt():
run_gemm(
512,
Expand Down
74 changes: 74 additions & 0 deletions testing/python/target/test_tilelang_rocm_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from tilelang import tvm as tvm
from tvm.target import Target

import tilelang.utils.target as target_utils
from tilelang.utils.target import (
determine_target,
normalize_rocm_arch,
rocm_warp_size_for_arch,
target_get_mcpu,
target_get_rdna_generation,
target_get_warp_size,
target_is_cdna,
target_is_rdna,
)


def test_normalize_rocm_arch_strips_feature_suffix():
assert normalize_rocm_arch("gfx1151:sramecc+:xnack-") == "gfx1151"
assert normalize_rocm_arch("gfx942") == "gfx942"
assert normalize_rocm_arch("") is None
assert normalize_rocm_arch("sm_90") is None
assert rocm_warp_size_for_arch("gfx1151") == 32
assert rocm_warp_size_for_arch("gfx1030") == 32
assert rocm_warp_size_for_arch("gfx942") == 64


def test_target_mcpu_helpers():
target = Target("hip -mcpu=gfx1151:sramecc+:xnack-")
assert target_get_mcpu(target) == "gfx1151"


def test_determine_target_adds_rdna_thread_warp_size():
target = determine_target("hip -mcpu=gfx1151", return_object=True)
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_auto_target_prefers_rocm_pytorch_over_cuda_toolkit(monkeypatch):
monkeypatch.setattr(target_utils.torch.version, "hip", "test", raising=False)
monkeypatch.setattr(target_utils, "check_hip_availability", lambda: True)
monkeypatch.setattr(target_utils, "check_cuda_availability", lambda: True)
monkeypatch.setattr(target_utils, "_detect_torch_rocm_arch", lambda: "gfx1151")

target = determine_target("auto", return_object=True)
assert target.kind.name == "hip"
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_rdna_gfx1151_target_classification():
target = Target("hip -mcpu=gfx1151")
assert target_is_rdna(target)
assert not target_is_cdna(target)
assert target_get_rdna_generation(target) == 11
assert target_get_warp_size(target) == 32


def test_carver_routes_rdna_without_instantiating_device(monkeypatch):
import torch

monkeypatch.setattr(torch.version, "hip", None, raising=False)
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
if hasattr(torch, "mps"):
monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False)

import tilelang.carver.arch as arch_mod

def fake_rdna(target):
return ("rdna", target)

monkeypatch.setattr(arch_mod, "RDNA", fake_rdna)
arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151"))
assert arch[0] == "rdna"
assert target_get_mcpu(arch[1]) == "gfx1151"
2 changes: 1 addition & 1 deletion tilelang/carver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .arch import CUDA, CDNA, RDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401
8 changes: 7 additions & 1 deletion tilelang/carver/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from .cuda import *
from .cpu import *
from .cdna import *
from .rdna import *
from .metal import *
from tvm.target import Target
import torch
from tilelang.utils.target import determine_target, target_is_rdna


def get_arch(target: str | Target = "cuda") -> TileDevice:
Expand All @@ -18,6 +20,8 @@ def get_arch(target: str | Target = "cuda") -> TileDevice:
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
if target_is_rdna(target):
return RDNA(target)
return CDNA(target)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif target.kind.name == "metal":
return METAL(target)
Expand All @@ -29,7 +33,7 @@ def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
if torch.version.hip is not None:
return get_arch("hip")
return get_arch(determine_target("auto", return_object=True))
if torch.cuda.is_available():
return get_arch("cuda")
elif torch.mps.is_available():
Expand All @@ -48,9 +52,11 @@ def auto_infer_current_arch() -> TileDevice:
"is_tensorcore_supported_precision",
"has_mma_support",
"is_cdna_arch",
"is_rdna_arch",
"is_metal_arch",
"CUDA",
"CDNA",
"RDNA",
"METAL",
"CPU",
]
52 changes: 52 additions & 0 deletions tilelang/carver/arch/rdna.py
Comment thread
lhl marked this conversation as resolved.
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from .cuda import TensorInstruction

_RDNA_DEFAULT_LDS_SIZE = 64 * 1024


def is_rdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, RDNA)


class RDNA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "RDNA"

reported_smem = device.max_shared_memory_per_block
self.smem_cap = reported_smem if reported_smem > 0 else _RDNA_DEFAULT_LDS_SIZE
self.compute_max_core = device.multi_processor_count
self.warp_size = 32
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = getattr(target, "l2_cache_size_bytes", 0)
self.transaction_size: list[int] = [32, 128]

# Keep the same units as the existing CUDA/CDNA heuristic. Strix Halo
# is a UMA part, so use a conservative global-memory score seed.
self.bandwidth: list[int] = [750, 12080]
self.available_tensor_instructions: list[TensorInstruction] | None = None

def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (TensorInstruction("wmma", [16, 16]),)
return [t.shape for t in self.available_tensor_instructions]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __repr__(self):
return f"RDNA({self.target})"


__all__ = [
"is_rdna_arch",
"RDNA",
]
12 changes: 11 additions & 1 deletion tilelang/carver/roller/policy/tensorcore.py
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def check_tile_shape_isvalid(self, td: TileDict):
return False
return super().check_tile_shape_isvalid(td)

def score_block_size(self, n):
base_score = super().score_block_size(n)
if getattr(self.arch, "platform", None) == "RDNA":
warps = (n + self.arch.warp_size - 1) // self.arch.warp_size
return (0 if warps == 8 else 1, abs(warps - 8), *base_score)
Comment thread
lhl marked this conversation as resolved.
return base_score

def _can_implement_layout(self, node: PrimFuncNode, td: TileDict):
# Not implemented yet
# This function is used to check whether we can implement swizzling
Expand Down Expand Up @@ -256,8 +263,11 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
space_prod = int(np.prod(space))
if space_prod < warps or space_prod % warps != 0:
return None

factors = factorize(np.prod(space) // warps)
factors = factorize(space_prod // warps)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _score(node, warp_tile): # small is better
score = 0
Expand Down
10 changes: 10 additions & 0 deletions tilelang/carver/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
is_volta_arch,
is_ampere_arch,
is_cdna_arch,
is_rdna_arch,
auto_infer_current_arch,
)
from ..roller.hint import Hint # Import the Hint class
Expand Down Expand Up @@ -95,6 +96,15 @@ def is_cdna_arch(self) -> bool:
"""
return is_cdna_arch(self._arch) if self._arch is not None else False

def is_rdna_arch(self) -> bool:
"""
Checks if the current architecture is an RDNA architecture.

Returns:
bool: True if the architecture is RDNA, False otherwise.
"""
return is_rdna_arch(self._arch) if self._arch is not None else False

def equivalent_function(self) -> PrimFunc:
"""
Returns the function associated with this template.
Expand Down
4 changes: 3 additions & 1 deletion tilelang/contrib/hipcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,7 @@ def compile_hip(code, target_format="hsaco", arch=None, options=None, path_targe
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
from tilelang.utils.target import target_get_mcpu

hsaco = compile_hip(code, target_format="hsaco", arch=target_get_mcpu(target))
return hsaco
Loading
Loading