diff --git a/src/tl_templates/hip/common.h b/src/tl_templates/hip/common.h index 49c5b6c1e4..c8ee23ee2a 100644 --- a/src/tl_templates/hip/common.h +++ b/src/tl_templates/hip/common.h @@ -111,6 +111,56 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) { return (v1 << 16) | v0; } +TL_DEVICE int tl_dp4a_fallback(const int a, const int b, int c) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int ai = static_cast((a >> (8 * i)) & 0xff); + const int bi = static_cast((b >> (8 * i)) & 0xff); + c += ai * bi; + } + return c; +} + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) +#define TL_AMDGPU_HAS_SUDOT4 1 +#endif + +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || \ + defined(__gfx950__) || defined(__gfx1011__) || defined(__gfx1012__) || \ + defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || \ + defined(TL_AMDGPU_HAS_SUDOT4) +#define TL_AMDGPU_HAS_SDOT4 1 +#endif + +TL_DEVICE int tl_dp4a(const int a, const int b, const int c) { +#if defined(TL_AMDGPU_HAS_SUDOT4) + return __builtin_amdgcn_sudot4(true, a, true, b, c, false); +#elif defined(TL_AMDGPU_HAS_SDOT4) + return __builtin_amdgcn_sdot4(a, b, c, false); +#else + return tl_dp4a_fallback(a, b, c); +#endif +} + +template +TL_DEVICE void DP4A(const InDatatype *a, const InDatatype *b, OutDatatype *c) { + static_assert(sizeof(InDatatype) == 1, + "DP4A expects a pointer to packed int8 lanes"); + static_assert(sizeof(OutDatatype) == sizeof(int), + "DP4A expects 4-byte accumulator/output type"); + int a_int; + int b_int; + int c_int; + __builtin_memcpy(&a_int, a, sizeof(a_int)); + __builtin_memcpy(&b_int, b, sizeof(b_int)); + __builtin_memcpy(&c_int, c, sizeof(c_int)); + const int out = tl_dp4a(a_int, b_int, c_int); + __builtin_memcpy(c, &out, sizeof(out)); +} + // __habs overloads for hip_bfloat16 and float16_t to resolve ambiguity on ROCm. // hip_bfloat16 != __hip_bfloat16, and float16_t != __half, so the standard // __habs overloads don't match exactly, causing ambiguous overload errors. diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 00fac1a3a3..e6a545b1ec 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -227,7 +227,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.flo (128, 256, 256, determine_fp8_type(), T.float32, T.float32, False, False, 2), ], ) -@tilelang.testing.requires_rocm +@tilelang.testing.requires_cdna def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack): assert_tl_matmul_correctness( M, diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index 864ac58c7b..81ed8d4883 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -356,7 +356,7 @@ def assert_tl_matmul_correctness( (256, 256, 512, determine_fp8_type(), T.float32, T.float32, False, True, 1, True, True, False), ], ) -@tilelang.testing.requires_rocm +@tilelang.testing.requires_cdna def test_assert_tl_matmul( M, N, diff --git a/testing/python/amd/test_tilelang_gfx950_copy_async.py b/testing/python/amd/test_tilelang_gfx950_copy_async.py index 18935db938..53368fcae0 100644 --- a/testing/python/amd/test_tilelang_gfx950_copy_async.py +++ b/testing/python/amd/test_tilelang_gfx950_copy_async.py @@ -231,7 +231,7 @@ def ref_program(A, B): (True, False), ], ) -@tilelang.testing.requires_rocm +@tilelang.testing.requires_gfx950 def test_gfx950_copy_async_gemm_no_pipeline(trans_A, trans_B): """Non-pipelined GEMM (num_stages=0) must also produce correct results.""" prog = _matmul_kernel( diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index f6a412f147..724f542e7c 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -153,6 +153,7 @@ def test_gemm_bf16bf16f32_nn(): ) +@tilelang.testing.requires_cuda_or_cdna def test_gemm_f32f32f32_nn(): run_gemm( 512, @@ -205,10 +206,12 @@ def test_gemm_f16f16f16_nt(): ) +@tilelang.testing.requires_cuda_or_cdna def test_gemm_i8i8i32_nt(): run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64) +@tilelang.testing.requires_cuda_or_cdna def test_gemm_i8i8i32_tn(): run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64) @@ -218,6 +221,7 @@ def test_gemm_f64f64f64_nt(): run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16) +@tilelang.testing.requires_cuda_or_cdna def test_gemm_f32f32f32_nt(): run_gemm( 512, diff --git a/testing/python/target/test_tilelang_rocm_target.py b/testing/python/target/test_tilelang_rocm_target.py new file mode 100644 index 0000000000..538b63c3aa --- /dev/null +++ b/testing/python/target/test_tilelang_rocm_target.py @@ -0,0 +1,114 @@ +import pytest + +from tilelang import tvm as tvm +from tvm.target import Target + +import tilelang.utils.target as target_utils +from tilelang.utils.target import ( + determine_target, + normalize_rocm_arch, + rocm_warp_size_for_arch, + target_get_mcpu, + target_get_rdna_generation, + target_get_warp_size, + target_is_cdna, + target_is_rdna, +) + + +def test_normalize_rocm_arch_strips_feature_suffix(): + assert normalize_rocm_arch("gfx1151:sramecc+:xnack-") == "gfx1151" + assert normalize_rocm_arch("gfx942") == "gfx942" + assert normalize_rocm_arch("") is None + assert normalize_rocm_arch("sm_90") is None + assert rocm_warp_size_for_arch("gfx1151") == 32 + assert rocm_warp_size_for_arch("gfx1030") == 32 + assert rocm_warp_size_for_arch("gfx1200") == 32 + assert rocm_warp_size_for_arch("gfx942") == 64 + + +def test_target_mcpu_helpers(): + target = Target("hip -mcpu=gfx1151:sramecc+:xnack-") + assert target_get_mcpu(target) == "gfx1151" + + +def test_determine_target_adds_rdna_thread_warp_size(): + target = determine_target("hip -mcpu=gfx1151", return_object=True) + assert target_get_mcpu(target) == "gfx1151" + assert int(target.attrs["thread_warp_size"]) == 32 + + +def test_determine_target_adds_known_gfx12_thread_warp_size(): + target = determine_target("hip -mcpu=gfx1200", return_object=True) + assert target_get_mcpu(target) == "gfx1200" + assert int(target.attrs["thread_warp_size"]) == 32 + + +def test_auto_target_prefers_rocm_pytorch_over_cuda_toolkit(monkeypatch): + monkeypatch.setattr(target_utils.torch.version, "hip", "test", raising=False) + monkeypatch.setattr(target_utils, "check_hip_availability", lambda: True) + monkeypatch.setattr(target_utils, "check_cuda_availability", lambda: True) + monkeypatch.setattr(target_utils, "_detect_torch_rocm_arch", lambda: "gfx1151") + + target = determine_target("auto", return_object=True) + assert target.kind.name == "hip" + assert target_get_mcpu(target) == "gfx1151" + assert int(target.attrs["thread_warp_size"]) == 32 + + +def test_rdna_gfx1151_target_classification(): + target = Target("hip -mcpu=gfx1151") + assert target_is_rdna(target) + assert not target_is_cdna(target) + assert target_get_rdna_generation(target) == 11 + assert target_get_warp_size(target) == 32 + + +def test_carver_routes_rdna_without_instantiating_device(monkeypatch): + import torch + + monkeypatch.setattr(torch.version, "hip", None, raising=False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + if hasattr(torch, "mps"): + monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False) + + import tilelang.carver.arch as arch_mod + + def fake_rdna(target): + return ("rdna", target) + + monkeypatch.setattr(arch_mod, "RDNA", fake_rdna) + arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151")) + assert arch[0] == "rdna" + assert target_get_mcpu(arch[1]) == "gfx1151" + + +def test_carver_rejects_unsupported_rdna_generations(monkeypatch): + import tilelang.carver.arch as arch_mod + + def fake_cdna(target): + return ("cdna", target) + + monkeypatch.setattr(arch_mod, "CDNA", fake_cdna) + with pytest.raises(ValueError, match="gfx11 targets only"): + arch_mod.get_arch(Target("hip -mcpu=gfx1200")) + + +def test_rdna_device_model_rejects_gfx12_before_device_probe(): + from tilelang.carver.arch.rdna import RDNA + + with pytest.raises(ValueError, match="gfx11 targets only"): + RDNA(Target("hip -mcpu=gfx1200")) + + +def test_rdna_tensor_instruction_lookup_is_generation_aware(): + from tilelang.carver.arch.rdna import RDNA + + arch = RDNA.__new__(RDNA) + arch.rdna_generation = 11 + assert arch.get_avaliable_tensorintrin_shapes() == [[16, 16]] + assert isinstance(arch.available_tensor_instructions, list) + + arch.rdna_generation = 12 + with pytest.raises(ValueError, match="Unsupported RDNA generation"): + arch.get_avaliable_tensorintrin_shapes() diff --git a/tilelang/carver/__init__.py b/tilelang/carver/__init__.py index f1dfc5b475..010f02a3f8 100644 --- a/tilelang/carver/__init__.py +++ b/tilelang/carver/__init__.py @@ -11,5 +11,5 @@ ) # noqa: F401 from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401 from .roller import * -from .arch import CUDA, CDNA # noqa: F401 +from .arch import CUDA, CDNA, RDNA # noqa: F401 from .template import MatmulTemplate, GEMVTemplate, ElementwiseTemplate, GeneralReductionTemplate, FlashAttentionTemplate # noqa: F401 diff --git a/tilelang/carver/arch/__init__.py b/tilelang/carver/arch/__init__.py index b6cb9e72f7..34601011b8 100644 --- a/tilelang/carver/arch/__init__.py +++ b/tilelang/carver/arch/__init__.py @@ -4,9 +4,11 @@ from .cuda import * from .cpu import * from .cdna import * +from .rdna import * from .metal import * from tvm.target import Target import torch +from tilelang.utils.target import determine_target, target_get_rdna_generation, target_is_rdna def get_arch(target: str | Target = "cuda") -> TileDevice: @@ -18,6 +20,10 @@ def get_arch(target: str | Target = "cuda") -> TileDevice: elif target.kind.name == "llvm": return CPU(target) elif target.kind.name == "hip": + if target_is_rdna(target): + if target_get_rdna_generation(target) == 11: + return RDNA(target) + raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {target}.") return CDNA(target) elif target.kind.name == "metal": return METAL(target) @@ -29,7 +35,7 @@ def auto_infer_current_arch() -> TileDevice: # TODO(lei): This is a temporary solution to infer the current architecture # Can be replaced by a more sophisticated method in the future if torch.version.hip is not None: - return get_arch("hip") + return get_arch(determine_target("auto", return_object=True)) if torch.cuda.is_available(): return get_arch("cuda") elif torch.mps.is_available(): @@ -48,9 +54,11 @@ def auto_infer_current_arch() -> TileDevice: "is_tensorcore_supported_precision", "has_mma_support", "is_cdna_arch", + "is_rdna_arch", "is_metal_arch", "CUDA", "CDNA", + "RDNA", "METAL", "CPU", ] diff --git a/tilelang/carver/arch/rdna.py b/tilelang/carver/arch/rdna.py new file mode 100644 index 0000000000..db3971e875 --- /dev/null +++ b/tilelang/carver/arch/rdna.py @@ -0,0 +1,67 @@ +from __future__ import annotations +import tvm +from tvm.target import Target +from .arch_base import TileDevice +from .cuda import TensorInstruction +from tilelang.utils.target import target_get_mcpu, target_get_rdna_generation + +_RDNA_DEFAULT_LDS_SIZE = 64 * 1024 +_RDNA_TENSOR_INSTRUCTIONS = { + 11: (TensorInstruction("wmma", [16, 16]),), +} + + +def _get_tensor_instructions_for_generation(rdna_generation: int) -> tuple[TensorInstruction, ...]: + try: + return _RDNA_TENSOR_INSTRUCTIONS[rdna_generation] + except KeyError as err: + raise ValueError(f"Unsupported RDNA generation for tensor instructions: {rdna_generation}") from err + + +def is_rdna_arch(arch: TileDevice) -> bool: + return isinstance(arch, RDNA) + + +class RDNA(TileDevice): + def __init__(self, target: Target | str): + if isinstance(target, str): + target = tvm.target.Target(target) + self.target = target + self.rdna_generation = target_get_rdna_generation(target) + if self.rdna_generation != 11: + arch = target_get_mcpu(target) or str(target) + raise ValueError(f"RDNA device model currently supports gfx11 targets only, got {arch}.") + device = tvm.runtime.rocm(0) + if not device.exist: + raise RuntimeError("Cannot find HIP device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "RDNA" + + reported_smem = device.max_shared_memory_per_block + self.smem_cap = reported_smem if reported_smem > 0 else _RDNA_DEFAULT_LDS_SIZE + self.compute_max_core = device.multi_processor_count + self.warp_size = 32 + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 32768 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = getattr(target, "l2_cache_size_bytes", 0) + self.transaction_size: list[int] = [32, 128] + + # Keep the same units as the existing CUDA/CDNA heuristic. Strix Halo + # is a UMA part, so use a conservative global-memory score seed. + self.bandwidth: list[int] = [750, 12080] + self.available_tensor_instructions: list[TensorInstruction] | None = None + + def get_avaliable_tensorintrin_shapes(self): + self.available_tensor_instructions = list(_get_tensor_instructions_for_generation(self.rdna_generation)) + return [t.shape for t in self.available_tensor_instructions] + + def __repr__(self): + return f"RDNA({self.target})" + + +__all__ = [ + "is_rdna_arch", + "RDNA", +] diff --git a/tilelang/carver/roller/policy/tensorcore.py b/tilelang/carver/roller/policy/tensorcore.py index 86c79ea732..6bbc85a805 100644 --- a/tilelang/carver/roller/policy/tensorcore.py +++ b/tilelang/carver/roller/policy/tensorcore.py @@ -9,6 +9,7 @@ from .common import coalesced_factor, factorize, get_all_factors from .default import DefaultPolicy from ..rasterization import NoRasterization, Rasterization2DColumn +from ...arch import is_rdna_arch logger = logging.getLogger(__name__) @@ -212,6 +213,13 @@ def check_tile_shape_isvalid(self, td: TileDict): return False return super().check_tile_shape_isvalid(td) + def score_block_size(self, n): + base_score = super().score_block_size(n) + if is_rdna_arch(self.arch): + warps = (n + self.arch.warp_size - 1) // self.arch.warp_size + return (0 if warps == 8 else 1, abs(warps - 8), *base_score) + return base_score + def _can_implement_layout(self, node: PrimFuncNode, td: TileDict): # Not implemented yet # This function is used to check whether we can implement swizzling @@ -256,8 +264,11 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: # allow pad, otherwise, we can not get a valid tile shape return None + space_prod = int(np.prod(space)) + if space_prod < warps or space_prod % warps != 0: + return None - factors = factorize(np.prod(space) // warps) + factors = factorize(space_prod // warps) def _score(node, warp_tile): # small is better score = 0 diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py index 4a699fbc7d..370d98d7bb 100644 --- a/tilelang/carver/template/base.py +++ b/tilelang/carver/template/base.py @@ -6,6 +6,7 @@ is_volta_arch, is_ampere_arch, is_cdna_arch, + is_rdna_arch, auto_infer_current_arch, ) from ..roller.hint import Hint # Import the Hint class @@ -95,6 +96,15 @@ def is_cdna_arch(self) -> bool: """ return is_cdna_arch(self._arch) if self._arch is not None else False + def is_rdna_arch(self) -> bool: + """ + Checks if the current architecture is an RDNA architecture. + + Returns: + bool: True if the architecture is RDNA, False otherwise. + """ + return is_rdna_arch(self._arch) if self._arch is not None else False + def equivalent_function(self) -> PrimFunc: """ Returns the function associated with this template. diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index 7b7f9f9479..85eb661a22 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -94,5 +94,7 @@ def compile_hip(code, target_format="hsaco", arch=None, options=None, path_targe @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) def tilelang_callback_hip_compile(code, target): """use hipcc to generate fatbin code for better optimization""" - hsaco = compile_hip(code, target_format="hsaco") + from tilelang.utils.target import target_get_mcpu + + hsaco = compile_hip(code, target_format="hsaco", arch=target_get_mcpu(target)) return hsaco diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index b5622963f3..96cb8e841a 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -15,7 +15,7 @@ from tilelang.transform import PassConfigKey from tilelang.transform.metal import MarkHostMetalContext from tilelang.engine.param import KernelParam, CompiledArtifact -from tilelang.utils.target import determine_target +from tilelang.utils.target import determine_target, target_get_mcpu from tilelang.engine.phase import ( PreLowerSemanticCheck, LowerAndLegalize, @@ -162,9 +162,11 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None): @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) def tilelang_callback_hip_compile(code, target): + arch = target_get_mcpu(target) hsaco = hipcc.compile_hip( code, target_format="hsaco", + arch=arch, options=[ "-std=c++17", "-I" + TILELANG_TEMPLATE_PATH, diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 85f9466e6f..020ce04f08 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -19,6 +19,7 @@ ) from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.env import TILELANG_TEMPLATE_PATH +from tilelang.utils.target import target_get_mcpu from .utils import is_cpu_target, is_cuda_target, is_hip_target @@ -117,7 +118,7 @@ def compile_lib(self, timeout: float = None): src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") rocm_path = find_rocm_path() - arch = get_rocm_arch(rocm_path) + arch = target_get_mcpu(target) or get_rocm_arch(rocm_path) command = [ "hipcc", "-std=c++17", diff --git a/tilelang/testing/__init__.py b/tilelang/testing/__init__.py index 6008bfcf12..542ad09682 100644 --- a/tilelang/testing/__init__.py +++ b/tilelang/testing/__init__.py @@ -5,7 +5,7 @@ import torch import numpy as np from tilelang.contrib import nvcc -from tilelang.utils.target import determine_target, target_is_gfx950 +from tilelang.utils.target import determine_target, target_is_cdna, target_is_cuda, target_is_gfx950 from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose from tilelang.utils.tensor import torch_assert_close as torch_assert_close @@ -17,6 +17,8 @@ "requires_metal", "requires_rocm", "requires_llvm", + "requires_cdna", + "requires_cuda_or_cdna", "requires_gfx950", "main", "requires_cuda_compute_version", @@ -33,6 +35,47 @@ def _check_is_gfx950() -> bool: return False +def _check_is_cdna() -> bool: + try: + target = determine_target("auto", return_object=True) + return target_is_cdna(target) + except (ValueError, RuntimeError): + return False + + +def _check_is_cuda_or_cdna() -> bool: + try: + target = determine_target("auto", return_object=True) + return target_is_cuda(target) or target_is_cdna(target) + except (ValueError, RuntimeError): + return False + + +def requires_cdna(func): + """Skip the test unless the ROCm device is a CDNA GPU.""" + is_cdna = _check_is_cdna() + marks = [ + pytest.mark.skipif( + not is_cdna, + reason="Requires CDNA ROCm target", + ), + *requires_rocm.marks(), + ] + return _compose([func], marks) + + +def requires_cuda_or_cdna(func): + """Skip the test unless the device is CUDA or CDNA ROCm.""" + is_cuda_or_cdna = _check_is_cuda_or_cdna() + marks = [ + pytest.mark.skipif( + not is_cuda_or_cdna, + reason="Requires CUDA or CDNA ROCm target", + ), + ] + return _compose([func], marks) + + def requires_gfx950(func): """Skip the test unless the ROCm device is gfx950 (CDNA4 / MI350).""" is_gfx950 = _check_is_gfx950() diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 94252a5f3d..4810bbce75 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -23,6 +23,71 @@ "cutedsl": "CuTe DSL GPU target.", } +ROCM_MTRIPLE = "amdgcn-amd-amdhsa-hcc" + + +def normalize_rocm_arch(arch: str | None) -> str | None: + if arch is None: + return None + normalized = str(arch).strip().split(":", maxsplit=1)[0] + return normalized if normalized.startswith("gfx") else None + + +def target_get_mcpu(target: str | Target | None) -> str | None: + if target is None: + return None + if isinstance(target, str): + target = Target(target) + return normalize_rocm_arch(target.attrs.get("mcpu")) + + +def rocm_warp_size_for_arch(arch: str | None) -> int | None: + if arch is None: + return None + if arch.startswith("gfx9"): + return 64 + if arch.startswith(("gfx10", "gfx11", "gfx12")): + return 32 + return None + + +def with_rocm_target_attrs(target: Target) -> Target: + if target.kind.name != "hip": + return target + arch = target_get_mcpu(target) + if arch is None: + return target + + target_dict = dict(target.export()) + target_dict.setdefault("mtriple", ROCM_MTRIPLE) + warp_size = rocm_warp_size_for_arch(arch) + if warp_size is not None: + target_dict["thread_warp_size"] = warp_size + else: + target_dict.pop("thread_warp_size", None) + return Target(target_dict) + + +def _detect_torch_rocm_arch() -> str | None: + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + return normalize_rocm_arch(getattr(props, "gcnArchName", None)) + + +def _rocm_target_from_arch(arch: str | None) -> Target | str: + if arch is None: + return "hip" + target_dict = { + "kind": "hip", + "mcpu": arch, + "mtriple": ROCM_MTRIPLE, + } + warp_size = rocm_warp_size_for_arch(arch) + if warp_size is not None: + target_dict["thread_warp_size"] = warp_size + return Target(target_dict) + def describe_supported_targets() -> dict[str, str]: """ @@ -140,23 +205,29 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj if target == "auto": target = tvm.target.Target.current(allow_none=True) if target is not None: - return target - # Check for CUDA and HIP availability - is_cuda_available = check_cuda_availability() - is_hip_available = check_hip_availability() - - # Determine the target based on availability - if is_cuda_available: - if torch.cuda.is_available() and (cap := torch.cuda.get_device_capability(0)): - return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"}) - else: - return_var = "cuda" - elif is_hip_available: - return_var = "hip" - elif check_metal_availability(): - return_var = "metal" + return with_rocm_target_attrs(target) + # ROCm PyTorch exposes devices through torch.cuda. If CUDA tooling is + # also present, prefer HIP so APUs such as gfx1151 are not misread as + # CUDA architectures like sm_115a. + if torch.version.hip is not None and check_hip_availability(): + return_var = _rocm_target_from_arch(_detect_torch_rocm_arch()) else: - raise ValueError("No CUDA or HIP or MPS available on this system.") + # Check for CUDA and HIP availability + is_cuda_available = check_cuda_availability() + is_hip_available = check_hip_availability() + + # Determine the target based on availability + if is_cuda_available: + if torch.cuda.is_available() and (cap := torch.cuda.get_device_capability(0)): + return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"}) + else: + return_var = "cuda" + elif is_hip_available: + return_var = _rocm_target_from_arch(_detect_torch_rocm_arch()) + elif check_metal_availability(): + return_var = "metal" + else: + raise ValueError("No CUDA or HIP or MPS available on this system.") else: possible_cutedsl_target = normalize_cutedsl_target(target) @@ -172,20 +243,23 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj else: # Validate the target if it's not "auto" if isinstance(target, Target): - return_var = target + return_var = with_rocm_target_attrs(target) elif isinstance(target, str): normalized_target = target.strip() if not normalized_target: raise AssertionError(f"Target {target} is not supported") try: - Target(normalized_target) + parsed_target = Target(normalized_target) except Exception as err: examples = ", ".join(f"`{name}`" for name in SUPPORTED_TARGETS) raise AssertionError( f"Target {target} is not supported. Supported targets include: {examples}. " "Pass additional options after the base name, e.g. `cuda -arch=sm_80`." ) from err - return_var = normalized_target + if parsed_target.kind.name == "hip" and target_get_mcpu(parsed_target) is not None: + return_var = with_rocm_target_attrs(parsed_target) + else: + return_var = normalized_target else: raise AssertionError(f"Target {target} is not supported") @@ -234,6 +308,10 @@ def target_is_cdna(target: Target) -> bool: return _ffi_api.TargetIsCDNA(target) +def target_is_rdna(target: Target) -> bool: + return _ffi_api.TargetIsRDNA(target) + + def target_is_gfx950(target: Target) -> bool: return _ffi_api.TargetIsGfx950(target) @@ -256,3 +334,7 @@ def target_has_bulk_copy(target: Target) -> bool: def target_get_warp_size(target: Target) -> int: return _ffi_api.TargetGetWarpSize(target) + + +def target_get_rdna_generation(target: Target) -> int: + return _ffi_api.TargetGetRDNAGeneration(target)