Skip to content

[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132

Open
zhangnju wants to merge 2 commits intotile-ai:mainfrom
zhangnju:gfx950_mxfp4
Open

[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132
zhangnju wants to merge 2 commits intotile-ai:mainfrom
zhangnju:gfx950_mxfp4

Conversation

@zhangnju
Copy link
Copy Markdown
Collaborator

@zhangnju zhangnju commented Apr 30, 2026

This PR adds end-to-end MXFP4 (FP4 E2M1) support for AMD gfx950 (CDNA4 /MI350) in the HIP backend.

HIP codegen (codegen_hip.cc/.h)

  • Added GetFP4Type() to map float4_e2m1fn DataType to HIP C types (fp4_e2_t, fp4_e2_2_t, ..., fp4_e2_32_t).
  • Implemented vectorized FP4 ↔ float16 / float32 / bfloat16 / float64 cast codegen, processing two lanes at a time via pairwise helpers (__tl_cvt_fp4x2_to_half2, __tl_cvt_fp4x2_to_bfloat162, etc.).
  • Added enable_fp4_ flag: when a FP4 type is encountered, Finish() automatically includes hip_fp4.h.

New header src/tl_templates/hip/hip_fp4.h

  • Defines FP4 scalar and vector types for gfx950.
  • Provides conversion intrinsic wrappers, guarded by #if defined(gfx950).

Dequantization kernels (tilelang/quantize/mxfp.py)

  • decode_f4_to_bf16_twiddling_hip: ports the CUDA PTX bit-twiddling algorithm to portable HIP C++ (no inline PTX), numerically equivalent to the CUDA reference.
  • decode_f4_to_bf16_simple_hip: a static-LUT fallback path for non-twiddling dequantization.

MFMA layout (tilelang/intrinsics/mfma_layout.py)

  • Extended MFMA matrix layout to support FP4 input operands on gfx950.

Tests & examples

  • testing/python/amd/test_tilelang_mxfp4_gfx950.py: covers FP4 copy, vectorized casts, and MXFP4 dequant-GEMM (both twiddling and simple paths); guarded by @requires_gfx950, silently skipped on other targets.
  • examples/: added an end-to-end BF16-output MXFP4 GEMM example (example_dequant_gemm_bf16_mxfp4_cdna4.py).

CI Test:

  • MI300 CI test: Pass
  • MI350 CI test: Pass

Summary by CodeRabbit

  • New Features

    • Added gfx950-exclusive MXFP4 dequantize-GEMM example with bias support
    • Introduced FP4 support for AMD gfx950 GPUs with vectorized conversions
    • Extended quantization framework to support gfx950 targets
    • Enhanced MFMA layout register mapping for 32x32 operations
  • Tests

    • Added comprehensive test suite for MXFP4 dequantization and FP4 operations

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

📝 Walkthrough

Walkthrough

Introduces FP4 (4-bit floating point) support for AMD gfx950 in TileLang, comprising a dequantize-GEMM example kernel, HIP code generation support for FP4 type handling and conversions, packed FP4 buffer management, intrinsic conversion routines, gfx950-specific tests, and updates to MFMA layout register mappings.

Changes

Cohort / File(s) Summary
Example Kernel
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.py
New gfx950 MXFP4 dequantize-GEMM TileLang kernel supporting both fast intrinsic and simple dequantization paths, with PyTorch reference implementations, optional bias, and benchmark reporting.
HIP Code Generation
src/target/codegen_hip.cc, src/target/codegen_hip.h
Adds FP4 type support to HIP codegen with specialized handling for FP4↔float16/bfloat16 vectorized casting, packed scalar FP4 buffer allocation, and packed load/store helpers; introduces enable_fp4_ gate and fp4_packed_buffers_ mapping.
FP4 Intrinsics & Templates
src/tl_templates/hip/hip_fp4.h
New HIP header providing FP4 E2M1 types, float↔FP4 bit-manipulation conversions, packed FP4↔float2/half2/double2/bfloat162 transformations, and packed buffer element access routines.
Quantization Framework
tilelang/quantize/mxfp.py
Extends get_mxfp_intrin_group to detect gfx950 target and return HIP C++ source with two new dequantization paths (twiddling and LUT-based) for fp4→bf16 conversion.
Register Mapping Updates
tilelang/intrinsics/mfma_layout.py
Updates 32x32 MFMA C-layout thread-id to (row, column) mapping using tid_high grouping and intermediate k computation for improved register indexing.
Functional Tests
testing/python/amd/test_tilelang_mxfp4_gfx950.py
Adds gfx950-gated test suite for FP4 round-trip copy, cast validation, MXFP4 dequantize-GEMM correctness (simple and twiddling paths), and source-generation assertions.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Gongen-Ali
  • LeiWang1999

Poem

🐰 Four bits per float, a nibble so small,
gfx950's new champion answers the call!
From FP4 to BF16, conversions do dance,
With intrinsics and kernels, we give quantization a chance!
Pack, dequant, and multiply—efficiency's might. 🚀

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950' accurately and specifically describes the main objective of the PR: adding FP4 E2M1 support for AMD's gfx950 architecture.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/target/codegen_hip.cc (1)

1727-1757: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scalar FP4 packing is only applied to local, but the accessors assume it everywhere.

GetBufferRef() (Line 1079 onward) and the new BufferStoreNode path (Line 1789 onward) always use tl_fp4_packed_load/store for scalar FP4 buffers, but this allocator only packs scope == "local". shared and local.var allocations still have fp4_e2_t layout, so their logical indexing no longer matches the physical storage and neighboring elements can alias/corrupt each other. Either pack every scalar-FP4 allocation that goes through the nibble helpers, or gate those helpers to buffers recorded in fp4_packed_buffers_.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_hip.cc` around lines 1727 - 1757, The bug: GetBufferRef
and the BufferStoreNode path assume scalar-FP4 buffers are packed but the
allocator only packs when is_fp4_scalar_local (scope == "local"), causing
mis-indexing for shared/local.var buffers; to fix, either (A) consistently emit
packed storage for every scalar FP4 allocation that uses the nibble helpers
(change the allocation path that creates fp4_e2_2_t in the code that sets
fp4_packed_buffers_), or (B) safer minimal change—gate all uses of
tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in
fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling
to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before
emitting packed helpers, and fall back to fp4_e2_t accessors for others; update
any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure
fp4_packed_buffers_ is the single source-of-truth.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/target/codegen_hip.cc`:
- Around line 806-808: The fp4 handling currently sets fp4_pair_cast based on
fp4_lanes (variable fp4_lanes) but only allows 2/4/8, which lets x16/x32 fall
through to generic vector casting and causes PrintVecElemLoad/Store to emit
invalid fp4_e2_* accesses; either extend the pairwise lowering to support 16 and
32 lanes (update fp4_pair_cast to include 16 and 32 and implement corresponding
pairwise lowering logic) or explicitly reject unsupported fp4 lane widths by
adding an ICHECK (or CHECK) where fp4_pair_cast is computed to assert fp4_lanes
is one of the supported values; update the same area in
src/target/codegen_hip.cc so casts for FP4x16/FP4x32 never fall through to
PrintVecElemLoad/Store.
- Around line 818-827: The code currently takes the address of a prvalue from
__tl_cvt_fp4x2_to_half2, which is invalid in C++; instead, materialize the
conversion into a local half2 temporary (e.g. declare a variable like half2
tmp_half2 = __tl_cvt_fp4x2_to_half2(((uint8_t*)&(src))[i/2]);) and then build
v0/v1 to reference its components (prefer member access like tmp_half2.x and
tmp_half2.y) before calling PrintVecElemStore(sret, target_ty, i, ...) and
PrintVecElemStore(sret, target_ty, i+1, ...). Ensure the temporary name is
unique per loop iteration to avoid collisions.

In `@src/tl_templates/hip/hip_fp4.h`:
- Around line 35-38: The FP4 decoder in hip_fp4.h currently maps exp==0 and
mant==1 to 0.25f, which conflicts with __tl_float_to_fp4() and the new LUTs that
treat nibble 1 as 0.5; update the denormal handling in the decoder (the exp==0
branch that sets result based on mant) to return 0.5f for mant==1 (instead of
0.25f) so decoding matches the encoder and the mxfp LUTs (or alternatively make
the encoder/LUTs consistent with the decoder, but prefer updating the decoder to
match __tl_float_to_fp4() and tilelang/quantize/mxfp.py).

In `@tilelang/quantize/mxfp.py`:
- Around line 201-207: The current try/except around importing and calling
target_is_gfx950 silently swallows all exceptions causing a HIP target to fall
back to the CUDA/PTX path; update the block that references target_is_gfx950 and
_is_gfx950 to only catch expected import/availability errors (e.g.,
ImportError/ModuleNotFoundError/AttributeError) or, if any other exception is
raised by target_is_gfx950, re-raise it with additional context (e.g., include
the provided target) instead of pass so target-detection failures do not
silently change backend selection.

---

Outside diff comments:
In `@src/target/codegen_hip.cc`:
- Around line 1727-1757: The bug: GetBufferRef and the BufferStoreNode path
assume scalar-FP4 buffers are packed but the allocator only packs when
is_fp4_scalar_local (scope == "local"), causing mis-indexing for
shared/local.var buffers; to fix, either (A) consistently emit packed storage
for every scalar FP4 allocation that uses the nibble helpers (change the
allocation path that creates fp4_e2_2_t in the code that sets
fp4_packed_buffers_), or (B) safer minimal change—gate all uses of
tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in
fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling
to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before
emitting packed helpers, and fall back to fp4_e2_t accessors for others; update
any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure
fp4_packed_buffers_ is the single source-of-truth.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 529c506e-c76e-4ad7-82a3-d5c03629a45b

📥 Commits

Reviewing files that changed from the base of the PR and between 936ae92 and fb1d0c3.

📒 Files selected for processing (7)
  • examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.py
  • src/target/codegen_hip.cc
  • src/target/codegen_hip.h
  • src/tl_templates/hip/hip_fp4.h
  • testing/python/amd/test_tilelang_mxfp4_gfx950.py
  • tilelang/intrinsics/mfma_layout.py
  • tilelang/quantize/mxfp.py

Comment thread src/target/codegen_hip.cc
Comment on lines +806 to +808
int fp4_lanes = from_ty.lanes();
bool fp4_pair_cast = (fp4_lanes == 2 || fp4_lanes == 4 || fp4_lanes == 8);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject FP4x16/FP4x32 casts instead of falling through to generic vector code.

The new FP4 types are exposed up to 32 lanes, but fp4_pair_cast only handles 2/4/8. For x16/x32 this drops into the generic vector cast path, and PrintVecElemLoad/Store don't know how to index fp4_e2_* aggregates, so codegen will emit invalid accessors instead of a real cast. Please either extend the pairwise lowering to those lane counts or ICHECK them as unsupported here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_hip.cc` around lines 806 - 808, The fp4 handling currently
sets fp4_pair_cast based on fp4_lanes (variable fp4_lanes) but only allows
2/4/8, which lets x16/x32 fall through to generic vector casting and causes
PrintVecElemLoad/Store to emit invalid fp4_e2_* accesses; either extend the
pairwise lowering to support 16 and 32 lanes (update fp4_pair_cast to include 16
and 32 and implement corresponding pairwise lowering logic) or explicitly reject
unsupported fp4 lane widths by adding an ICHECK (or CHECK) where fp4_pair_cast
is computed to assert fp4_lanes is one of the supported values; update the same
area in src/target/codegen_hip.cc so casts for FP4x16/FP4x32 never fall through
to PrintVecElemLoad/Store.

Comment thread src/target/codegen_hip.cc
Comment on lines +818 to +827
for (int i = 0; i < fp4_lanes; i += 2) {
std::ostringstream val;
val << "__tl_cvt_fp4x2_to_half2(((uint8_t*)&(" << src << "))[" << i / 2
<< "])";
// Store both elements of the half2
std::ostringstream v0, v1;
v0 << "((half_t*)(&(" << val.str() << ")))[0]";
v1 << "((half_t*)(&(" << val.str() << ")))[1]";
PrintVecElemStore(sret, target_ty, i, v0.str());
PrintVecElemStore(sret, target_ty, i + 1, v1.str());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In C++ (and HIP/Clang), is taking the address of a temporary/prvalue returned from a function expression like &(f()) valid, and what is the recommended way to read packed fields from such a return value?

💡 Result:

No, taking the address of a temporary/prvalue returned from a function using the built-in & operator, as in &(f), is invalid in standard C++. The unary & operator requires an lvalue operand ([expr.unary.op]), and f (a prvalue) is not an lvalue. This is ill-formed; compilers may diagnose it as an error or warning, but it leads to undefined behavior if accepted as an extension. std::addressof(f) works because it takes an lvalue reference (T&) and uses reinterpret_cast to obtain the address, bypassing operator& overloads. However, since C++17 (LWG 2598), std::addressof has a deleted overload for const T&& to prevent use on temporaries, though it binds to non-const T&. The pointer is valid only until the end of the full-expression containing the call (no lifetime extension without direct reference binding). This behavior is standard in C++ (including Clang); no HIP-specific differences found in Clang/HIP docs. For reading packed fields from such a return value (e.g., struct with #pragma pack(1) or attribute((packed))): - Direct member access on the temporary (e.g., f.field) is valid and safe within the expression, as the compiler handles potential unaligned access. - Avoid storing the address or pointer for later use (dangling). - Recommended safe way: Copy the temporary to a local variable (extends lifetime) or use memcpy for fields: auto tmp = f; // Copies/moves, lifetime to end of scope auto field = tmp.packed_field; // Safe aligned access on local Or for immediate use without storage: int field; memcpy(&field, &f.packed_field, sizeof(field)); // Use std::addressof(f.packed_field) if needed This avoids unaligned access penalties/crashes (e.g., on ARM) while respecting lifetimes. Packed structs are for serialization; parse into aligned host structs for general use.

Citations:


🏁 Script executed:

cat -n src/target/codegen_hip.cc | sed -n '810,835p'

Repository: tile-ai/tilelang

Length of output: 1368


🏁 Script executed:

# Search for BF16 cast handling to see the correct pattern
rg -n "bf16\|BF16\|bfloat" src/target/codegen_hip.cc -A 8 -B 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for float32 and float64 cast paths to see the correct pattern
rg -n "from_ty.is_float4_e2m1fn.*float" src/target/codegen_hip.cc -A 15 | head -60

Repository: tile-ai/tilelang

Length of output: 3076


🏁 Script executed:

# Search more broadly for how temporaries are materialized in similar code
rg -n "ostringstream\|PrintVecElemStore" src/target/codegen_hip.cc -B 3 -A 3 | head -80

Repository: tile-ai/tilelang

Length of output: 42


Materialize the half2 temporary before taking its address.

Lines 824–825 generate code that takes &(__tl_cvt_fp4x2_to_half2(...)), which attempts to take the address of a prvalue (temporary), a violation in C++. The float32 (lines 845–849) and float64 (lines 868–872) paths already demonstrate the correct pattern: materialize the return value into a local variable first, then access members by name. Apply the same approach here: create a half2 temporary, store the conversion result into it, and use member access instead of pointer arithmetic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_hip.cc` around lines 818 - 827, The code currently takes
the address of a prvalue from __tl_cvt_fp4x2_to_half2, which is invalid in C++;
instead, materialize the conversion into a local half2 temporary (e.g. declare a
variable like half2 tmp_half2 =
__tl_cvt_fp4x2_to_half2(((uint8_t*)&(src))[i/2]);) and then build v0/v1 to
reference its components (prefer member access like tmp_half2.x and tmp_half2.y)
before calling PrintVecElemStore(sret, target_ty, i, ...) and
PrintVecElemStore(sret, target_ty, i+1, ...). Ensure the temporary name is
unique per loop iteration to avoid collisions.

Comment on lines +35 to +38
if (exp == 0u) {
// Denormal: value = (-1)^s * 2^(-1) * (0 + m*0.5) = (-1)^s * m * 0.25
result = mant ? 0.25f : 0.0f;
} else {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

The FP4 decoder is inconsistent with the encoder and LUTs.

Here exp == 0 && mant == 1 decodes to 0.25f, but __tl_float_to_fp4() below encodes 0.5f as nibble 1, and the new LUTs in tilelang/quantize/mxfp.py / testing/python/amd/test_tilelang_mxfp4_gfx950.py also decode nibble 1 as 0.5. That means FP4 values won't round-trip and every FP4→float/half/bf16 cast for raw value 1/9 is off by 2x. Either the decoder needs to return 0.5f here, or the encoder/LUTs need to be updated to the same encoding.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/tl_templates/hip/hip_fp4.h` around lines 35 - 38, The FP4 decoder in
hip_fp4.h currently maps exp==0 and mant==1 to 0.25f, which conflicts with
__tl_float_to_fp4() and the new LUTs that treat nibble 1 as 0.5; update the
denormal handling in the decoder (the exp==0 branch that sets result based on
mant) to return 0.5f for mant==1 (instead of 0.25f) so decoding matches the
encoder and the mxfp LUTs (or alternatively make the encoder/LUTs consistent
with the decoder, but prefer updating the decoder to match __tl_float_to_fp4()
and tilelang/quantize/mxfp.py).

Comment thread tilelang/quantize/mxfp.py
Comment on lines +201 to +207
if target is not None:
try:
from tilelang.utils.target import target_is_gfx950

_is_gfx950 = target_is_gfx950(target)
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't silently fall back to PTX when gfx950 detection fails.

If target_is_gfx950(target) raises here, the function quietly returns the CUDA/PTX implementation for a HIP target. That turns a target-detection bug into a much harder-to-debug backend mismatch later in compilation. Catch only the expected import/availability failure, or re-raise with context instead of pass.

🧰 Tools
🪛 Ruff (0.15.12)

[error] 206-207: try-except-pass detected, consider logging the exception

(S110)


[warning] 206-206: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/quantize/mxfp.py` around lines 201 - 207, The current try/except
around importing and calling target_is_gfx950 silently swallows all exceptions
causing a HIP target to fall back to the CUDA/PTX path; update the block that
references target_is_gfx950 and _is_gfx950 to only catch expected
import/availability errors (e.g.,
ImportError/ModuleNotFoundError/AttributeError) or, if any other exception is
raised by target_is_gfx950, re-raise it with additional context (e.g., include
the provided target) instead of pass so target-detection failures do not
silently change backend selection.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant