add MXFP8 pre-swizzling for gfx1250 GEMM (#568)#605
Conversation
|
Manually tested on gfx1250, should be ready to go from my perspective. |
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
|
|
||
| // hipBLASLt on gfx950 produces incorrect results for certain small MXFP8 |
There was a problem hiding this comment.
Is there ticket for that?
There was a problem hiding this comment.
No, there isn't.
| if (is_nvfp4_scaling(config.scaling_mode)) { | ||
| if (is_nvfp4_scaling(config.scaling_mode) | ||
| #ifdef USE_ROCM | ||
| || (config.scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING |
There was a problem hiding this comment.
I think there should be corresponding update fort workspace size calculation: scale sizes should be added to it
Claude WalkthroughIntent. Add MXFP8 GEMM support on gfx1250 by teaching TE to pre-swizzle MXFP8 scale tensors into the K-tiled "Tensile 3D" layout that hipBLASLt requires there, and to canonicalize operand layouts to TN (the only layout the gfx1250 MXFP8 kernels accept). Cherry-picked from #568. Key changes.
Walkthrough.
Testing.
Notes for reviewers.
Generated by Claude. To request a code review, comment |
|
|
||
| uint8_t val; | ||
| if constexpr (kRowwise) { | ||
| val = input[idx]; // == input[m * orig_K + k] |
There was a problem hiding this comment.
Wrong input stride when padded_K_scale > orig_K_scale.
The kernel reads input[idx] = input[m * orig_K + k], treating the input as if its row stride were orig_K. But the actual scale_inv buffer is allocated with the padded scale shape [padded_M, padded_K] (see MXFP8Quantizer::get_scale_shape and mxfp8_tensor.py), so its row stride is padded_K (= the m parameter that's not passed — only padded_M is). For row m > 0 this reads from the wrong memory location.
Concrete example: data shape [3, 96] → orig_M=3, orig_K=3, padded_M=4, padded_K=4. Buffer stride is 4, but the kernel reads input[1] for logical (m=1, k=0) instead of input[4].
This is masked in practice because:
- All
test_swizzle.cucases useorig_Kalready a multiple of 4 (sopadded_K == orig_K), and the test allocates the input buffer with sizeorig_M * orig_Krather thanpadded_M * padded_K— matching the kernel's incorrect stride assumption rather than the production layout. - All production LLM shapes here have
K_dataas a multiple of 128, soK_scale % 4 == 0and no K-side padding occurs.
But the kernel will misbehave for K_data ∈ {32, 64, 96, 160, ...} (any K that's a multiple of 32 but not 128). Suggested fix: also pass the padded scale K (= k in the caller) as a stride parameter, e.g.:
swizzle_scaling_mx_kernel(input, output, padded_M, padded_K, orig_M, orig_K);
...
if constexpr (kRowwise) {
val = input[m * padded_K + k];
} else {
val = input[k * padded_M + m]; // padded_M for columnwise (currently OK because K_data%32==0 ⇒ padded_M==orig_M, but better to be explicit)
}And update the test to allocate d_input of size padded_M * padded_K and exercise non-multiple-of-4 orig_K (e.g. {3,3}, {8,5}, {32,7}) to lock this down.
|
|
||
| // Scale dimensions (M_scale, K_scale). | ||
| // K_scale will be padded to multiple of 4 by the test. | ||
| std::vector<std::pair<int, int>> mx_scale_dims = { |
There was a problem hiding this comment.
Coverage gap: no padded case is exercised.
Every entry in mx_scale_dims has orig_K already a multiple of 4, so K = roundup_sz(orig_K, 4) == orig_K. The test also allocates d_input of size orig_M * orig_K (line 248) rather than the padded M * K, so the test only validates the kernel against a tightly-packed (unpadded) input layout — not the padded [padded_M, padded_K] layout used by MXFP8Quantizer::get_scale_shape in production. Likewise no case has orig_M % 4 != 0.
Please add cases with orig_K not a multiple of 4 (e.g. {4,3}, {8,5}, {32,7}) and at least one with orig_M % 4 != 0 (e.g. {3,4}), allocating the input with the padded stride matching production. This would have caught the stride bug flagged in swizzle.cu:427.
| // Relax for gfx1250 | ||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
| if (prop.major == 12 && type == DType::kBFloat16) { |
There was a problem hiding this comment.
This relaxation fires for every FP8 GEMM test on any gfx12 device (tensor-scaling FP8 included), not just MXFP8 on gfx1250. The comment ("Relax for gfx1250") and the PR scope suggest the intent is the gfx1250 MXFP8 path specifically. Consider guarding with use_mxfp8 and/or prop.major == 12 && prop.minor == 5 so non-MXFP8 FP8 tests don't silently lose precision coverage on this arch.
| for (size_t i = 0; i < input.size(); i++) { | ||
| if (is_mxfp8_scaling(input[i]->scaling_mode)) { | ||
| any_mxfp8 = true; | ||
| } |
There was a problem hiding this comment.
If any_mxfp8 is true, every tensor in the batch is dispatched to swizzle_scaling_factors_mx, which asserts scaling_mode == NVTE_MXFP8_1D_SCALING. The single-tensor multi_tensor_swizzle_scaling_factors contract (a few lines below) only requires each tensor to be (fp8 && mxfp8) || (fp4 && nvfp4) — i.e. a batch can in principle mix MXFP8 and NVFP4 tensors. On gfx1250 such a mix would crash inside the MX helper.
Probably theoretical today (NVFP4 isn't supported on gfx1250 yet), but the safer form is either to assert all_mxfp8 here, or to dispatch per-tensor through swizzle_scaling_factors (which already routes MXFP8→MX on gfx1250 and leaves NVFP4 on the existing path).
| if (!a_colwise) swizzle_mxfp8_scales(A, true); | ||
| if (a_colwise) swizzle_mxfp8_scales(A, false); | ||
| if (!b_colwise) swizzle_mxfp8_scales(B, true); | ||
| if (b_colwise) swizzle_mxfp8_scales(B, false); |
There was a problem hiding this comment.
Nit: each pair of if (!x) ...; if (x) ...; lines unconditionally calls swizzle_mxfp8_scales with the negation of x. The same lines exist again at 776-779.
| if (!a_colwise) swizzle_mxfp8_scales(A, true); | |
| if (a_colwise) swizzle_mxfp8_scales(A, false); | |
| if (!b_colwise) swizzle_mxfp8_scales(B, true); | |
| if (b_colwise) swizzle_mxfp8_scales(B, false); | |
| swizzle_mxfp8_scales(A, !a_colwise); | |
| swizzle_mxfp8_scales(B, !b_colwise); |
Claude reviewReviewed the gfx1250 MXFP8 pre-swizzle plumbing (GEMM path + scale padding + new Findings (see inline comments):
Copyright headers: OK (all 12 changed files have correct AMD |
Description
Cherry-picked from #568 (same code)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: