Skip to content

Add RDNA gfx1151 ROCm target support#2127

Open
lhl wants to merge 13 commits intotile-ai:mainfrom
lhl:rdna3-gfx1151
Open

Add RDNA gfx1151 ROCm target support#2127
lhl wants to merge 13 commits intotile-ai:mainfrom
lhl:rdna3-gfx1151

Conversation

@lhl
Copy link
Copy Markdown

@lhl lhl commented Apr 29, 2026

I saw there was a recent PR for RDNA3/3.5 including for Strix Halo but I had some problems w/ running things and it and turns out I had to fix some things to get it working.

  • Normalize ROCm arch strings, including feature suffixes such as gfx1151:sramecc+:xnack-.
  • Prefer HIP target selection when running under ROCm PyTorch, even if CUDA tooling is also discoverable.
  • Set ROCm target attrs consistently:
    • mcpu=gfx1151
    • mtriple=amdgcn-amd-amdhsa-hcc
    • thread_warp_size=32 for gfx10/gfx11/gfx12
    • thread_warp_size=64 for gfx9
  • Pass mcpu through HIP compilation paths so generated kernels compile for the actual ROCm target arch.
  • Add an RDNA carver device model and route RDNA HIP targets away from the CDNA path.
  • Add RDNA-aware tensorcore schedule scoring, biased toward 8-wave blocks.
  • Add HIP DP4A support:
    • uses __builtin_amdgcn_sudot4 for RDNA3/RDNA3.5/RDNA4 targets including gfx1151
    • uses __builtin_amdgcn_sdot4 on supported older AMD targets
    • keeps a scalar fallback for other HIP targets
  • Add focused ROCm target tests for gfx1151 detection, wave size, RDNA classification, carver routing, and auto-target behavior.

Tested on:

  • Hardware: Strix Halo / Radeon 8060S Graphics
  • GPU arch reported by PyTorch: gfx1151
  • OS: Linux x86_64
  • Python: 3.12.11 in the therock conda/mamba environment
  • TileLang commit under test: 12412659
  • Baseline parent commit: 4639c279
  • TVM: 0.23.dev0 from the TileLang submodule build
  • PyTorch: 2.10.0+rocm7.13.0a20260417
  • torch.version.hip: 7.13.26154
  • HIP compiler: HIP 7.13.26154-ca4b97ef2c, AMD clang 23.0.0git
  • Build mode: editable source build with ROCm enabled

Smoke Test:

  • Kernel: TileLang int8 SIMT matmul, M=N=K=512, int8 x int8 -> int32
  • Target: auto-detected HIP / gfx1151

Results after this patch:

  • run 0: 0.058230 ms, 4.610 TOPS
  • run 1: 0.055946 ms, 4.798 TOPS
  • run 2: 0.055425 ms, 4.843 TOPS
  • median: 0.055946 ms, 4.798 TOPS

Slow, but it works! (before compile just fails)

Tests Passed:

  • pre-commit install --install-hooks
  • pre-commit run --all-files
  • python3 -m pytest testing/python/target/test_tilelang_rocm_target.py -q
    • 6 passed
  • python -m py_compile ...
  • ruff check ...
  • ruff format --check ...
  • git diff --check
  • Direct HIP syntax probe for __builtin_amdgcn_sudot4 with --offload-arch=gfx1151

Fails (CDNA tests improperly scoped for RDNA...):

Fixed scope for CDNA tets:

  • Existing AMD MFMA tests compile CDNA-style MFMA builtins and fail on gfx1151:
    __builtin_amdgcn_mfma_f32_16x16x16f16 needs target feature mai-insts
  • The run later aborted inside:
    testing/python/debug/test_tilelang_debug_print.py::test_debug_print_buffer

I updated the tests that improperly target RDNA to be CDNA or in one case, gfx950-only. Tests pass for the architectures they're meant for.

Summary by CodeRabbit

  • New Features

    • RDNA architecture support for AMD GPUs (device model, generation-aware tensor-instruction lookup, and RDNA-aware routing/tuning).
    • DP4A dot-product support for HIP/AMD with a safe software fallback and template wrapper.
    • Improved HIP target detection and explicit architecture selection during HIP compilation.
  • Bug Fixes

    • Safer handling when optional loop/layout annotations are absent.
  • Tests

    • Expanded ROCm/CDNA tests: arch normalization, warp sizes, target selection/routing, and new CDNA/CUDA-or-CDNA test decorators.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 29, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds RDNA device support and RDNA-aware target utilities, introduces int8 DP4A intrinsics with an AMD builtin fallback, forwards target-derived mcpu to HIP compilation, expands test gating for RDNA/CDNA/CUDA, and guards optional IR annotations.

Changes

RDNA Architecture

Layer / File(s) Summary
Data / Export
tilelang/carver/arch/rdna.py, tilelang/carver/arch/__init__.py, tilelang/carver/__init__.py
Adds RDNA class, is_rdna_arch predicate, and re-exports RDNA from tilelang.carver.
Construction / Validation
tilelang/carver/arch/rdna.py
RDNA.__init__ normalizes Target, enforces RDNA generation==11, binds HIP device 0, and initializes hardware/capability fields.
Tensor Intrinsics
tilelang/carver/arch/rdna.py
Adds _get_tensor_instructions_for_generation and get_avaliable_tensorintrin_shapes() to return generation-specific tensor-intrin shapes and raise on unsupported generations.
Arch Selection
tilelang/carver/arch/__init__.py
get_arch now returns RDNA(target) for RDNA targets (gen 11) else CDNA(target); auto_infer_current_arch uses determine_target(..., return_object=True) for HIP resolution.

ROCm Target Utilities & Tests

Layer / File(s) Summary
Helpers / Constants
tilelang/utils/target.py
Adds ROCM_MTRIPLE, normalize_rocm_arch, target_get_mcpu, rocm_warp_size_for_arch, with_rocm_target_attrs, and public wrappers target_is_rdna, target_get_rdna_generation.
Target Determination
tilelang/utils/target.py
determine_target("auto") enriches the current TVM target with ROCm attributes and prefers a gfx-derived HIP Target when PyTorch HIP is detected; non-auto HIP Targets/strings are enriched when mcpu normalizes.
Tests
testing/python/target/test_tilelang_rocm_target.py
Adds tests for normalize_rocm_arch, target_get_mcpu, warp-size mapping, determine_target(..., return_object=True) enrichment, auto-target preference, RDNA classification, carver routing, and generation-aware tensor-intrin lookup.

HIP/AMD DP4A Intrinsics

Layer / File(s) Summary
Fallback / Builtins
src/tl_templates/hip/common.h
Adds tl_dp4a_fallback that sign-extends four int8 lanes and accumulates dot-products into an int accumulator.
Dispatcher / Template
src/tl_templates/hip/common.h
Adds tl_dp4a dispatcher selecting AMD sudot4/sdot4 builtins when available, and DP4A<InDatatype,OutDatatype> template that validates sizes and performs safe memcpy-based loads/stores.

Compilation / Tooling Integration

Layer / File(s) Summary
HIP compile wiring
tilelang/contrib/hipcc.py, tilelang/engine/lower.py, tilelang/jit/adapter/libgen.py
HIP compile paths now derive arch from target_get_mcpu(target) and forward it to compile_hip/--offload-arch, falling back to prior ROCm detection when mcpu is unavailable.

Carver Template & Policies

Layer / File(s) Summary
Template predicate
tilelang/carver/template/base.py
Adds BaseTemplate.is_rdna_arch() delegating to the new predicate.
Policy scoring
tilelang/carver/roller/policy/tensorcore.py
Adds RDNA-specific score_block_size preferring 8-warp configs and updates _assign_block_size to precompute space_prod and require divisibility by warps.

IR/Transform Robustness

Layer / File(s) Summary
Reducer handling
src/op/parallel.cc
Guard-checks presence of attr::kReducerInfo before casting/using reducer map.
Layout map handling
src/transform/layout_reducer.cc
Only populate layout_map from attr::kLayoutMap when annotation exists; otherwise keep default and merge safely.

Testing Decorators & Test Scope

Layer / File(s) Summary
Test utilities
tilelang/testing/__init__.py
Adds requires_cdna and requires_cuda_or_cdna decorators and target helpers for CDNA/CUDA-or-CDNA detection.
Test updates
testing/python/amd/*, testing/python/kernel/*
Updates several test decorators to gate on CDNA/gfx-specific targets or CUDA-or-CDNA.
Specific test additions
testing/python/target/test_tilelang_rocm_target.py
New unit tests validating ROCm normalization, warp sizes, auto-target selection, RDNA classification, carver routing, and generation-aware tensor-intrin lookup.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant TargetUtil as tilelang.utils.target
    participant CarverArch as tilelang.carver.arch
    participant RDNAClass as RDNA
    participant ROCmDevice as ROCm Device

    User->>TargetUtil: determine_target("auto" / Target)
    TargetUtil->>TargetUtil: normalize_rocm_arch, extract mcpu, add mtriple/thread_warp_size
    TargetUtil-->>CarverArch: enriched Target
    User->>CarverArch: get_arch(target)
    CarverArch->>CarverArch: target_is_rdna(target)?
    alt RDNA
        CarverArch->>RDNAClass: instantiate RDNA(target)
        RDNAClass->>ROCmDevice: bind device_0 via tvm.runtime.rocm(0)
        ROCmDevice-->>RDNAClass: device handle
        RDNAClass-->>CarverArch: RDNA instance
    else
        CarverArch-->>User: CDNA/CUDA arch
    end
    CarverArch-->>User: arch object
Loading
sequenceDiagram
    participant Builder as tilelang.engine.lower / libgen
    participant TargetUtil as tilelang.utils.target
    participant HIPcc as tilelang.contrib.hipcc
    participant HIPImpl as HIP Compiler / runtime

    Builder->>TargetUtil: target_get_mcpu(target)
    alt mcpu available
        Builder->>HIPcc: compile_hip(source, arch=mcpu)
        HIPcc->>HIPImpl: invoke compile with --offload-arch=mcpu
    else
        Builder->>HIPcc: compile_hip(source, arch=fallback_rocm_arch)
        HIPcc->>HIPImpl: invoke compile with detected arch
    end
    HIPImpl-->>HIPcc: compiled binary
    HIPcc-->>Builder: result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • zhangnju
  • Gongen-Ali

Poem

🐰
I hopped through headers, lanes of four,
sign-extended dots and RDNA lore.
mcpu whispered into the build,
tests aligned, the carver filled,
a tiny rabbit cheers — hip, explore! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.31% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'Add RDNA gfx1151 ROCm target support' clearly and concisely summarizes the main change: introducing support for RDNA gfx1151 architecture within the ROCm/HIP ecosystem. It is specific, descriptive, and directly reflects the core objective of the PR without being vague or misleading.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tilelang/contrib/hipcc.py (1)

19-27: Use the canonical ROCm arch parser instead of duplicating it.

This helper reimplements logic already centralized in tilelang/utils/target.py (target_get_mcpu / normalize_rocm_arch). Keeping two parsers risks drift.

♻️ Suggested simplification
+from tilelang.utils.target import target_get_mcpu
-
-def _target_mcpu(target):
-    try:
-        mcpu = target.attrs.get("mcpu")
-    except AttributeError:
-        return None
-    if mcpu is None:
-        return None
-    arch = str(mcpu).strip().split(":", maxsplit=1)[0]
-    return arch if arch.startswith("gfx") else None

And update the callback usage:

-    hsaco = compile_hip(code, target_format="hsaco", arch=_target_mcpu(target))
+    hsaco = compile_hip(code, target_format="hsaco", arch=target_get_mcpu(target))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/contrib/hipcc.py` around lines 19 - 27, Replace the duplicated ROCm
arch parsing logic in tilelang/contrib/hipcc.py with the canonical utilities
from tilelang/utils/target.py: import and call target_get_mcpu and/or
normalize_rocm_arch instead of the local parser, remove the duplicated
functions/branches in hipcc.py, and adjust any place that consumed the old local
parser output (including callback invocation) to accept the canonical function's
return shape; ensure the callback is called with the normalized ROCm arch value
returned by normalize_rocm_arch/target_get_mcpu so behavior remains identical.
tilelang/carver/roller/policy/tensorcore.py (1)

266-270: Minor cleanup: reuse np.prod(space) once.

Good feasibility check. You can avoid duplicate computation by storing the product once.

Refactor sketch
-        if np.prod(space) < warps:
+        space_prod = int(np.prod(space))
+        if space_prod < warps:
             return None

-        factors = factorize(np.prod(space) // warps)
+        factors = factorize(space_prod // warps)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/carver/roller/policy/tensorcore.py` around lines 266 - 270, The
feasibility check currently calls np.prod(space) twice; compute it once, store
it in a local variable (e.g., total = np.prod(space)) and replace both
occurrences with that variable to avoid duplicate work; update the scope where
np.prod(space) is used (the feasibility check block) so all references use the
new local variable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/tl_templates/hip/common.h`:
- Around line 149-159: The DP4A template is unsafe for non-4-byte types because
it uses fixed 4-byte __builtin_memcpy into int temporaries; add compile-time
checks inside DP4A to prevent instantiation with incompatible types by adding
static_asserts that sizeof(InDatatype) == 4 and sizeof(OutDatatype) == 4 (and
optionally that std::is_trivially_copyable<InDatatype>::value and
std::is_trivially_copyable<OutDatatype>::value) so only 4-byte,
trivially-copyable operand/result types can be used; keep the checks in the DP4A
template (which calls tl_dp4a and uses __builtin_memcpy) to fail fast at compile
time.

---

Nitpick comments:
In `@tilelang/carver/roller/policy/tensorcore.py`:
- Around line 266-270: The feasibility check currently calls np.prod(space)
twice; compute it once, store it in a local variable (e.g., total =
np.prod(space)) and replace both occurrences with that variable to avoid
duplicate work; update the scope where np.prod(space) is used (the feasibility
check block) so all references use the new local variable.

In `@tilelang/contrib/hipcc.py`:
- Around line 19-27: Replace the duplicated ROCm arch parsing logic in
tilelang/contrib/hipcc.py with the canonical utilities from
tilelang/utils/target.py: import and call target_get_mcpu and/or
normalize_rocm_arch instead of the local parser, remove the duplicated
functions/branches in hipcc.py, and adjust any place that consumed the old local
parser output (including callback invocation) to accept the canonical function's
return shape; ensure the callback is called with the normalized ROCm arch value
returned by normalize_rocm_arch/target_get_mcpu so behavior remains identical.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bb907b2d-7064-4071-adf6-72f2346e742f

📥 Commits

Reviewing files that changed from the base of the PR and between 4639c27 and 1241265.

📒 Files selected for processing (11)
  • src/tl_templates/hip/common.h
  • testing/python/target/test_tilelang_rocm_target.py
  • tilelang/carver/__init__.py
  • tilelang/carver/arch/__init__.py
  • tilelang/carver/arch/rdna.py
  • tilelang/carver/roller/policy/tensorcore.py
  • tilelang/carver/template/base.py
  • tilelang/contrib/hipcc.py
  • tilelang/engine/lower.py
  • tilelang/jit/adapter/libgen.py
  • tilelang/utils/target.py

Comment thread src/tl_templates/hip/common.h
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/tl_templates/hip/common.h (1)

149-163: ⚠️ Potential issue | 🔴 Critical

Fix the DP4A carrier-size check.

__builtin_memcpy(&a_int, a, sizeof(a_int)) and the matching b/c copies read 4 bytes, so sizeof(InDatatype) == 1 lets this template read past the pointed-to object. Require a 4-byte packed carrier here, or shrink the memcpy width to match the asserted size.

🔧 Proposed fix
 template <typename InDatatype, typename OutDatatype>
 TL_DEVICE void DP4A(const InDatatype *a, const InDatatype *b, OutDatatype *c) {
-  static_assert(sizeof(InDatatype) == 1,
-                "DP4A expects a pointer to packed int8 lanes");
+  static_assert(sizeof(InDatatype) == sizeof(int),
+                "DP4A expects a 4-byte packed carrier");
   static_assert(sizeof(OutDatatype) == sizeof(int),
                 "DP4A expects 4-byte accumulator/output type");

Run the following script to inspect DP4A call sites and confirm the carrier types they use:

#!/bin/bash
set -euo pipefail
rg -nP -C 2 '\bDP4A\s*(<|\()' src testing
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/hip/common.h` around lines 149 - 163, The DP4A template
currently asserts sizeof(InDatatype) == 1 but then memcpy's 4 bytes from the
pointers (a/b/c), so change the carrier-size check to require a 4-byte packed
carrier: replace the static_assert for InDatatype with sizeof(InDatatype) == 4
(update the error message accordingly) so the template's memcpy calls (using
a_int/b_int/c_int) are safe; keep the OutDatatype/int-size assert as-is and
leave the memcpy widths unchanged. Also run the provided rg script to verify all
DP4A call sites use 4-byte carrier types.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/carver/roller/policy/tensorcore.py`:
- Around line 266-270: The code currently computes factors from space_prod //
warps which truncates remainders; change it to first require exact divisibility
by checking space_prod % warps == 0 and return None if not, then call factorize
on the exact quotient (space_prod // warps) — update the logic around the
variables space_prod and warps and the call to factorize to enforce this
divisibility check before factorization.

---

Duplicate comments:
In `@src/tl_templates/hip/common.h`:
- Around line 149-163: The DP4A template currently asserts sizeof(InDatatype) ==
1 but then memcpy's 4 bytes from the pointers (a/b/c), so change the
carrier-size check to require a 4-byte packed carrier: replace the static_assert
for InDatatype with sizeof(InDatatype) == 4 (update the error message
accordingly) so the template's memcpy calls (using a_int/b_int/c_int) are safe;
keep the OutDatatype/int-size assert as-is and leave the memcpy widths
unchanged. Also run the provided rg script to verify all DP4A call sites use
4-byte carrier types.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: df20e00d-b6ab-4cfa-813d-270f143be6a1

📥 Commits

Reviewing files that changed from the base of the PR and between 1241265 and 51acb3c.

📒 Files selected for processing (6)
  • src/op/parallel.cc
  • src/tl_templates/hip/common.h
  • src/transform/layout_reducer.cc
  • testing/python/target/test_tilelang_rocm_target.py
  • tilelang/carver/roller/policy/tensorcore.py
  • tilelang/contrib/hipcc.py

Comment thread tilelang/carver/roller/policy/tensorcore.py
@LeiWang1999 LeiWang1999 requested a review from zhangnju April 30, 2026 05:47
@lhl
Copy link
Copy Markdown
Author

lhl commented Apr 30, 2026

Took a look at the failing tests and fixed w/ separate commits:

  • 462e4a0b Scope MFMA tests to CDNA targets

    • Adds requires_cdna
    • Skips MFMA intrinsic/preshuffle tests on RDNA
  • 35d53bf8 Scope gfx950 copy async GEMM test

    • Marks the non-pipelined gfx950 copy.async GEMM test as requires_gfx950
  • d77e26d1 Skip CUDA/CDNA GEMM dtype tests on RDNA

    • Adds requires_cuda_or_cdna
    • Skips generic GEMM int8/float32 matrix-core dtype tests on RDNA, where current WMMA support is f16/bf16
  • 355ffe28 Reject non-divisible warp layouts

    • Addresses CodeRabbit feedback by rejecting tensorcore layouts where space_prod % warps != 0 before factorization

Validation:

  • MFMA/gfx950 scoping behaves as expected on RDNA
  • Kernel GEMM file: 3 passed, 13 skipped
  • Runtime-looking ndim expected 2, got ... repros pass when therock has a usable host compiler path

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/carver/roller/policy/tensorcore.py`:
- Around line 215-219: In score_block_size, the current warp count uses ceiling
division which overestimates warps for non-aligned n; change the computation in
tensorcore.score_block_size to use floor division (compute warps = n //
self.arch.warp_size) when checking getattr(self.arch, "platform", None) ==
"RDNA" so underfilled blocks do not appear as full 8-warp occupancy; keep the
rest of the return tuple logic ((0 if warps == 8 else 1, abs(warps - 8),
*base_score)) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: eabb4224-6f88-4935-9b78-a5706998c4c3

📥 Commits

Reviewing files that changed from the base of the PR and between 51acb3c and d77e26d.

📒 Files selected for processing (6)
  • testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
  • testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
  • testing/python/amd/test_tilelang_gfx950_copy_async.py
  • testing/python/kernel/test_tilelang_kernel_gemm.py
  • tilelang/carver/roller/policy/tensorcore.py
  • tilelang/testing/__init__.py
✅ Files skipped from review due to trivial changes (1)
  • testing/python/amd/test_tilelang_gfx950_copy_async.py

Comment thread tilelang/carver/roller/policy/tensorcore.py
zhangnju
zhangnju previously approved these changes May 4, 2026
Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution.

Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contributions to AMD Tilelang. This PR adds first-class support for AMD RDNA GPUs (specifically gfx1151/Strix Halo) in TileLang. The work is well-structured across the layers: target detection, device modeling, and code generation. But because some codes of this PR will be used by RDNA4, we may also need to check whether the codes can work well for RDAN4, like some hardcoding.

Comment thread src/op/parallel.cc
Comment thread src/tl_templates/hip/common.h
Comment thread tilelang/carver/arch/rdna.py
Comment thread tilelang/carver/roller/policy/tensorcore.py
Comment thread tilelang/carver/arch/rdna.py
@LeiWang1999
Copy link
Copy Markdown
Member

@lhl Thanks for you contributions! would you mind resolving those comments?

@lhl
Copy link
Copy Markdown
Author

lhl commented May 5, 2026

@lhl Thanks for you contributions! would you mind resolving those comments?

I'll take a look soon but, and I can add guards to make sure it only supports the GPUs I have (gfx1100, gfx1151) but note: I don't have or have access to RNDA4 so besides checking that it's not affecting it, I won't be doing any RDNA4 optimization and will leave that for those that want to do that. I'll scope specifically to RDNA3/3.5

@lhl
Copy link
Copy Markdown
Author

lhl commented May 5, 2026

@LeiWang1999 scope is guarded to gfx11 and all feedback is resolved now

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
testing/python/target/test_tilelang_rocm_target.py (2)

67-83: 💤 Low value

The torch monkeypatches in test_carver_routes_rdna_without_instantiating_device don't affect get_arch.

torch.version.hip, torch.cuda.is_available, and torch.mps.is_available are only consulted by auto_infer_current_arch, not by get_arch(Target(...)) which is called directly on line 81. These three patches are dead in this test and could be removed, leaving only the arch_mod.RDNA monkeypatch that actually prevents the device probe.

🧹 Proposed cleanup
 def test_carver_routes_rdna_without_instantiating_device(monkeypatch):
-    import torch
-
-    monkeypatch.setattr(torch.version, "hip", None, raising=False)
-    monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
-    if hasattr(torch, "mps"):
-        monkeypatch.setattr(torch.mps, "is_available", lambda: False, raising=False)
-
     import tilelang.carver.arch as arch_mod

     def fake_rdna(target):
         return ("rdna", target)

     monkeypatch.setattr(arch_mod, "RDNA", fake_rdna)
     arch = arch_mod.get_arch(Target("hip -mcpu=gfx1151"))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/target/test_tilelang_rocm_target.py` around lines 67 - 83, The
test function test_carver_routes_rdna_without_instantiating_device contains
irrelevant monkeypatches for torch.version.hip, torch.cuda.is_available, and
torch.mps.is_available that do not affect get_arch; remove the three
monkeypatch.setattr calls (the torch.version.hip, torch.cuda.is_available, and
torch.mps.is_available lines) and keep only the monkeypatch of arch_mod.RDNA and
the call to arch_mod.get_arch(Target("hip -mcpu=gfx1151")) so the test solely
prevents device probing via the RDNA stub.

19-27: 💤 Low value

test_normalize_rocm_arch_strips_feature_suffix mixes two distinct API surfaces.

The function tests both normalize_rocm_arch and rocm_warp_size_for_arch in the same test case. Consider splitting into separate test functions so failures are easier to diagnose. This is a low-priority style concern.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/target/test_tilelang_rocm_target.py` around lines 19 - 27,
Split the mixed assertions in test_normalize_rocm_arch_strips_feature_suffix
into two focused tests: keep only normalize_rocm_arch assertions in a renamed
test (e.g., test_normalize_rocm_arch) that asserts
normalize_rocm_arch("gfx1151:sramecc+:xnack-")== "gfx1151",
normalize_rocm_arch("gfx942")=="gfx942", normalize_rocm_arch("") is None, and
normalize_rocm_arch("sm_90") is None; move the rocm_warp_size_for_arch
assertions into a separate test (e.g., test_rocm_warp_size_for_arch) that
asserts rocm_warp_size_for_arch("gfx1151")==32,
rocm_warp_size_for_arch("gfx1030")==32, rocm_warp_size_for_arch("gfx1200")==32,
and rocm_warp_size_for_arch("gfx942")==64 so failures pinpoint the correct API
surface.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/tl_templates/hip/common.h`:
- Around line 129-136: The TL_AMDGPU_HAS_SDOT4 feature-guard incorrectly
includes __gfx1010__, which lacks dot1-insts and causes use of
__builtin_amdgcn_sdot4 on unsupported targets; update the preprocessor check
around the TL_AMDGPU_HAS_SDOT4 macro by removing the defined(__gfx1010__) entry
so that __gfx1010__ falls through to the fallback implementation (modify the
`#if/`#endif block that defines TL_AMDGPU_HAS_SDOT4 accordingly).

In `@tilelang/carver/arch/__init__.py`:
- Around line 22-25: The current branch that handles HIP targets silently maps
any RDNA target not equal to generation 11 into CDNA, causing wrong defaults;
update the HIP branch in __init__.py that checks target.kind.name == "hip" so
that after calling target_is_rdna(target) and reading
target_get_rdna_generation(target) you explicitly reject unsupported RDNA
generations (e.g. raise a ValueError with a clear message naming the detected
generation) instead of returning CDNA(target); if you prefer to keep a fallback,
add a clear inline comment and a warning log before returning CDNA, but the
preferred fix is to raise an explicit error when generation != 11 so RDNA
targets like gfx1030/gfx1200 are not silently misconfigured for CDNA defaults.

In `@tilelang/carver/arch/rdna.py`:
- Around line 45-47: The method get_avaliable_tensorintrin_shapes currently sets
self.available_tensor_instructions to a tuple (TensorInstruction("wmma", [16,
16]),) which contradicts the declared type list[TensorInstruction] | None;
change the assignment to a list (e.g. [TensorInstruction("wmma", [16, 16])]) so
callers can treat available_tensor_instructions as a list (support
isinstance(..., list) and .append()), and return [t.shape for t in
self.available_tensor_instructions] as before; keep references to the attribute
available_tensor_instructions and the TensorInstruction constructor in the fix.

---

Nitpick comments:
In `@testing/python/target/test_tilelang_rocm_target.py`:
- Around line 67-83: The test function
test_carver_routes_rdna_without_instantiating_device contains irrelevant
monkeypatches for torch.version.hip, torch.cuda.is_available, and
torch.mps.is_available that do not affect get_arch; remove the three
monkeypatch.setattr calls (the torch.version.hip, torch.cuda.is_available, and
torch.mps.is_available lines) and keep only the monkeypatch of arch_mod.RDNA and
the call to arch_mod.get_arch(Target("hip -mcpu=gfx1151")) so the test solely
prevents device probing via the RDNA stub.
- Around line 19-27: Split the mixed assertions in
test_normalize_rocm_arch_strips_feature_suffix into two focused tests: keep only
normalize_rocm_arch assertions in a renamed test (e.g.,
test_normalize_rocm_arch) that asserts
normalize_rocm_arch("gfx1151:sramecc+:xnack-")== "gfx1151",
normalize_rocm_arch("gfx942")=="gfx942", normalize_rocm_arch("") is None, and
normalize_rocm_arch("sm_90") is None; move the rocm_warp_size_for_arch
assertions into a separate test (e.g., test_rocm_warp_size_for_arch) that
asserts rocm_warp_size_for_arch("gfx1151")==32,
rocm_warp_size_for_arch("gfx1030")==32, rocm_warp_size_for_arch("gfx1200")==32,
and rocm_warp_size_for_arch("gfx942")==64 so failures pinpoint the correct API
surface.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 61d1e99c-e77f-4ed3-a77e-1cd816f3aa3e

📥 Commits

Reviewing files that changed from the base of the PR and between d77e26d and 09d25c9.

📒 Files selected for processing (5)
  • src/tl_templates/hip/common.h
  • testing/python/target/test_tilelang_rocm_target.py
  • tilelang/carver/arch/__init__.py
  • tilelang/carver/arch/rdna.py
  • tilelang/utils/target.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/utils/target.py

Comment thread src/tl_templates/hip/common.h
Comment thread tilelang/carver/arch/__init__.py
Comment thread tilelang/carver/arch/rdna.py
@lhl
Copy link
Copy Markdown
Author

lhl commented May 6, 2026

@LeiWang1999 @zhangnju FYI all manual and automated issues are reviewed/resolved. This patch is tightly scoped to gfx11.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants