[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132
[AMD][CDNA4] Add MXFP4 (FP4 E2M1) support for gfx950#2132zhangnju wants to merge 2 commits intotile-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughIntroduces FP4 (4-bit floating point) support for AMD gfx950 in TileLang, comprising a dequantize-GEMM example kernel, HIP code generation support for FP4 type handling and conversions, packed FP4 buffer management, intrinsic conversion routines, gfx950-specific tests, and updates to MFMA layout register mappings. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/target/codegen_hip.cc (1)
1727-1757:⚠️ Potential issue | 🟠 Major | ⚡ Quick winScalar FP4 packing is only applied to
local, but the accessors assume it everywhere.
GetBufferRef()(Line 1079 onward) and the newBufferStoreNodepath (Line 1789 onward) always usetl_fp4_packed_load/storefor scalar FP4 buffers, but this allocator only packsscope == "local".sharedandlocal.varallocations still havefp4_e2_tlayout, so their logical indexing no longer matches the physical storage and neighboring elements can alias/corrupt each other. Either pack every scalar-FP4 allocation that goes through the nibble helpers, or gate those helpers to buffers recorded infp4_packed_buffers_.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/target/codegen_hip.cc` around lines 1727 - 1757, The bug: GetBufferRef and the BufferStoreNode path assume scalar-FP4 buffers are packed but the allocator only packs when is_fp4_scalar_local (scope == "local"), causing mis-indexing for shared/local.var buffers; to fix, either (A) consistently emit packed storage for every scalar FP4 allocation that uses the nibble helpers (change the allocation path that creates fp4_e2_2_t in the code that sets fp4_packed_buffers_), or (B) safer minimal change—gate all uses of tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before emitting packed helpers, and fall back to fp4_e2_t accessors for others; update any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure fp4_packed_buffers_ is the single source-of-truth.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/target/codegen_hip.cc`:
- Around line 806-808: The fp4 handling currently sets fp4_pair_cast based on
fp4_lanes (variable fp4_lanes) but only allows 2/4/8, which lets x16/x32 fall
through to generic vector casting and causes PrintVecElemLoad/Store to emit
invalid fp4_e2_* accesses; either extend the pairwise lowering to support 16 and
32 lanes (update fp4_pair_cast to include 16 and 32 and implement corresponding
pairwise lowering logic) or explicitly reject unsupported fp4 lane widths by
adding an ICHECK (or CHECK) where fp4_pair_cast is computed to assert fp4_lanes
is one of the supported values; update the same area in
src/target/codegen_hip.cc so casts for FP4x16/FP4x32 never fall through to
PrintVecElemLoad/Store.
- Around line 818-827: The code currently takes the address of a prvalue from
__tl_cvt_fp4x2_to_half2, which is invalid in C++; instead, materialize the
conversion into a local half2 temporary (e.g. declare a variable like half2
tmp_half2 = __tl_cvt_fp4x2_to_half2(((uint8_t*)&(src))[i/2]);) and then build
v0/v1 to reference its components (prefer member access like tmp_half2.x and
tmp_half2.y) before calling PrintVecElemStore(sret, target_ty, i, ...) and
PrintVecElemStore(sret, target_ty, i+1, ...). Ensure the temporary name is
unique per loop iteration to avoid collisions.
In `@src/tl_templates/hip/hip_fp4.h`:
- Around line 35-38: The FP4 decoder in hip_fp4.h currently maps exp==0 and
mant==1 to 0.25f, which conflicts with __tl_float_to_fp4() and the new LUTs that
treat nibble 1 as 0.5; update the denormal handling in the decoder (the exp==0
branch that sets result based on mant) to return 0.5f for mant==1 (instead of
0.25f) so decoding matches the encoder and the mxfp LUTs (or alternatively make
the encoder/LUTs consistent with the decoder, but prefer updating the decoder to
match __tl_float_to_fp4() and tilelang/quantize/mxfp.py).
In `@tilelang/quantize/mxfp.py`:
- Around line 201-207: The current try/except around importing and calling
target_is_gfx950 silently swallows all exceptions causing a HIP target to fall
back to the CUDA/PTX path; update the block that references target_is_gfx950 and
_is_gfx950 to only catch expected import/availability errors (e.g.,
ImportError/ModuleNotFoundError/AttributeError) or, if any other exception is
raised by target_is_gfx950, re-raise it with additional context (e.g., include
the provided target) instead of pass so target-detection failures do not
silently change backend selection.
---
Outside diff comments:
In `@src/target/codegen_hip.cc`:
- Around line 1727-1757: The bug: GetBufferRef and the BufferStoreNode path
assume scalar-FP4 buffers are packed but the allocator only packs when
is_fp4_scalar_local (scope == "local"), causing mis-indexing for
shared/local.var buffers; to fix, either (A) consistently emit packed storage
for every scalar FP4 allocation that uses the nibble helpers (change the
allocation path that creates fp4_e2_2_t in the code that sets
fp4_packed_buffers_), or (B) safer minimal change—gate all uses of
tl_fp4_packed_load/tl_fp4_packed_store to only buffers recorded in
fp4_packed_buffers_: modify GetBufferRef(...) and the BufferStoreNode handling
to check fp4_packed_buffers_.count(op->buffer_var.get()) (or equivalent) before
emitting packed helpers, and fall back to fp4_e2_t accessors for others; update
any code that constructs vid_packed (where fp4_e2_2_t is emitted) to ensure
fp4_packed_buffers_ is the single source-of-truth.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 529c506e-c76e-4ad7-82a3-d5c03629a45b
📒 Files selected for processing (7)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.pysrc/target/codegen_hip.ccsrc/target/codegen_hip.hsrc/tl_templates/hip/hip_fp4.htesting/python/amd/test_tilelang_mxfp4_gfx950.pytilelang/intrinsics/mfma_layout.pytilelang/quantize/mxfp.py
| int fp4_lanes = from_ty.lanes(); | ||
| bool fp4_pair_cast = (fp4_lanes == 2 || fp4_lanes == 4 || fp4_lanes == 8); | ||
|
|
There was a problem hiding this comment.
Reject FP4x16/FP4x32 casts instead of falling through to generic vector code.
The new FP4 types are exposed up to 32 lanes, but fp4_pair_cast only handles 2/4/8. For x16/x32 this drops into the generic vector cast path, and PrintVecElemLoad/Store don't know how to index fp4_e2_* aggregates, so codegen will emit invalid accessors instead of a real cast. Please either extend the pairwise lowering to those lane counts or ICHECK them as unsupported here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_hip.cc` around lines 806 - 808, The fp4 handling currently
sets fp4_pair_cast based on fp4_lanes (variable fp4_lanes) but only allows
2/4/8, which lets x16/x32 fall through to generic vector casting and causes
PrintVecElemLoad/Store to emit invalid fp4_e2_* accesses; either extend the
pairwise lowering to support 16 and 32 lanes (update fp4_pair_cast to include 16
and 32 and implement corresponding pairwise lowering logic) or explicitly reject
unsupported fp4 lane widths by adding an ICHECK (or CHECK) where fp4_pair_cast
is computed to assert fp4_lanes is one of the supported values; update the same
area in src/target/codegen_hip.cc so casts for FP4x16/FP4x32 never fall through
to PrintVecElemLoad/Store.
| for (int i = 0; i < fp4_lanes; i += 2) { | ||
| std::ostringstream val; | ||
| val << "__tl_cvt_fp4x2_to_half2(((uint8_t*)&(" << src << "))[" << i / 2 | ||
| << "])"; | ||
| // Store both elements of the half2 | ||
| std::ostringstream v0, v1; | ||
| v0 << "((half_t*)(&(" << val.str() << ")))[0]"; | ||
| v1 << "((half_t*)(&(" << val.str() << ")))[1]"; | ||
| PrintVecElemStore(sret, target_ty, i, v0.str()); | ||
| PrintVecElemStore(sret, target_ty, i + 1, v1.str()); |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In C++ (and HIP/Clang), is taking the address of a temporary/prvalue returned from a function expression like &(f()) valid, and what is the recommended way to read packed fields from such a return value?
💡 Result:
No, taking the address of a temporary/prvalue returned from a function using the built-in & operator, as in &(f), is invalid in standard C++. The unary & operator requires an lvalue operand ([expr.unary.op]), and f (a prvalue) is not an lvalue. This is ill-formed; compilers may diagnose it as an error or warning, but it leads to undefined behavior if accepted as an extension. std::addressof(f) works because it takes an lvalue reference (T&) and uses reinterpret_cast to obtain the address, bypassing operator& overloads. However, since C++17 (LWG 2598), std::addressof has a deleted overload for const T&& to prevent use on temporaries, though it binds to non-const T&. The pointer is valid only until the end of the full-expression containing the call (no lifetime extension without direct reference binding). This behavior is standard in C++ (including Clang); no HIP-specific differences found in Clang/HIP docs. For reading packed fields from such a return value (e.g., struct with #pragma pack(1) or attribute((packed))): - Direct member access on the temporary (e.g., f.field) is valid and safe within the expression, as the compiler handles potential unaligned access. - Avoid storing the address or pointer for later use (dangling). - Recommended safe way: Copy the temporary to a local variable (extends lifetime) or use memcpy for fields: auto tmp = f; // Copies/moves, lifetime to end of scope auto field = tmp.packed_field; // Safe aligned access on local Or for immediate use without storage: int field; memcpy(&field, &f.packed_field, sizeof(field)); // Use std::addressof(f.packed_field) if needed This avoids unaligned access penalties/crashes (e.g., on ARM) while respecting lifetimes. Packed structs are for serialization; parse into aligned host structs for general use.
Citations:
- 1: https://stackoverflow.com/questions/30032827/c-taking-the-address-of-a-temporary-object
- 2: https://stackoverflow.com/questions/2280688/taking-the-address-of-a-temporary-object
- 3: https://stackoverflow.com/questions/4301179/why-is-taking-the-address-of-a-temporary-illegal
- 4: https://timsong-cpp.github.io/cppwp/std20/expr.unary.op
- 5: https://docs.cppreference.com/w/cpp/memory/addressof.html
- 6: https://timsong-cpp.github.io/lwg-issues/2598
- 7: https://clang.llvm.org/docs/HIPSupport.html
- 8: clang::Expr::ClassifyImpl asserts on isPRValue() llvm/llvm-project#53958
- 9: https://docs.huihoo.com/doxygen/clang/r222231/classclang_1_1MaterializeTemporaryExpr.html
- 10: compilation error with hip-clang ROCm/rocBLAS#473
- 11: https://thelinuxcode.com/how-i-pack-structs-in-c-without-getting-burned-by-alignment-abi-and-performance/
- 12: https://stackoverflow.com/questions/70980038/how-to-properly-access-packed-struct-members
🏁 Script executed:
cat -n src/target/codegen_hip.cc | sed -n '810,835p'Repository: tile-ai/tilelang
Length of output: 1368
🏁 Script executed:
# Search for BF16 cast handling to see the correct pattern
rg -n "bf16\|BF16\|bfloat" src/target/codegen_hip.cc -A 8 -B 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for float32 and float64 cast paths to see the correct pattern
rg -n "from_ty.is_float4_e2m1fn.*float" src/target/codegen_hip.cc -A 15 | head -60Repository: tile-ai/tilelang
Length of output: 3076
🏁 Script executed:
# Search more broadly for how temporaries are materialized in similar code
rg -n "ostringstream\|PrintVecElemStore" src/target/codegen_hip.cc -B 3 -A 3 | head -80Repository: tile-ai/tilelang
Length of output: 42
Materialize the half2 temporary before taking its address.
Lines 824–825 generate code that takes &(__tl_cvt_fp4x2_to_half2(...)), which attempts to take the address of a prvalue (temporary), a violation in C++. The float32 (lines 845–849) and float64 (lines 868–872) paths already demonstrate the correct pattern: materialize the return value into a local variable first, then access members by name. Apply the same approach here: create a half2 temporary, store the conversion result into it, and use member access instead of pointer arithmetic.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_hip.cc` around lines 818 - 827, The code currently takes
the address of a prvalue from __tl_cvt_fp4x2_to_half2, which is invalid in C++;
instead, materialize the conversion into a local half2 temporary (e.g. declare a
variable like half2 tmp_half2 =
__tl_cvt_fp4x2_to_half2(((uint8_t*)&(src))[i/2]);) and then build v0/v1 to
reference its components (prefer member access like tmp_half2.x and tmp_half2.y)
before calling PrintVecElemStore(sret, target_ty, i, ...) and
PrintVecElemStore(sret, target_ty, i+1, ...). Ensure the temporary name is
unique per loop iteration to avoid collisions.
| if (exp == 0u) { | ||
| // Denormal: value = (-1)^s * 2^(-1) * (0 + m*0.5) = (-1)^s * m * 0.25 | ||
| result = mant ? 0.25f : 0.0f; | ||
| } else { |
There was a problem hiding this comment.
The FP4 decoder is inconsistent with the encoder and LUTs.
Here exp == 0 && mant == 1 decodes to 0.25f, but __tl_float_to_fp4() below encodes 0.5f as nibble 1, and the new LUTs in tilelang/quantize/mxfp.py / testing/python/amd/test_tilelang_mxfp4_gfx950.py also decode nibble 1 as 0.5. That means FP4 values won't round-trip and every FP4→float/half/bf16 cast for raw value 1/9 is off by 2x. Either the decoder needs to return 0.5f here, or the encoder/LUTs need to be updated to the same encoding.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/tl_templates/hip/hip_fp4.h` around lines 35 - 38, The FP4 decoder in
hip_fp4.h currently maps exp==0 and mant==1 to 0.25f, which conflicts with
__tl_float_to_fp4() and the new LUTs that treat nibble 1 as 0.5; update the
denormal handling in the decoder (the exp==0 branch that sets result based on
mant) to return 0.5f for mant==1 (instead of 0.25f) so decoding matches the
encoder and the mxfp LUTs (or alternatively make the encoder/LUTs consistent
with the decoder, but prefer updating the decoder to match __tl_float_to_fp4()
and tilelang/quantize/mxfp.py).
| if target is not None: | ||
| try: | ||
| from tilelang.utils.target import target_is_gfx950 | ||
|
|
||
| _is_gfx950 = target_is_gfx950(target) | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Don't silently fall back to PTX when gfx950 detection fails.
If target_is_gfx950(target) raises here, the function quietly returns the CUDA/PTX implementation for a HIP target. That turns a target-detection bug into a much harder-to-debug backend mismatch later in compilation. Catch only the expected import/availability failure, or re-raise with context instead of pass.
🧰 Tools
🪛 Ruff (0.15.12)
[error] 206-207: try-except-pass detected, consider logging the exception
(S110)
[warning] 206-206: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/quantize/mxfp.py` around lines 201 - 207, The current try/except
around importing and calling target_is_gfx950 silently swallows all exceptions
causing a HIP target to fall back to the CUDA/PTX path; update the block that
references target_is_gfx950 and _is_gfx950 to only catch expected
import/availability errors (e.g.,
ImportError/ModuleNotFoundError/AttributeError) or, if any other exception is
raised by target_is_gfx950, re-raise it with additional context (e.g., include
the provided target) instead of pass so target-detection failures do not
silently change backend selection.
This PR adds end-to-end MXFP4 (FP4 E2M1) support for AMD gfx950 (CDNA4 /MI350) in the HIP backend.
HIP codegen (codegen_hip.cc/.h)
New header src/tl_templates/hip/hip_fp4.h
Dequantization kernels (tilelang/quantize/mxfp.py)
MFMA layout (tilelang/intrinsics/mfma_layout.py)
Tests & examples
CI Test:
Summary by CodeRabbit
New Features
Tests