Skip to content
Open
12 changes: 7 additions & 5 deletions src/op/parallel.cc
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
if (auto reducer_info_anno = op->annotations.Get(attr::kReducerInfo)) {
auto reducer_info_map =
reducer_info_anno.value().as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
}
StmtExprVisitor::VisitStmt_(op);
}
Expand Down
50 changes: 50 additions & 0 deletions src/tl_templates/hip/common.h
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,56 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}

TL_DEVICE int tl_dp4a_fallback(const int a, const int b, int c) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int ai = static_cast<int8_t>((a >> (8 * i)) & 0xff);
const int bi = static_cast<int8_t>((b >> (8 * i)) & 0xff);
c += ai * bi;
}
return c;
}

#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__)
#define TL_AMDGPU_HAS_SUDOT4 1
#endif

#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || \
defined(__gfx950__) || defined(__gfx1011__) || defined(__gfx1012__) || \
defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || \
defined(TL_AMDGPU_HAS_SUDOT4)
#define TL_AMDGPU_HAS_SDOT4 1
#endif
Comment thread
coderabbitai[bot] marked this conversation as resolved.

TL_DEVICE int tl_dp4a(const int a, const int b, const int c) {
#if defined(TL_AMDGPU_HAS_SUDOT4)
return __builtin_amdgcn_sudot4(true, a, true, b, c, false);
#elif defined(TL_AMDGPU_HAS_SDOT4)
return __builtin_amdgcn_sdot4(a, b, c, false);
#else
return tl_dp4a_fallback(a, b, c);
#endif
}

template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(const InDatatype *a, const InDatatype *b, OutDatatype *c) {
static_assert(sizeof(InDatatype) == 1,
"DP4A expects a pointer to packed int8 lanes");
static_assert(sizeof(OutDatatype) == sizeof(int),
"DP4A expects 4-byte accumulator/output type");
int a_int;
int b_int;
int c_int;
__builtin_memcpy(&a_int, a, sizeof(a_int));
__builtin_memcpy(&b_int, b, sizeof(b_int));
__builtin_memcpy(&c_int, c, sizeof(c_int));
const int out = tl_dp4a(a_int, b_int, c_int);
__builtin_memcpy(c, &out, sizeof(out));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm.
// hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard
// __habs overloads don't match exactly, causing ambiguous overload errors.
Expand Down
8 changes: 5 additions & 3 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
auto result = IRMutatorWithAnalyzer::VisitStmt_(op).as<Block>().value();
// After iterating over the body, set all layout_map to block
auto p_result = result.CopyOnWrite();
auto layout_map = p_result->annotations.Get(attr::kLayoutMap)
->as<Map<Var, Layout>>()
.value_or(Map<Var, Layout>());
Map<Var, Layout> layout_map;
if (auto opt_layout_map = p_result->annotations.Get(attr::kLayoutMap)) {
layout_map =
opt_layout_map.value().as<Map<Var, Layout>>().value_or(layout_map);
}
for (auto &&[k, v] : new_layout_map_)
layout_map.Set(k, v);
if (!layout_map.empty())
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.flo
(128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 2),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack):
assert_tl_matmul_correctness(
M,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def assert_tl_matmul_correctness(
(256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(
M,
N,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gfx950_copy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def ref_program(A, B):
(True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_gfx950
def test_gfx950_copy_async_gemm_no_pipeline(trans_A, trans_B):
"""Non-pipelined GEMM (num_stages=0) must also produce correct results."""
prog = _matmul_kernel(
Expand Down
4 changes: 4 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_gemm_bf16bf16f32_nn():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nn():
run_gemm(
512,
Expand Down Expand Up @@ -205,10 +206,12 @@ def test_gemm_f16f16f16_nt():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64)

Expand All @@ -218,6 +221,7 @@ def test_gemm_f64f64f64_nt():
run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nt():
run_gemm(
512,
Expand Down
114 changes: 114 additions & 0 deletions testing/python/target/test_tilelang_rocm_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest

from tilelang import tvm as tvm
from tvm.target import Target

import tilelang.utils.target as target_utils
from tilelang.utils.target import (
determine_target,
normalize_rocm_arch,
rocm_warp_size_for_arch,
target_get_mcpu,
target_get_rdna_generation,
target_get_warp_size,
target_is_cdna,
target_is_rdna,
)


def test_normalize_rocm_arch_strips_feature_suffix():
assert normalize_rocm_arch("gfx1151:sramecc+:xnack-") == "gfx1151"
assert normalize_rocm_arch("gfx942") == "gfx942"
assert normalize_rocm_arch("") is None
assert normalize_rocm_arch("sm_90") is None
assert rocm_warp_size_for_arch("gfx1151") == 32
assert rocm_warp_size_for_arch("gfx1030") == 32
assert rocm_warp_size_for_arch("gfx1200") == 32
assert rocm_warp_size_for_arch("gfx942") == 64


def test_target_mcpu_helpers():
target = Target("hip -mcpu=gfx1151:sramecc+:xnack-")
assert target_get_mcpu(target) == "gfx1151"


def test_determine_target_adds_rdna_thread_warp_size():
target = determine_target("hip -mcpu=gfx1151", return_object=True)
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_determine_target_adds_known_gfx12_thread_warp_size():
target = determine_target("hip -mcpu=gfx1200", return_object=True)
assert target_get_mcpu(target) == "gfx1200"
assert int(target.attrs["thread_warp_size"]) == 32


def test_auto_target_prefers_rocm_pytorch_over_cuda_toolkit(monkeypatch):
monkeypatch.setattr(target_utils.torch.version, "hip", "test", raising=False)
monkeypatch.setattr(target_utils, "check_hip_availability", lambda: True)
monkeypatch.setattr(target_utils, "check_cuda_availability", lambda: True)
monkeypatch.setattr(target_utils, "_detect_torch_rocm_arch", lambda: "gfx1151")

target = determine_target("auto", return_object=True)
assert target.kind.name == "hip"
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_rdna_gfx1151_target_classification():
target = Target("hip -mcpu=gfx1151")
assert target_is_rdna(target)
assert not target_is_cdna(target)
assert target_get_rdna_generation(target) == 11
assert target_get_warp_size(target) == 32


def test_carver_routes_rdna_without_instantiating_device(monkeypatch):
import torch

monkeypatch.setattr(torch.version, "hip", None, raising=False)
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
if hasattr(torch, "mps"):
monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False)

import tilelang.carver.arch as arch_mod

def fake_rdna(target):
return ("rdna", target)

monkeypatch.setattr(arch_mod, "RDNA", fake_rdna)
arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151"))
assert arch[0] == "rdna"
assert target_get_mcpu(arch[1]) == "gfx1151"


def test_carver_rejects_unsupported_rdna_generations(monkeypatch):
import tilelang.carver.arch as arch_mod

def fake_cdna(target):
return ("cdna", target)

monkeypatch.setattr(arch_mod, "CDNA", fake_cdna)
with pytest.raises(ValueError, match="gfx11 targets only"):
arch_mod.get_arch(Target("hip -mcpu=gfx1200"))


def test_rdna_device_model_rejects_gfx12_before_device_probe():
from tilelang.carver.arch.rdna import RDNA

with pytest.raises(ValueError, match="gfx11 targets only"):
RDNA(Target("hip -mcpu=gfx1200"))


def test_rdna_tensor_instruction_lookup_is_generation_aware():
from tilelang.carver.arch.rdna import RDNA

arch = RDNA.__new__(RDNA)
arch.rdna_generation = 11
assert arch.get_avaliable_tensorintrin_shapes() == [[16, 16]]
assert isinstance(arch.available_tensor_instructions, list)

arch.rdna_generation = 12
with pytest.raises(ValueError, match="Unsupported RDNA generation"):
arch.get_avaliable_tensorintrin_shapes()
2 changes: 1 addition & 1 deletion tilelang/carver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .arch import CUDA, CDNA, RDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401
10 changes: 9 additions & 1 deletion tilelang/carver/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from .cuda import *
from .cpu import *
from .cdna import *
from .rdna import *
from .metal import *
from tvm.target import Target
import torch
from tilelang.utils.target import determine_target, target_get_rdna_generation, target_is_rdna


def get_arch(target: str | Target = "cuda") -> TileDevice:
Expand All @@ -18,6 +20,10 @@ def get_arch(target: str | Target = "cuda") -> TileDevice:
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
if target_is_rdna(target):
if target_get_rdna_generation(target) == 11:
return RDNA(target)
raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {target}.")
return CDNA(target)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif target.kind.name == "metal":
return METAL(target)
Expand All @@ -29,7 +35,7 @@ def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
if torch.version.hip is not None:
return get_arch("hip")
return get_arch(determine_target("auto", return_object=True))
if torch.cuda.is_available():
return get_arch("cuda")
elif torch.mps.is_available():
Expand All @@ -48,9 +54,11 @@ def auto_infer_current_arch() -> TileDevice:
"is_tensorcore_supported_precision",
"has_mma_support",
"is_cdna_arch",
"is_rdna_arch",
"is_metal_arch",
"CUDA",
"CDNA",
"RDNA",
"METAL",
"CPU",
]
67 changes: 67 additions & 0 deletions tilelang/carver/arch/rdna.py
Comment thread
lhl marked this conversation as resolved.
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from .cuda import TensorInstruction
from tilelang.utils.target import target_get_mcpu, target_get_rdna_generation

_RDNA_DEFAULT_LDS_SIZE = 64 * 1024
_RDNA_TENSOR_INSTRUCTIONS = {
11: (TensorInstruction("wmma", [16, 16]),),
}


def _get_tensor_instructions_for_generation(rdna_generation: int) -> tuple[TensorInstruction, ...]:
try:
return _RDNA_TENSOR_INSTRUCTIONS[rdna_generation]
except KeyError as err:
raise ValueError(f"Unsupported RDNA generation for tensor instructions: {rdna_generation}") from err


def is_rdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, RDNA)


class RDNA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.rdna_generation = target_get_rdna_generation(target)
if self.rdna_generation != 11:
arch = target_get_mcpu(target) or str(target)
raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {arch}.")
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "RDNA"

reported_smem = device.max_shared_memory_per_block
self.smem_cap = reported_smem if reported_smem > 0 else _RDNA_DEFAULT_LDS_SIZE
self.compute_max_core = device.multi_processor_count
self.warp_size = 32
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = getattr(target, "l2_cache_size_bytes", 0)
self.transaction_size: list[int] = [32, 128]

# Keep the same units as the existing CUDA/CDNA heuristic. Strix Halo
# is a UMA part, so use a conservative global-memory score seed.
self.bandwidth: list[int] = [750, 12080]
self.available_tensor_instructions: list[TensorInstruction] | None = None

def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = list(_get_tensor_instructions_for_generation(self.rdna_generation))
return [t.shape for t in self.available_tensor_instructions]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __repr__(self):
return f"RDNA({self.target})"


__all__ = [
"is_rdna_arch",
"RDNA",
]
Loading