Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/op/parallel.cc
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
if (auto reducer_info_anno = op->annotations.Get(attr::kReducerInfo)) {
auto reducer_info_map =
reducer_info_anno.value().as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
}
StmtExprVisitor::VisitStmt_(op);
}
Expand Down
50 changes: 50 additions & 0 deletions src/tl_templates/hip/common.h
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,56 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}

TL_DEVICE int tl_dp4a_fallback(const int a, const int b, int c) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int ai = static_cast<int8_t>((a >> (8 * i)) & 0xff);
const int bi = static_cast<int8_t>((b >> (8 * i)) & 0xff);
c += ai * bi;
}
return c;
}

#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__)
#define TL_AMDGPU_HAS_SUDOT4 1
#endif

#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || \
defined(__gfx950__) || defined(__gfx1011__) || defined(__gfx1012__) || \
defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || \
defined(TL_AMDGPU_HAS_SUDOT4)
#define TL_AMDGPU_HAS_SDOT4 1
#endif
Comment thread
coderabbitai[bot] marked this conversation as resolved.

TL_DEVICE int tl_dp4a(const int a, const int b, const int c) {
#if defined(TL_AMDGPU_HAS_SUDOT4)
return __builtin_amdgcn_sudot4(true, a, true, b, c, false);
#elif defined(TL_AMDGPU_HAS_SDOT4)
return __builtin_amdgcn_sdot4(a, b, c, false);
#else
return tl_dp4a_fallback(a, b, c);
#endif
}

template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(const InDatatype *a, const InDatatype *b, OutDatatype *c) {
static_assert(sizeof(InDatatype) == 1,
"DP4A expects a pointer to packed int8 lanes");
static_assert(sizeof(OutDatatype) == sizeof(int),
"DP4A expects 4-byte accumulator/output type");
int a_int;
int b_int;
int c_int;
__builtin_memcpy(&a_int, a, sizeof(a_int));
__builtin_memcpy(&b_int, b, sizeof(b_int));
__builtin_memcpy(&c_int, c, sizeof(c_int));
const int out = tl_dp4a(a_int, b_int, c_int);
__builtin_memcpy(c, &out, sizeof(out));
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm.
// hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard
// __habs overloads don't match exactly, causing ambiguous overload errors.
Expand Down
8 changes: 5 additions & 3 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
auto result = IRMutatorWithAnalyzer::VisitStmt_(op).as<Block>().value();
// After iterating over the body, set all layout_map to block
auto p_result = result.CopyOnWrite();
auto layout_map = p_result->annotations.Get(attr::kLayoutMap)
->as<Map<Var, Layout>>()
.value_or(Map<Var, Layout>());
Map<Var, Layout> layout_map;
if (auto opt_layout_map = p_result->annotations.Get(attr::kLayoutMap)) {
layout_map =
opt_layout_map.value().as<Map<Var, Layout>>().value_or(layout_map);
}
for (auto &&[k, v] : new_layout_map_)
layout_map.Set(k, v);
if (!layout_map.empty())
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.flo
(128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 2),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack):
assert_tl_matmul_correctness(
M,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def assert_tl_matmul_correctness(
(256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_cdna
def test_assert_tl_matmul(
M,
N,
Expand Down
2 changes: 1 addition & 1 deletion testing/python/amd/test_tilelang_gfx950_copy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def ref_program(A, B):
(True, False),
],
)
@tilelang.testing.requires_rocm
@tilelang.testing.requires_gfx950
def test_gfx950_copy_async_gemm_no_pipeline(trans_A, trans_B):
"""Non-pipelined GEMM (num_stages=0) must also produce correct results."""
prog = _matmul_kernel(
Expand Down
4 changes: 4 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_gemm_bf16bf16f32_nn():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nn():
run_gemm(
512,
Expand Down Expand Up @@ -205,10 +206,12 @@ def test_gemm_f16f16f16_nt():
)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64)

Expand All @@ -218,6 +221,7 @@ def test_gemm_f64f64f64_nt():
run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16)


@tilelang.testing.requires_cuda_or_cdna
def test_gemm_f32f32f32_nt():
run_gemm(
512,
Expand Down
114 changes: 114 additions & 0 deletions testing/python/target/test_tilelang_rocm_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest

from tilelang import tvm as tvm
from tvm.target import Target

import tilelang.utils.target as target_utils
from tilelang.utils.target import (
determine_target,
normalize_rocm_arch,
rocm_warp_size_for_arch,
target_get_mcpu,
target_get_rdna_generation,
target_get_warp_size,
target_is_cdna,
target_is_rdna,
)


def test_normalize_rocm_arch_strips_feature_suffix():
assert normalize_rocm_arch("gfx1151:sramecc+:xnack-") == "gfx1151"
assert normalize_rocm_arch("gfx942") == "gfx942"
assert normalize_rocm_arch("") is None
assert normalize_rocm_arch("sm_90") is None
assert rocm_warp_size_for_arch("gfx1151") == 32
assert rocm_warp_size_for_arch("gfx1030") == 32
assert rocm_warp_size_for_arch("gfx1200") == 32
assert rocm_warp_size_for_arch("gfx942") == 64


def test_target_mcpu_helpers():
target = Target("hip -mcpu=gfx1151:sramecc+:xnack-")
assert target_get_mcpu(target) == "gfx1151"


def test_determine_target_adds_rdna_thread_warp_size():
target = determine_target("hip -mcpu=gfx1151", return_object=True)
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_determine_target_adds_known_gfx12_thread_warp_size():
target = determine_target("hip -mcpu=gfx1200", return_object=True)
assert target_get_mcpu(target) == "gfx1200"
assert int(target.attrs["thread_warp_size"]) == 32


def test_auto_target_prefers_rocm_pytorch_over_cuda_toolkit(monkeypatch):
monkeypatch.setattr(target_utils.torch.version, "hip", "test", raising=False)
monkeypatch.setattr(target_utils, "check_hip_availability", lambda: True)
monkeypatch.setattr(target_utils, "check_cuda_availability", lambda: True)
monkeypatch.setattr(target_utils, "_detect_torch_rocm_arch", lambda: "gfx1151")

target = determine_target("auto", return_object=True)
assert target.kind.name == "hip"
assert target_get_mcpu(target) == "gfx1151"
assert int(target.attrs["thread_warp_size"]) == 32


def test_rdna_gfx1151_target_classification():
target = Target("hip -mcpu=gfx1151")
assert target_is_rdna(target)
assert not target_is_cdna(target)
assert target_get_rdna_generation(target) == 11
assert target_get_warp_size(target) == 32


def test_carver_routes_rdna_without_instantiating_device(monkeypatch):
import torch

monkeypatch.setattr(torch.version, "hip", None, raising=False)
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
if hasattr(torch, "mps"):
monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False)

import tilelang.carver.arch as arch_mod

def fake_rdna(target):
return ("rdna", target)

monkeypatch.setattr(arch_mod, "RDNA", fake_rdna)
arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151"))
assert arch[0] == "rdna"
assert target_get_mcpu(arch[1]) == "gfx1151"


def test_carver_rejects_unsupported_rdna_generations(monkeypatch):
import tilelang.carver.arch as arch_mod

def fake_cdna(target):
return ("cdna", target)

monkeypatch.setattr(arch_mod, "CDNA", fake_cdna)
with pytest.raises(ValueError, match="gfx11 targets only"):
arch_mod.get_arch(Target("hip -mcpu=gfx1200"))


def test_rdna_device_model_rejects_gfx12_before_device_probe():
from tilelang.carver.arch.rdna import RDNA

with pytest.raises(ValueError, match="gfx11 targets only"):
RDNA(Target("hip -mcpu=gfx1200"))


def test_rdna_tensor_instruction_lookup_is_generation_aware():
from tilelang.carver.arch.rdna import RDNA

arch = RDNA.__new__(RDNA)
arch.rdna_generation = 11
assert arch.get_avaliable_tensorintrin_shapes() == [[16, 16]]
assert isinstance(arch.available_tensor_instructions, list)

arch.rdna_generation = 12
with pytest.raises(ValueError, match="Unsupported RDNA generation"):
arch.get_avaliable_tensorintrin_shapes()
2 changes: 1 addition & 1 deletion tilelang/carver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .arch import CUDA, CDNA, RDNA # noqa: F401
from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401
10 changes: 9 additions & 1 deletion tilelang/carver/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from .cuda import *
from .cpu import *
from .cdna import *
from .rdna import *
from .metal import *
from tvm.target import Target
import torch
from tilelang.utils.target import determine_target, target_get_rdna_generation, target_is_rdna


def get_arch(target: str | Target = "cuda") -> TileDevice:
Expand All @@ -18,6 +20,10 @@ def get_arch(target: str | Target = "cuda") -> TileDevice:
elif target.kind.name == "llvm":
return CPU(target)
elif target.kind.name == "hip":
if target_is_rdna(target):
if target_get_rdna_generation(target) == 11:
return RDNA(target)
raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {target}.")
return CDNA(target)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif target.kind.name == "metal":
return METAL(target)
Expand All @@ -29,7 +35,7 @@ def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
if torch.version.hip is not None:
return get_arch("hip")
return get_arch(determine_target("auto", return_object=True))
if torch.cuda.is_available():
return get_arch("cuda")
elif torch.mps.is_available():
Expand All @@ -48,9 +54,11 @@ def auto_infer_current_arch() -> TileDevice:
"is_tensorcore_supported_precision",
"has_mma_support",
"is_cdna_arch",
"is_rdna_arch",
"is_metal_arch",
"CUDA",
"CDNA",
"RDNA",
"METAL",
"CPU",
]
67 changes: 67 additions & 0 deletions tilelang/carver/arch/rdna.py
Comment thread
lhl marked this conversation as resolved.
Comment thread
lhl marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from .cuda import TensorInstruction
from tilelang.utils.target import target_get_mcpu, target_get_rdna_generation

_RDNA_DEFAULT_LDS_SIZE = 64 * 1024
_RDNA_TENSOR_INSTRUCTIONS = {
11: (TensorInstruction("wmma", [16, 16]),),
}


def _get_tensor_instructions_for_generation(rdna_generation: int) -> tuple[TensorInstruction, ...]:
try:
return _RDNA_TENSOR_INSTRUCTIONS[rdna_generation]
except KeyError as err:
raise ValueError(f"Unsupported RDNA generation for tensor instructions: {rdna_generation}") from err


def is_rdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, RDNA)


class RDNA(TileDevice):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.rdna_generation = target_get_rdna_generation(target)
if self.rdna_generation != 11:
arch = target_get_mcpu(target) or str(target)
raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {arch}.")
device = tvm.runtime.rocm(0)
if not device.exist:
raise RuntimeError("Cannot find HIP device 0.")
self.device: tvm.runtime.Device = device
self.platform: str = "RDNA"

reported_smem = device.max_shared_memory_per_block
self.smem_cap = reported_smem if reported_smem > 0 else _RDNA_DEFAULT_LDS_SIZE
self.compute_max_core = device.multi_processor_count
self.warp_size = 32
self.compute_capability = device.compute_version.replace(".", "")
self.reg_cap: int = 32768
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = getattr(target, "l2_cache_size_bytes", 0)
self.transaction_size: list[int] = [32, 128]

# Keep the same units as the existing CUDA/CDNA heuristic. Strix Halo
# is a UMA part, so use a conservative global-memory score seed.
self.bandwidth: list[int] = [750, 12080]
self.available_tensor_instructions: list[TensorInstruction] | None = None

def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = list(_get_tensor_instructions_for_generation(self.rdna_generation))
return [t.shape for t in self.available_tensor_instructions]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def __repr__(self):
return f"RDNA({self.target})"


__all__ = [
"is_rdna_arch",
"RDNA",
]
Loading