Skip to content

[tensilelite] Add fast path reference gemm for MXFP4#8106

Open
Alex-Vasile wants to merge 5 commits into
developfrom
users/alvasile/mxfp4_fast_ref
Open

[tensilelite] Add fast path reference gemm for MXFP4#8106
Alex-Vasile wants to merge 5 commits into
developfrom
users/alvasile/mxfp4_fast_ref

Conversation

@Alex-Vasile

@Alex-Vasile Alex-Vasile commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Motivation

Add a fast reference gemm for mxfp4.

Speedup is on the roughly 1 order of magnitude:

  • M=1500, N=1500, K=512 GEMM tests it's 0.12s vs 2.53s
  • M=512, N=512, K=1024 GEMM tests it's 0.07s vs 1.00s

Technical Details

Depends on #8006.

Test Plan

Add new tests to verify against the slow path.

Test Result

Tests pass.

Submission Checklist

@codecov-commenter

codecov-commenter commented Jun 5, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8106      +/-   ##
===========================================
- Coverage    65.18%   65.18%   -0.00%     
===========================================
  Files         2597     2597              
  Lines       404506   404553      +47     
  Branches     60271    60274       +3     
===========================================
+ Hits        263654   263673      +19     
- Misses      121051   121077      +26     
- Partials     19801    19803       +2     
Flag Coverage Δ *Carryforward flag
TensileLite 31.04% <ø> (-<0.01%) ⬇️
hipBLAS 90.65% <ø> (ø) Carriedforward from 594ea78
hipBLASLt 41.25% <ø> (+<0.01%) ⬆️
hipCUB 82.68% <ø> (ø) Carriedforward from 594ea78
hipDNN 86.68% <ø> (ø) Carriedforward from 594ea78
hipFFT 50.97% <ø> (ø) Carriedforward from 594ea78
hipRAND 76.12% <ø> (ø) Carriedforward from 594ea78
hipSOLVER 69.18% <ø> (ø) Carriedforward from 594ea78
hipSPARSE 86.55% <ø> (ø) Carriedforward from 594ea78
rocBLAS 48.08% <ø> (ø) Carriedforward from 594ea78
rocFFT 49.48% <ø> (ø) Carriedforward from 594ea78
rocRAND 57.02% <ø> (ø) Carriedforward from 594ea78
rocSOLVER 77.83% <ø> (ø) Carriedforward from 594ea78
rocSPARSE 72.68% <ø> (ø) Carriedforward from 594ea78
rocThrust 91.34% <ø> (ø) Carriedforward from 594ea78

*This pull request uses carry forward flags. Click here to find out more.
see 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@newling newling left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agent review comments are mostly from me discussing with it, there is not much new of substance in it as far as I can tell https://github.com/newling/agent-public-reviews/blob/main/rocm-libraries/review_pr8106.md

Comment thread projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp Outdated
Comment thread projects/hipblaslt/tensilelite/client/src/Reference.cpp
Comment thread projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp Outdated
Alex-Vasile added a commit that referenced this pull request Jun 8, 2026
The driver had three MX block-size flags: --mxBlock as a "set both
sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut
was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB =
mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it.
Per PR review #8106, the docstrings on the per-side flags also lied —
they claimed to "override --mxBlock" while the driver actually rejected
mixing the shortcut with either per-side flag.

Drop --mxBlock entirely. Symmetric MX is now spelled
--mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the
splat + conflict-validation block, and lets the docstrings be honest.

Driver:
- Remove the "mxBlock" option declaration.
- Remove parsing, splat, and the "cannot be combined with --mxBlockA
  or --mxBlockB" check.
- Drop the standalone "mxBlock must be non-negative" check (the
  per-side check already covers it).
- Update --mxBlockA / --mxBlockB docstrings to drop the "overrides"
  language.
- Cosmetic cleanup of three comments that referenced "mxBlock".

Tests:
- Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N
  --mxBlockB N" (mxBlock values 32, 16, and 6).
- Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer
  exists once --mxBlock is gone).

Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass
(was 241; the deleted conflict test accounts for the -1).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Alex-Vasile added a commit that referenced this pull request Jun 8, 2026
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()"
ternary obscured what the special case was for. Refactor:

- Extract the duplicated A/B loops into a single initFp4Operand lambda.
- Name the slots (slot0/slot1) and the condition (isPaddingSlot).
- Add a comment explaining why the padding slot must hold a valid FP4
  value: the fast-path ShadowBuffer FP4 decoder reads both slots of
  every byte unconditionally; only the write-back of the trailing slot
  is guarded. Leaving the padding slot uninitialized would be UB.

No behavior change. ctest -R CPUGemm still 240/240.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Alex-Vasile Alex-Vasile requested a review from newling June 8, 2026 19:28
Alex-Vasile added a commit that referenced this pull request Jun 12, 2026
The driver had three MX block-size flags: --mxBlock as a "set both
sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut
was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB =
mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it.
Per PR review #8106, the docstrings on the per-side flags also lied —
they claimed to "override --mxBlock" while the driver actually rejected
mixing the shortcut with either per-side flag.

Drop --mxBlock entirely. Symmetric MX is now spelled
--mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the
splat + conflict-validation block, and lets the docstrings be honest.

Driver:
- Remove the "mxBlock" option declaration.
- Remove parsing, splat, and the "cannot be combined with --mxBlockA
  or --mxBlockB" check.
- Drop the standalone "mxBlock must be non-negative" check (the
  per-side check already covers it).
- Update --mxBlockA / --mxBlockB docstrings to drop the "overrides"
  language.
- Cosmetic cleanup of three comments that referenced "mxBlock".

Tests:
- Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N
  --mxBlockB N" (mxBlock values 32, 16, and 6).
- Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer
  exists once --mxBlock is gone).

Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass
(was 241; the deleted conflict test accounts for the -1).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Alex-Vasile added a commit that referenced this pull request Jun 12, 2026
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()"
ternary obscured what the special case was for. Refactor:

- Extract the duplicated A/B loops into a single initFp4Operand lambda.
- Name the slots (slot0/slot1) and the condition (isPaddingSlot).
- Add a comment explaining why the padding slot must hold a valid FP4
  value: the fast-path ShadowBuffer FP4 decoder reads both slots of
  every byte unconditionally; only the write-back of the trailing slot
  is guarded. Leaving the padding slot uninitialized would be UB.

No behavior change. ctest -R CPUGemm still 240/240.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Alex-Vasile Alex-Vasile force-pushed the users/alvasile/mxfp4_fast_ref branch from 76c8deb to 7b1f738 Compare June 12, 2026 20:39
Alex-Vasile added a commit that referenced this pull request Jun 16, 2026
The driver had three MX block-size flags: --mxBlock as a "set both
sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut
was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB =
mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it.
Per PR review #8106, the docstrings on the per-side flags also lied —
they claimed to "override --mxBlock" while the driver actually rejected
mixing the shortcut with either per-side flag.

Drop --mxBlock entirely. Symmetric MX is now spelled
--mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the
splat + conflict-validation block, and lets the docstrings be honest.

Driver:
- Remove the "mxBlock" option declaration.
- Remove parsing, splat, and the "cannot be combined with --mxBlockA
  or --mxBlockB" check.
- Drop the standalone "mxBlock must be non-negative" check (the
  per-side check already covers it).
- Update --mxBlockA / --mxBlockB docstrings to drop the "overrides"
  language.
- Cosmetic cleanup of three comments that referenced "mxBlock".

Tests:
- Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N
  --mxBlockB N" (mxBlock values 32, 16, and 6).
- Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer
  exists once --mxBlock is gone).

Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass
(was 241; the deleted conflict test accounts for the -1).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Alex-Vasile added a commit that referenced this pull request Jun 16, 2026
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()"
ternary obscured what the special case was for. Refactor:

- Extract the duplicated A/B loops into a single initFp4Operand lambda.
- Name the slots (slot0/slot1) and the condition (isPaddingSlot).
- Add a comment explaining why the padding slot must hold a valid FP4
  value: the fast-path ShadowBuffer FP4 decoder reads both slots of
  every byte unconditionally; only the write-back of the trailing slot
  is guarded. Leaving the padding slot uninitialized would be UB.

No behavior change. ctest -R CPUGemm still 240/240.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Alex-Vasile Alex-Vasile force-pushed the users/alvasile/mxfp4_fast_ref branch from 7b1f738 to bce05db Compare June 16, 2026 15:13

@newling newling left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some perf numbers in the PR summary if you can think of any that might be relevant

#ifndef _WIN32
constexpr bool isFP4 = std::is_same_v<InputAT, Float4x2> && std::is_same_v<InputBT, Float4x2>;
#else
constexpr bool isFP4 = false;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if

std::is_same_v<InputAT, Float4x2> && !std::is_same_v<InputBT, Float4x2>

(ie one if fp4 and the other is not) isFP4 is false?

Wondering if some static_asserts would be good here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, good catch. I've added the assert, and also had it not even try the fast path for mixed types.

Alex-Vasile and others added 5 commits June 25, 2026 17:11
Extends both CPU reference paths to handle mxfp4 problems:
A and/or B is Float4 (E2M1, packed as Float4x2) with per-block UE8M0
scale factors (E8) applied every mxBlock K-elements.

Driver (client/cpu_gemm_driver.cpp)
-----------------------------------
- --type f4 with FP4-packed input/output paths
- --mxBlock N for symmetric block scaling; --mxBlockA N / --mxBlockB N
  for asymmetric. One-sided MX (only one of A/B scaled) is rejected.
- --batchCount N with per-batch MX scale slicing
- --computeInputA/B accepts f4 / f64 / tf32
- columnMajorGemm MX inner reduction steps by min(mxBlockA, mxBlockB);
  mxBlock must be a power of 2
- setMXScaleA/B called with padScaleTensor=false so the allocated
  buffers match columnMajorGemm's tight-packed {m, k/mxBlock} layout
- if constexpr guards on impossible Float4x2 mixed-input template
  instantiations
- FP4 init values cover the full E2M1 grid (±0, ±0.5, ±1, ±1.5,
  ±2, ±3, ±4, ±6); validation tolerance widened to 0.5 for FP4

Fast reference path (client/src/Reference.cpp)
----------------------------------------------
- ShadowBuffer FP4 decode via __amd_cvt_fp4x2_to_floatx2_scale;
  odd-N handled by reading (N+1)/2 words and guarding the trailing
  2nd-nibble write. Caller-side invariant: trailing-word upper
  nibble must be zero-padded or random (the driver pads).
- isFastPathEligible accepts FP4 + MX when
  mxBlockA/B % BLOCK_K == 0 and K % mxBlockA/B == 0, and rejects
  one-sided MX to keep the two paths consistent.
- solveCPUFastInF32 MX branch: per-side E8 scale pointers, per-batch
  slicing, per-BLOCK_K-tile tilePartial accumulator scaled per MX
  block into cReg. assert(sizeK % BLOCK_K == 0) inside the MX
  branch documents the invariant that holds transitively from
  mxBlock % BLOCK_K == 0 and K % mxBlock == 0.

Tests (tests/CMakeLists.txt)
----------------------------
Slow and fast paths across NN/TN/NT/TT, Bias, Relu, ScaleAlphaVec,
ScaleAB (scalar/vector), Beta, Batched (--batchCount 2), asymmetric
MX block (A32/B64, A64/B32), edge sizes (K=mxBlock, M=1, N=1,
alpha=0), plus negative cases (mxBlock not pow2, K unaligned, K too
small, type/mxBlock mismatch, one-sided MX rejected).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The driver had three MX block-size flags: --mxBlock as a "set both
sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut
was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB =
mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it.
Per PR review #8106, the docstrings on the per-side flags also lied —
they claimed to "override --mxBlock" while the driver actually rejected
mixing the shortcut with either per-side flag.

Drop --mxBlock entirely. Symmetric MX is now spelled
--mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the
splat + conflict-validation block, and lets the docstrings be honest.

Driver:
- Remove the "mxBlock" option declaration.
- Remove parsing, splat, and the "cannot be combined with --mxBlockA
  or --mxBlockB" check.
- Drop the standalone "mxBlock must be non-negative" check (the
  per-side check already covers it).
- Update --mxBlockA / --mxBlockB docstrings to drop the "overrides"
  language.
- Cosmetic cleanup of three comments that referenced "mxBlock".

Tests:
- Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N
  --mxBlockB N" (mxBlock values 32, 16, and 6).
- Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer
  exists once --mxBlock is gone).

Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass
(was 241; the deleted conflict test accounts for the -1).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()"
ternary obscured what the special case was for. Refactor:

- Extract the duplicated A/B loops into a single initFp4Operand lambda.
- Name the slots (slot0/slot1) and the condition (isPaddingSlot).
- Add a comment explaining why the padding slot must hold a valid FP4
  value: the fast-path ShadowBuffer FP4 decoder reads both slots of
  every byte unconditionally; only the write-back of the trailing slot
  is guarded. Leaving the padding slot uninitialized would be UB.

No behavior change. ctest -R CPUGemm still 240/240.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Alex-Vasile Alex-Vasile force-pushed the users/alvasile/mxfp4_fast_ref branch from bce05db to b2fed18 Compare June 25, 2026 20:05
@Alex-Vasile Alex-Vasile requested a review from newling June 25, 2026 20:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants