-
Notifications
You must be signed in to change notification settings - Fork 31
add MXFP8 pre-swizzling for gfx1250 GEMM (#568) #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
c873e46
c4c2ea5
00da5e6
8eaf06d
77f1c45
76c8d98
db3123f
c6cc59f
b6440e0
e46d6da
8b37f0f
d5a16c9
1cf0dad
a24f739
5f510c1
b540068
1d2f222
bfedb4a
98fb2ff
ce60ce0
b668a2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,12 +5,15 @@ | |||||||||||||
| ************************************************************************/ | ||||||||||||||
| #include <cmath> | ||||||||||||||
| #include <iostream> | ||||||||||||||
| #include <optional> | ||||||||||||||
| #include <set> | ||||||||||||||
| #include <string> | ||||||||||||||
| #include <cuda_bf16.h> | ||||||||||||||
| #include <cuda_runtime.h> | ||||||||||||||
| #include <gtest/gtest.h> | ||||||||||||||
| #include <transformer_engine/cast.h> | ||||||||||||||
| #include <transformer_engine/gemm.h> | ||||||||||||||
| #include <transformer_engine/swizzle.h> | ||||||||||||||
| #include <transformer_engine/transformer_engine.h> | ||||||||||||||
| #include "../test_common.h" | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -30,7 +33,107 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |||||||||||||
|
|
||||||||||||||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | ||||||||||||||
| {32, 128, 16}, | ||||||||||||||
| {64, 128, 32}, | ||||||||||||||
| {128, 128, 64}, | ||||||||||||||
| {64, 256, 32}, | ||||||||||||||
| {128, 384, 64}, | ||||||||||||||
| {256, 512, 128}, | ||||||||||||||
| {512, 1024, 256}, | ||||||||||||||
| {768, 3072, 4096}, | ||||||||||||||
| {1024, 2048, 128}, | ||||||||||||||
| {4096, 8192, 64}, | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // Production LLM shapes for MXFP8 GEMM testing. | ||||||||||||||
| // | ||||||||||||||
| // Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4) | ||||||||||||||
| // yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT) | ||||||||||||||
| // via ::testing::Combine. | ||||||||||||||
| // | ||||||||||||||
| // GemmPass selects the FP8 type combination: | ||||||||||||||
| // FWD: E4M3 x E4M3 -> BF16 | ||||||||||||||
| // DGRAD: E5M2 x E4M3 -> BF16 | ||||||||||||||
| // WGRAD: E4M3 x E5M2 -> BF16 | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| enum class GemmPass { FWD, DGRAD, WGRAD }; | ||||||||||||||
|
|
||||||||||||||
| struct ShapeDef { | ||||||||||||||
| const char* label; | ||||||||||||||
| size_t dim1; // FWD/DGRAD: N, WGRAD: M | ||||||||||||||
| size_t dim2; // FWD/DGRAD: K, WGRAD: N | ||||||||||||||
| GemmPass pass; | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| std::ostream& operator<<(std::ostream& os, const ShapeDef& s) { | ||||||||||||||
| return os << s.label; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| static void resolve_mkn(const ShapeDef& s, size_t mbs, | ||||||||||||||
| size_t& m, size_t& k, size_t& n) { | ||||||||||||||
| size_t tokens = mbs * 4096; | ||||||||||||||
| switch (s.pass) { | ||||||||||||||
| case GemmPass::FWD: | ||||||||||||||
| case GemmPass::DGRAD: | ||||||||||||||
| m = tokens; n = s.dim1; k = s.dim2; | ||||||||||||||
| break; | ||||||||||||||
| case GemmPass::WGRAD: | ||||||||||||||
| m = s.dim1; n = s.dim2; k = tokens; | ||||||||||||||
| break; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head) | ||||||||||||||
| static const ShapeDef deepseek3_shapes[] = { | ||||||||||||||
| // Forward (M=tokens, N, K) | ||||||||||||||
| {"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD}, | ||||||||||||||
| {"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD}, | ||||||||||||||
| // Dgrad (M=tokens, N, K) | ||||||||||||||
| {"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD}, | ||||||||||||||
| {"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD}, | ||||||||||||||
| // Wgrad (M, N, K=tokens) | ||||||||||||||
| {"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD}, | ||||||||||||||
| {"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD}, | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head) | ||||||||||||||
| static const ShapeDef qwen3_shapes[] = { | ||||||||||||||
| // Forward (M=tokens, N, K) | ||||||||||||||
| {"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD}, | ||||||||||||||
| {"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD}, | ||||||||||||||
| {"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD}, | ||||||||||||||
| {"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD}, | ||||||||||||||
| // Dgrad (M=tokens, N, K) | ||||||||||||||
| {"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD}, | ||||||||||||||
| {"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD}, | ||||||||||||||
| {"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD}, | ||||||||||||||
| {"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD}, | ||||||||||||||
| // Wgrad (M, N, K=tokens) | ||||||||||||||
| {"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD}, | ||||||||||||||
| {"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD}, | ||||||||||||||
| {"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD}, | ||||||||||||||
| {"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD}, | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // A, B, Bias, Gelu, D | ||||||||||||||
|
|
@@ -303,6 +406,40 @@ void cpu_rowwise_to_columnwise( | |||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250. | ||||||||||||||
| static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) { | ||||||||||||||
| using namespace transformer_engine; | ||||||||||||||
| void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() | ||||||||||||||
| : t.columnwise_scale_inv_dptr(); | ||||||||||||||
| if (!scale_ptr) return; | ||||||||||||||
| const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() | ||||||||||||||
| : t.columnwise_scale_inv_shape(); | ||||||||||||||
| const NVTEShape data_shape = rowwise ? t.rowwise_shape() | ||||||||||||||
| : t.columnwise_shape(); | ||||||||||||||
| size_t num_scales = 1; | ||||||||||||||
| for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d]; | ||||||||||||||
| uint8_t *d_tmp = nullptr; | ||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); | ||||||||||||||
| TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); | ||||||||||||||
| TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); | ||||||||||||||
| output_tw.set_with_gemm_swizzled_scales(true); | ||||||||||||||
| if (rowwise) { | ||||||||||||||
| input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||||||||||||||
| input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||||||||||||||
| output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); | ||||||||||||||
| output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||||||||||||||
| } else { | ||||||||||||||
| input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||||||||||||||
| input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); | ||||||||||||||
| output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); | ||||||||||||||
| output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); | ||||||||||||||
| } | ||||||||||||||
| nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); | ||||||||||||||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); | ||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(d_tmp)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) { | ||||||||||||||
| auto [atol, rtol] = getTolerances(type); | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -318,6 +455,12 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool | |||||||||||||
| else if (use_fp8) { | ||||||||||||||
| atol = 1e-3; | ||||||||||||||
| rtol = std::max(rtol, 1e-2); | ||||||||||||||
| // Relax for gfx1250 | ||||||||||||||
| cudaDeviceProp prop; | ||||||||||||||
| (void)cudaGetDeviceProperties(&prop, 0); | ||||||||||||||
| if (prop.major == 12 && type == DType::kBFloat16) { | ||||||||||||||
| rtol = std::max(rtol, 5e-2); | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| else if (type == DType::kBFloat16) { | ||||||||||||||
| //relax for certain prime number TN gemm | ||||||||||||||
|
|
@@ -496,6 +639,31 @@ void performTest(const TestParams& params) { | |||||||||||||
| #endif | ||||||||||||||
| Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); | ||||||||||||||
|
|
||||||||||||||
| //perform the reference gemm on GPU (before swizzle, which modifies scales in-place) | ||||||||||||||
| Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); | ||||||||||||||
| Tensor RefPreGeluOut; | ||||||||||||||
|
|
||||||||||||||
| if (params.use_gelu) { | ||||||||||||||
| RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||||||||||||||
| params, | ||||||||||||||
| A, | ||||||||||||||
| B, | ||||||||||||||
| params.use_bias ? &bias : nullptr, | ||||||||||||||
| D, | ||||||||||||||
| RefD, | ||||||||||||||
| params.use_gelu ? &RefPreGeluOut : nullptr); | ||||||||||||||
|
|
||||||||||||||
| // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. | ||||||||||||||
| if (use_mxfp8 && prop.major == 12) { | ||||||||||||||
| if (!a_colwise) swizzle_mxfp8_scales(A, true); | ||||||||||||||
| if (a_colwise) swizzle_mxfp8_scales(A, false); | ||||||||||||||
| if (!b_colwise) swizzle_mxfp8_scales(B, true); | ||||||||||||||
| if (b_colwise) swizzle_mxfp8_scales(B, false); | ||||||||||||||
|
Comment on lines
+661
to
+664
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: each pair of
Suggested change
|
||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| //perform the gemm in GPU | ||||||||||||||
| nvte_cublas_gemm(A.data(), | ||||||||||||||
| B.data(), | ||||||||||||||
|
|
@@ -517,23 +685,6 @@ void performTest(const TestParams& params) { | |||||||||||||
| pre_gelu_out.to_cpu(); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| //perform the reference gemm on GPU | ||||||||||||||
| Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); | ||||||||||||||
| Tensor RefPreGeluOut; | ||||||||||||||
|
|
||||||||||||||
| if (params.use_gelu) { | ||||||||||||||
| RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||||||||||||||
| params, | ||||||||||||||
| A, | ||||||||||||||
| B, | ||||||||||||||
| params.use_bias ? &bias : nullptr, | ||||||||||||||
| D, | ||||||||||||||
| RefD, | ||||||||||||||
| params.use_gelu ? &RefPreGeluOut : nullptr); | ||||||||||||||
|
|
||||||||||||||
| // check if error message happens in running | ||||||||||||||
| (void)cudaDeviceSynchronize(); | ||||||||||||||
| auto err = cudaGetLastError(); | ||||||||||||||
|
|
@@ -559,7 +710,9 @@ void performTest(const TestParams& params) { | |||||||||||||
|
|
||||||||||||||
| #ifdef __HIP_PLATFORM_AMD__ | ||||||||||||||
| template <typename A_Type, typename B_Type, typename D_Type> | ||||||||||||||
| void performDqTest(const TestParams ¶ms) { | ||||||||||||||
| void performDqTest(const TestParams ¶ms, | ||||||||||||||
| std::optional<double> atol_override = std::nullopt, | ||||||||||||||
| std::optional<double> rtol_override = std::nullopt) { | ||||||||||||||
| DType atype = TypeInfo<A_Type>::dtype; | ||||||||||||||
| DType btype = TypeInfo<B_Type>::dtype; | ||||||||||||||
| DType dtype = TypeInfo<D_Type>::dtype; | ||||||||||||||
|
|
@@ -582,6 +735,17 @@ void performDqTest(const TestParams ¶ms) { | |||||||||||||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // hipBLASLt on gfx950 produces incorrect results for certain small MXFP8 | ||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there ticket for that?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, there isn't. |
||||||||||||||
| // GEMMs with non-TN layouts. | ||||||||||||||
| if (prop.major == 9 && prop.minor == 5) { | ||||||||||||||
| const bool is_NN = !params.transa && !params.transb; | ||||||||||||||
| const bool is_NT = !params.transa && params.transb; | ||||||||||||||
| if ((is_NN && params.m == 64) || | ||||||||||||||
| (is_NT && params.m > 32 && params.m <= 128 && params.n <= 64)) { | ||||||||||||||
| GTEST_SKIP() << "hipBLASLt MXFP8 non-TN GEMM with small M/N is not supported on gfx950"; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| DType ref_type = dtype; | ||||||||||||||
| TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m}; | ||||||||||||||
| TShape b_shape = params.transb ? TShape{params.k, params.n} : TShape{params.n, params.k}; | ||||||||||||||
|
|
@@ -605,6 +769,16 @@ void performDqTest(const TestParams ¶ms) { | |||||||||||||
| nvte_dequantize(A_fp8.data(), A_ref.data(), 0); | ||||||||||||||
| nvte_dequantize(B_fp8.data(), B_ref.data(), 0); | ||||||||||||||
|
|
||||||||||||||
| // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. | ||||||||||||||
| if (prop.major == 12) { | ||||||||||||||
| const bool a_colwise = !params.transa; | ||||||||||||||
| const bool b_colwise = params.transb; | ||||||||||||||
| if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true); | ||||||||||||||
| if (a_colwise) swizzle_mxfp8_scales(A_fp8, false); | ||||||||||||||
| if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true); | ||||||||||||||
| if (b_colwise) swizzle_mxfp8_scales(B_fp8, false); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| Tensor bias; | ||||||||||||||
| Tensor pre_gelu_out; | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -633,6 +807,10 @@ void performDqTest(const TestParams ¶ms) { | |||||||||||||
|
|
||||||||||||||
| //compare results | ||||||||||||||
| auto [atol, rtol] = getTestTolerances(dtype, true, true); | ||||||||||||||
| if (atol_override) | ||||||||||||||
| atol = *atol_override; | ||||||||||||||
| if (rtol_override) | ||||||||||||||
| rtol = *rtol_override; | ||||||||||||||
| compareResults("D", D, D_ref.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||||||||||||||
| } | ||||||||||||||
| #endif // __HIP_PLATFORM_AMD__ | ||||||||||||||
|
|
@@ -751,6 +929,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |||||||||||||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||||||||||||||
| }); | ||||||||||||||
|
|
||||||||||||||
| // ============================================================================ | ||||||||||||||
| // Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*') | ||||||||||||||
| // ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| // Known-failing GEMM shapes on gfx950 | ||||||||||||||
| static const std::set<std::string> kGfx950Skips = { | ||||||||||||||
| "DeepSeek3_Linear1_fwd_mbs1_NT", | ||||||||||||||
| "DeepSeek3_Linear1_fwd_mbs2_NT", | ||||||||||||||
| "DeepSeek3_Linear1_fwd_mbs4_NT", | ||||||||||||||
| "DeepSeek3_LNLinear0_fwd_mbs4_NN", | ||||||||||||||
| "DeepSeek3_LNLinear0_fwd_mbs4_NT", | ||||||||||||||
| "DeepSeek3_attn_wgrad_mbs1_NN", | ||||||||||||||
| "Qwen3_LMHead_fwd_mbs2_NN", | ||||||||||||||
| "Qwen3_Router_fwd_mbs2_NT", | ||||||||||||||
| "Qwen3_LMHead_fwd_mbs4_TN", | ||||||||||||||
| "Qwen3_LMHead_fwd_mbs4_NN", | ||||||||||||||
| "Qwen3_LMHead_fwd_mbs4_NT", | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine. | ||||||||||||||
| using ProdGemmParam = std::tuple<ShapeDef, size_t, Layout>; | ||||||||||||||
|
|
||||||||||||||
| class ProdDqGEMMTestSuite : public ::testing::TestWithParam<ProdGemmParam> {}; | ||||||||||||||
|
|
||||||||||||||
| TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) { | ||||||||||||||
| const auto& shape = std::get<0>(GetParam()); | ||||||||||||||
| size_t mbs = std::get<1>(GetParam()); | ||||||||||||||
| const auto& layout = std::get<2>(GetParam()); | ||||||||||||||
|
|
||||||||||||||
| std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs) | ||||||||||||||
| + "_" + TN(layout); | ||||||||||||||
| if (kGfx950Skips.count(name)) { | ||||||||||||||
| GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| size_t m, k, n; | ||||||||||||||
| resolve_mkn(shape, mbs, m, k, n); | ||||||||||||||
|
|
||||||||||||||
| TestParams params = {.m = m, .k = k, .n = n, | ||||||||||||||
| .use_bias = false, .use_gelu = false, | ||||||||||||||
| .transa = layout.first, .transb = layout.second, | ||||||||||||||
| .scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING}; | ||||||||||||||
|
|
||||||||||||||
| // Production shapes use looser tolerances: the MXFP8 and bf16 reference | ||||||||||||||
| // GEMM use different internal accumulation paths, so results can differ | ||||||||||||||
| // by up to 1 ULP in bf16 (~1.5-2% relative). | ||||||||||||||
| const double prod_atol = 1e-3; | ||||||||||||||
| const double prod_rtol = 2e-2; | ||||||||||||||
|
|
||||||||||||||
| switch (shape.pass) { | ||||||||||||||
| case GemmPass::FWD: | ||||||||||||||
| performDqTest<fp8, fp8, bf16>(params, prod_atol, prod_rtol); | ||||||||||||||
| break; | ||||||||||||||
| case GemmPass::DGRAD: | ||||||||||||||
| performDqTest<bf8, fp8, bf16>(params, prod_atol, prod_rtol); | ||||||||||||||
| break; | ||||||||||||||
| case GemmPass::WGRAD: | ||||||||||||||
| performDqTest<fp8, bf8, bf16>(params, prod_atol, prod_rtol); | ||||||||||||||
| break; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| static auto prodTestName = [](const testing::TestParamInfo<ProdGemmParam>& info) { | ||||||||||||||
| const auto& shape = std::get<0>(info.param); | ||||||||||||||
| size_t mbs = std::get<1>(info.param); | ||||||||||||||
| const auto& layout = std::get<2>(info.param); | ||||||||||||||
| return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout); | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite, | ||||||||||||||
| ::testing::Combine( | ||||||||||||||
| ::testing::ValuesIn(deepseek3_shapes), | ||||||||||||||
| ::testing::Values(size_t{1}, size_t{2}, size_t{4}), | ||||||||||||||
| ::testing::ValuesIn(kLayouts)), | ||||||||||||||
| prodTestName); | ||||||||||||||
|
|
||||||||||||||
| INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite, | ||||||||||||||
| ::testing::Combine( | ||||||||||||||
| ::testing::ValuesIn(qwen3_shapes), | ||||||||||||||
| ::testing::Values(size_t{1}, size_t{2}, size_t{4}), | ||||||||||||||
| ::testing::ValuesIn(kLayouts)), | ||||||||||||||
| prodTestName); | ||||||||||||||
|
|
||||||||||||||
| TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { | ||||||||||||||
| const size_t rows = 128; | ||||||||||||||
| const size_t cols = 256; | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This relaxation fires for every FP8 GEMM test on any gfx12 device (tensor-scaling FP8 included), not just MXFP8 on gfx1250. The comment ("Relax for gfx1250") and the PR scope suggest the intent is the gfx1250 MXFP8 path specifically. Consider guarding with
use_mxfp8and/orprop.major == 12 && prop.minor == 5so non-MXFP8 FP8 tests don't silently lose precision coverage on this arch.