Skip to content

[feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support.#2171

Open
TerminusAkivili wants to merge 1 commit intotile-ai:mainfrom
TerminusAkivili:sm120-fp4-a8w4-clean-pr
Open

[feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support.#2171
TerminusAkivili wants to merge 1 commit intotile-ai:mainfrom
TerminusAkivili:sm120-fp4-a8w4-clean-pr

Conversation

@TerminusAkivili
Copy link
Copy Markdown
Contributor

@TerminusAkivili TerminusAkivili commented May 8, 2026

Summary

This PR adds SM120 fragment-MMA GEMM support for T.float4_e2m1fn, including plain FP4 GEMM and explicit FP8 e4m3 / FP4 mixed GEMM.

Supported combinations:

  • T.float4_e2m1fn x T.float4_e2m1fn -> T.float32
  • T.float8_e4m3fn x T.float4_e2m1fn -> T.float32
  • T.float4_e2m1fn x T.float8_e4m3fn -> T.float32

The TileLang-facing API stays dtype-semantic: kernels declare FP4 tensors as T.float4_e2m1fn. Packed byte storage is handled by lowering/codegen and by host-side example setup, not by using uint8 as a GEMM dtype.

guide

The SM120 FP4 path has three contracts that need to line up:

Contract Where it is implemented
Shared FP4 operands are loaded with SM120 b4x16_p64 ldmatrix ldsm.h, codegen_cuda.cc, copy.cc, utils.py, mma_layout.py
Local FP4 fragments remain semantic FP4 objects for MMA operands codegen_cuda.cc, codegen_cuda.h
GEMM emits SM120 m16n8k32 MMA for explicit FP4/FP8 dtype pairs instruction/mma.h, gemm_mma.h, mma_macro_generator.py, gemm_mma.py

The main implementation detail is the shared-memory layout: SM120 b4x16_p64 uses packed FP4 bytes with a padded shared row layout. The copy and ldmatrix lowering paths therefore compute packed global offsets and padded shared offsets separately, while local fragments keep their declared names and types.

Main changes

CUDA templates

File Change
src/tl_templates/cuda/ldsm.h Add SM120 ptx_ldmatrix_b4x16_x{1,2,4} helpers with architecture guard
src/tl_templates/cuda/instruction/mma.h Add SM120 cute::SM120_16x8x32_TN dispatch for FP4xFP4, FP8xFP4, and FP4xFP8 to FP32
src/tl_templates/cuda/instruction/mma.h Apply the FP4 operand register shift only for FP4 operands before calling the CuTe atom
src/tl_templates/cuda/gemm_mma.h Register FP4 and mixed FP8/FP4 GEMM template dispatch
src/tl_templates/cuda/cuda_fp4.h Bridge TileLang FP4 template type to CuTe FP4 type while keeping existing packed helper types

CUDA lowering

File Change
src/backend/cuda/codegen/codegen_cuda.cc Select ptx_ldmatrix_b4x16_x{1,2,4} for explicit float4_e2m1fn ldmatrix loads
src/backend/cuda/codegen/codegen_cuda.cc Distinguish packed global FP4, padded shared FP4, and semantic local FP4 storage
src/backend/cuda/codegen/codegen_cuda.cc Keep local FP4 fragments under their declared names instead of introducing _packed aliases
src/backend/cuda/op/copy.cc Extend shared-to-fragment LDSM lowering for SM120 FP4
src/backend/cuda/op/copy.cc Route SM120 FP4 async copy through the existing cp.async lowering with padded shared-copy handling enabled
src/transform/lower_ptx_async_copy.cc Emit 8-byte FP4 global-to-shared async-copy segments for padded shared storage
src/backend/cuda/op/copy_analysis.cc / src/transform/ptx_async_copy_injector.h Carry the metadata/flag needed by the FP4 padded shared-copy path

Python lowering

File Change
tilelang/cuda/intrinsics/layout/utils.py Add FP4-specific ldmatrix offset handling, gated on float4_e2m1fn
tilelang/cuda/intrinsics/layout/mma_layout.py Add FP4 ldmatrix logical layouts
tilelang/cuda/intrinsics/macro/mma_macro_generator.py Use SM120 FP4 m16n8k32 MMA granularity
tilelang/cuda/op/gemm/gemm_mma.py Validate mixed A/B dtypes explicitly for A8W4/W4A8

Examples

File Change
examples/gemm_fp4/example_gemm_fp4_sm120.py Minimal SM120 FP4 GEMM example with numerical check
examples/gemm_fp4/example_gemm_a8w4_sm120.py Minimal SM120 A8W4 GEMM example with numerical check

Notes

  • Mixed A8W4/W4A8 dispatch is selected from explicit FP8/FP4 dtype pairs.
  • uint8 is only a host storage/interoperability detail in the examples.
  • FP4 ldmatrix layout handling is gated on float4_e2m1fn; existing int4/uint4 ldmatrix offset behavior stays on the existing path.
  • Example kernels use semantic TileLang signatures; packing is limited to input generation and reference checking.

Summary by CodeRabbit

  • New Features

    • FP4 (4-bit E2M1) GEMM examples and host-side unpack/validation for CUDA SM120+ (A8W4 and FP4-only).
    • Mixed-precision MMA support allowing FP4 weights to pair with FP8/float activations on SM120 tensor cores.
  • Improvements

    • SM120-optimized memory/layout and tensor-core paths (b4x16 ldmatrix, padded shared layout, cp.async) for efficient FP4 transfers.
    • Deterministic test harnesses and numerical-accuracy checks for FP4 kernels.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR implements SM120 (CUDA 12.0+) FP4 (float4_e2m1fn) GEMM support across CUDA/TI codegen, lowering, intrinsics, MMA dispatch, layout/macro generation, and provides two runnable examples with host-side FP4 unpacking and validation.

Changes

SM120 FP4 GEMM Support

Layer / File(s) Summary
Examples & Host Helpers
examples/gemm_fp4/...
Adds FP4 LUT constant, unpack_fp4_storage_to_float, require_sm120(), TileLang kernel generators (matmul_a8w4 / matmul_fp4), main() harnesses, deterministic/random inputs, zero-input checks, float32 reference comparisons, and error assertions.
CUDA Codegen: Buffer / Vector / Scalar Access
src/backend/cuda/codegen/codegen_cuda.cc, src/backend/cuda/codegen/codegen_cuda.h
Adds FP4 storage classifiers and GetFp4PaddedSharedIndex; applies padded-shared index remapping and packed-byte divisor logic; implements FP4-aware vector/scalar load-store codegen and cp.async/ldmatrix emission paths.
Copy Lowering & PTX Async Injection
src/transform/lower_ptx_async_copy.cc, src/transform/ptx_async_copy_injector.h, src/backend/cuda/op/copy.cc
Introduces fp4_padded_shared_copy flag, FP4-padded cp.async specialization splitting transfers into 16-FP4-element segments, threads FP4 mode into parallel-loop/vectorization lowering, and forwards flag into InjectPTXAsyncCopy.
LDSM/STSM Lowering & Analysis
src/backend/cuda/op/copy.cc, src/backend/cuda/op/copy_analysis.cc
Adds FP4 guards in LowerLDSM (require SM120, disallow transposed FP4 ldmatrix), computes elems_per_reg/elems_per_inst for 4-bit types, updates vectorizability and access_ptr extent math, and gates LDSM/STSM eligibility for FP4.
FP4 Type Definitions & Compile-Time Support
src/tl_templates/cuda/cuda_fp4.h
Expands compile-time arch guard (__CUDA_ARCH_LIST__ >= 1200) and adds make_fp4_e2_64_t variadic helper for 64-element FP4 packing.
MMA Dispatch & Instruction Support
src/tl_templates/cuda/gemm_mma.h, src/tl_templates/cuda/instruction/mma.h
Maps fp4_e2_t to CuTe float_e2m1_t, registers SM120 SM120_16x8x32_TN dispatchers for FP4×FP4 and mixed FP8/FP4 combinations, and updates tl::mma_sync to left-shift FP4 operands before dispatcher invocation.
Shared-Memory LDMATRIX Helpers
src/tl_templates/cuda/ldsm.h
Adds SM120-only ptx_ldmatrix_b4x16_x1/x2/x4 helpers to emit b4x16 ldmatrix inline PTX for FP4 loads.
Layout & Offset Computation
tilelang/cuda/intrinsics/layout/mma_layout.py, tilelang/cuda/intrinsics/layout/utils.py
Adds FP4-specific layout mapping helpers and extends get_ldmatrix_offset to accept float4_e2m1fn and apply FP4 layout transforms for supported cases.
Macro Generation & K-Dim/Extent Configuration
tilelang/cuda/intrinsics/macro/mma_macro_generator.py
Special-cases float4_e2m1fn to k_dim=32, caches dtype bitwidths for layout selection, computes FP4-dependent access extents (4*num), and treats 4-/8-bit dtypes via shared_16x32 layout transforms.
High-Level GEMM Integration
tilelang/cuda/op/gemm/gemm_mma.py
Adds FP8/FP4 dtype predicates, _validate_mma_dtypes() to enforce allowed mixed operand pairs (FP8+FP4), and allocates local fragments per operand dtype during lowering.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LJC00118
  • SiriusNEO
  • LeiWang1999

Poem

🐰 I nibble nibbles, pack them tight,
SM120 hums — kernels take flight.
Rows get padded, cp.async sings,
TileLang hops and tensor springs.
Rabbits cheer as GEMM takes flight.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support.' clearly describes the main feature: adding FP4 GEMM support for SM120 (Blackwell) architecture, which aligns with the core objective and changes in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/tl_templates/cuda/cuda_fp4.h (1)

166-187: ⚡ Quick win

Verify register allocation for fp4_e2_t values[64] in device code.

The 64-element local array is constant-indexed throughout (values[0]values[63]), so nvcc at -O2+ should scalar-replace it into registers. However, unlike the explicitly-parameterized make_fp4_e2_32_t which guarantees register-only arguments, register spilling to local memory is possible at lower optimisation levels or with larger surrounding register pressure. Consider adding a __forceinline__ annotation to maximise inlining and scalar replacement at call sites.

Proposed annotation
-template <typename... Args>
-TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) {
+template <typename... Args>
+TL_DEVICE __forceinline__ fp4_e2_64_t make_fp4_e2_64_t(Args... args) {
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/cuda_fp4.h` around lines 166 - 187, The local array
fp4_e2_t values[64] in make_fp4_e2_64_t may be spilled under some compile
conditions; annotate the function to force inlining (e.g., add a
__forceinline__/always-inline device inline attribute to make_fp4_e2_64_t) so
nvcc can scalar-replace values[0]..values[63] into registers and inline the
make_fp4_e2_32_t calls; update the function declaration for make_fp4_e2_64_t
accordingly (keeping fp4_e2_t values[64] and the existing make_fp4_e2_32_t
usages unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/cuda/codegen/codegen_cuda.cc`:
- Around line 1973-2003: The FP4 padded shared-memory vector path
(IsFp4PaddedSharedStorage + code using GetFp4PaddedSharedIndex and the
byte_offset lambda when constructing the reinterpret cast for t.lanes()) can
incorrectly span the padded 16-element row boundary; add a guard or split logic:
either assert the logical base alignment (e.g., Ensure base % 16 == 0 for the
requested load/store) or detect when the access crosses a 16-element row by
computing the start and end logical indices (base + offset and base + offset +
t.lanes()-1) and comparing their 16-element row indices (truncdiv(..., 16)); if
it crosses, split the operation into two row-aligned fragments (like the
existing t.lanes()==32 two-fragment approach) and merge them, otherwise keep the
current single contiguous byte reinterpretation; apply the same fix to the other
similar blocks identified (around the other ranges mentioned).
- Around line 4428-4444: The allocator treats only scope == "local" as the path
that emits local backing arrays but FP4 fragments use the semantic storage name
"local.fragment", so allocations for these still hit the unsupported-scope
branch; update the scope checks used around is_int4_scalar_local, the FP4
alignas(16) branch, and the place that prints/omits the storage scope to treat
"local.fragment" as equivalent to "local" (either normalize scope to "local"
earlier or change conditions from scope == "local" to (scope == "local" || scope
== "local.fragment")), ensuring PrintStorageScope/PrintType and the
backing-array emission path handle FP4 fragments the same as regular local
allocations (references: is_int4_scalar_local, op->dtype.is_float4_e2m1fn(),
PrintStorageScope, PrintType, and the "local.fragment" semantic storage).

In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 121-124: The FP4 fast-path in mma_macro_generator.py sets
self.k_dim = 32 without respecting self.chunk, causing micro_size_k to exceed
chunk when chunk < 32; update the FP4 branch in the initializer (the block
setting self.k_dim) to clamp k_dim by self.chunk (e.g., self.k_dim = min(32,
self.chunk)) and add the same clamp/guard in the subclass override (the code
around lines 873–877) so both places respect chunk; optionally emit a clear
ValueError or assertion if chunk < required minimum to fail early with a helpful
message referencing the dtype and chunk size.

---

Nitpick comments:
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Around line 166-187: The local array fp4_e2_t values[64] in make_fp4_e2_64_t
may be spilled under some compile conditions; annotate the function to force
inlining (e.g., add a __forceinline__/always-inline device inline attribute to
make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into
registers and inline the make_fp4_e2_32_t calls; update the function declaration
for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing
make_fp4_e2_32_t usages unchanged).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a09f3145-ce2d-4b0d-bb75-d916a099b2be

📥 Commits

Reviewing files that changed from the base of the PR and between a797e51 and 140f774.

📒 Files selected for processing (16)
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/backend/cuda/codegen/codegen_cuda.cc
  • src/backend/cuda/codegen/codegen_cuda.h
  • src/backend/cuda/op/copy.cc
  • src/backend/cuda/op/copy_analysis.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • src/transform/lower_ptx_async_copy.cc
  • src/transform/ptx_async_copy_injector.h
  • tilelang/cuda/intrinsics/layout/mma_layout.py
  • tilelang/cuda/intrinsics/layout/utils.py
  • tilelang/cuda/intrinsics/macro/mma_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_mma.py

Comment thread src/backend/cuda/codegen/codegen_cuda.cc
Comment thread src/backend/cuda/codegen/codegen_cuda.cc
Comment thread tilelang/cuda/intrinsics/macro/mma_macro_generator.py
@TerminusAkivili TerminusAkivili force-pushed the sm120-fp4-a8w4-clean-pr branch 2 times, most recently from a255e60 to 3e5823d Compare May 8, 2026 15:53
@TerminusAkivili TerminusAkivili force-pushed the sm120-fp4-a8w4-clean-pr branch from 3e5823d to 7f254a9 Compare May 8, 2026 16:39
@TerminusAkivili TerminusAkivili changed the title [feature][Blackwell] Add SM120 FP4 and A8W4 GEMM support [feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support. May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant