[tensilelite] Add fast path reference gemm for MXFP4#8106
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. ❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #8106 +/- ##
===========================================
- Coverage 65.18% 65.18% -0.00%
===========================================
Files 2597 2597
Lines 404506 404553 +47
Branches 60271 60274 +3
===========================================
+ Hits 263654 263673 +19
- Misses 121051 121077 +26
- Partials 19801 19803 +2
*This pull request uses carry forward flags. Click here to find out more. 🚀 New features to boost your workflow:
|
newling
left a comment
There was a problem hiding this comment.
Agent review comments are mostly from me discussing with it, there is not much new of substance in it as far as I can tell https://github.com/newling/agent-public-reviews/blob/main/rocm-libraries/review_pr8106.md
The driver had three MX block-size flags: --mxBlock as a "set both sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB = mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it. Per PR review #8106, the docstrings on the per-side flags also lied — they claimed to "override --mxBlock" while the driver actually rejected mixing the shortcut with either per-side flag. Drop --mxBlock entirely. Symmetric MX is now spelled --mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the splat + conflict-validation block, and lets the docstrings be honest. Driver: - Remove the "mxBlock" option declaration. - Remove parsing, splat, and the "cannot be combined with --mxBlockA or --mxBlockB" check. - Drop the standalone "mxBlock must be non-negative" check (the per-side check already covers it). - Update --mxBlockA / --mxBlockB docstrings to drop the "overrides" language. - Cosmetic cleanup of three comments that referenced "mxBlock". Tests: - Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N --mxBlockB N" (mxBlock values 32, 16, and 6). - Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer exists once --mxBlock is gone). Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass (was 241; the deleted conflict test accounts for the -1). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()" ternary obscured what the special case was for. Refactor: - Extract the duplicated A/B loops into a single initFp4Operand lambda. - Name the slots (slot0/slot1) and the condition (isPaddingSlot). - Add a comment explaining why the padding slot must hold a valid FP4 value: the fast-path ShadowBuffer FP4 decoder reads both slots of every byte unconditionally; only the write-back of the trailing slot is guarded. Leaving the padding slot uninitialized would be UB. No behavior change. ctest -R CPUGemm still 240/240. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The driver had three MX block-size flags: --mxBlock as a "set both sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB = mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it. Per PR review #8106, the docstrings on the per-side flags also lied — they claimed to "override --mxBlock" while the driver actually rejected mixing the shortcut with either per-side flag. Drop --mxBlock entirely. Symmetric MX is now spelled --mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the splat + conflict-validation block, and lets the docstrings be honest. Driver: - Remove the "mxBlock" option declaration. - Remove parsing, splat, and the "cannot be combined with --mxBlockA or --mxBlockB" check. - Drop the standalone "mxBlock must be non-negative" check (the per-side check already covers it). - Update --mxBlockA / --mxBlockB docstrings to drop the "overrides" language. - Cosmetic cleanup of three comments that referenced "mxBlock". Tests: - Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N --mxBlockB N" (mxBlock values 32, 16, and 6). - Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer exists once --mxBlock is gone). Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass (was 241; the deleted conflict test accounts for the -1). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()" ternary obscured what the special case was for. Refactor: - Extract the duplicated A/B loops into a single initFp4Operand lambda. - Name the slots (slot0/slot1) and the condition (isPaddingSlot). - Add a comment explaining why the padding slot must hold a valid FP4 value: the fast-path ShadowBuffer FP4 decoder reads both slots of every byte unconditionally; only the write-back of the trailing slot is guarded. Leaving the padding slot uninitialized would be UB. No behavior change. ctest -R CPUGemm still 240/240. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
76c8deb to
7b1f738
Compare
The driver had three MX block-size flags: --mxBlock as a "set both sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB = mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it. Per PR review #8106, the docstrings on the per-side flags also lied — they claimed to "override --mxBlock" while the driver actually rejected mixing the shortcut with either per-side flag. Drop --mxBlock entirely. Symmetric MX is now spelled --mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the splat + conflict-validation block, and lets the docstrings be honest. Driver: - Remove the "mxBlock" option declaration. - Remove parsing, splat, and the "cannot be combined with --mxBlockA or --mxBlockB" check. - Drop the standalone "mxBlock must be non-negative" check (the per-side check already covers it). - Update --mxBlockA / --mxBlockB docstrings to drop the "overrides" language. - Cosmetic cleanup of three comments that referenced "mxBlock". Tests: - Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N --mxBlockB N" (mxBlock values 32, 16, and 6). - Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer exists once --mxBlock is gone). Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass (was 241; the deleted conflict test accounts for the -1). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()" ternary obscured what the special case was for. Refactor: - Extract the duplicated A/B loops into a single initFp4Operand lambda. - Name the slots (slot0/slot1) and the condition (isPaddingSlot). - Add a comment explaining why the padding slot must hold a valid FP4 value: the fast-path ShadowBuffer FP4 decoder reads both slots of every byte unconditionally; only the write-back of the trailing slot is guarded. Leaving the padding slot uninitialized would be UB. No behavior change. ctest -R CPUGemm still 240/240. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
7b1f738 to
bce05db
Compare
newling
left a comment
There was a problem hiding this comment.
Please add some perf numbers in the PR summary if you can think of any that might be relevant
| #ifndef _WIN32 | ||
| constexpr bool isFP4 = std::is_same_v<InputAT, Float4x2> && std::is_same_v<InputBT, Float4x2>; | ||
| #else | ||
| constexpr bool isFP4 = false; |
There was a problem hiding this comment.
So if
std::is_same_v<InputAT, Float4x2> && !std::is_same_v<InputBT, Float4x2>
(ie one if fp4 and the other is not) isFP4 is false?
Wondering if some static_asserts would be good here
There was a problem hiding this comment.
Yup, good catch. I've added the assert, and also had it not even try the fast path for mixed types.
Extends both CPU reference paths to handle mxfp4 problems:
A and/or B is Float4 (E2M1, packed as Float4x2) with per-block UE8M0
scale factors (E8) applied every mxBlock K-elements.
Driver (client/cpu_gemm_driver.cpp)
-----------------------------------
- --type f4 with FP4-packed input/output paths
- --mxBlock N for symmetric block scaling; --mxBlockA N / --mxBlockB N
for asymmetric. One-sided MX (only one of A/B scaled) is rejected.
- --batchCount N with per-batch MX scale slicing
- --computeInputA/B accepts f4 / f64 / tf32
- columnMajorGemm MX inner reduction steps by min(mxBlockA, mxBlockB);
mxBlock must be a power of 2
- setMXScaleA/B called with padScaleTensor=false so the allocated
buffers match columnMajorGemm's tight-packed {m, k/mxBlock} layout
- if constexpr guards on impossible Float4x2 mixed-input template
instantiations
- FP4 init values cover the full E2M1 grid (±0, ±0.5, ±1, ±1.5,
±2, ±3, ±4, ±6); validation tolerance widened to 0.5 for FP4
Fast reference path (client/src/Reference.cpp)
----------------------------------------------
- ShadowBuffer FP4 decode via __amd_cvt_fp4x2_to_floatx2_scale;
odd-N handled by reading (N+1)/2 words and guarding the trailing
2nd-nibble write. Caller-side invariant: trailing-word upper
nibble must be zero-padded or random (the driver pads).
- isFastPathEligible accepts FP4 + MX when
mxBlockA/B % BLOCK_K == 0 and K % mxBlockA/B == 0, and rejects
one-sided MX to keep the two paths consistent.
- solveCPUFastInF32 MX branch: per-side E8 scale pointers, per-batch
slicing, per-BLOCK_K-tile tilePartial accumulator scaled per MX
block into cReg. assert(sizeK % BLOCK_K == 0) inside the MX
branch documents the invariant that holds transitively from
mxBlock % BLOCK_K == 0 and K % mxBlock == 0.
Tests (tests/CMakeLists.txt)
----------------------------
Slow and fast paths across NN/TN/NT/TT, Bias, Relu, ScaleAlphaVec,
ScaleAB (scalar/vector), Beta, Batched (--batchCount 2), asymmetric
MX block (A32/B64, A64/B32), edge sizes (K=mxBlock, M=1, N=1,
alpha=0), plus negative cases (mxBlock not pow2, K unaligned, K too
small, type/mxBlock mismatch, one-sided MX rejected).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The driver had three MX block-size flags: --mxBlock as a "set both sides" shortcut, plus per-side --mxBlockA / --mxBlockB. The shortcut was purely CLI sugar (immediately splatted into mxBlockA = mxBlockB = mxBlock); neither columnMajorGemm nor gemm_reference_fast consume it. Per PR review #8106, the docstrings on the per-side flags also lied — they claimed to "override --mxBlock" while the driver actually rejected mixing the shortcut with either per-side flag. Drop --mxBlock entirely. Symmetric MX is now spelled --mxBlockA N --mxBlockB N. This shrinks the CLI surface, removes the splat + conflict-validation block, and lets the docstrings be honest. Driver: - Remove the "mxBlock" option declaration. - Remove parsing, splat, and the "cannot be combined with --mxBlockA or --mxBlockB" check. - Drop the standalone "mxBlock must be non-negative" check (the per-side check already covers it). - Update --mxBlockA / --mxBlockB docstrings to drop the "overrides" language. - Cosmetic cleanup of three comments that referenced "mxBlock". Tests: - Convert 19 ctest entries from "--mxBlock N" to "--mxBlockA N --mxBlockB N" (mxBlock values 32, 16, and 6). - Delete CPUGemm.f4_MX_Conflict_Rejected (the conflict no longer exists once --mxBlock is gone). Verification: invoke build-client + ctest -R CPUGemm → 240/240 pass (was 241; the deleted conflict test accounts for the -1). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The terse "(i == storageA - 1 && numA % 2 != 0) ? 0.0f : randomFp4()" ternary obscured what the special case was for. Refactor: - Extract the duplicated A/B loops into a single initFp4Operand lambda. - Name the slots (slot0/slot1) and the condition (isPaddingSlot). - Add a comment explaining why the padding slot must hold a valid FP4 value: the fast-path ShadowBuffer FP4 decoder reads both slots of every byte unconditionally; only the write-back of the trailing slot is guarded. Leaving the padding slot uninitialized would be UB. No behavior change. ctest -R CPUGemm still 240/240. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
bce05db to
b2fed18
Compare
Motivation
Add a fast reference gemm for mxfp4.
Speedup is on the roughly 1 order of magnitude:
Technical Details
Depends on #8006.
Test Plan
Add new tests to verify against the slow path.
Test Result
Tests pass.
Submission Checklist