-
Notifications
You must be signed in to change notification settings - Fork 30
add production GEMM tests #590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
c873e46
c4c2ea5
00da5e6
8eaf06d
77f1c45
76c8d98
db3123f
c6cc59f
b6440e0
1cf0dad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,8 @@ | |
| ************************************************************************/ | ||
| #include <cmath> | ||
| #include <iostream> | ||
| #include <optional> | ||
| #include <set> | ||
| #include <string> | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
@@ -33,6 +35,98 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | |
| {768, 3072, 4096}, | ||
| }; | ||
|
|
||
| // ============================================================================ | ||
| // Production LLM shapes for MXFP8 GEMM testing. | ||
| // | ||
| // Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4) | ||
| // yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT) | ||
| // via ::testing::Combine. | ||
| // | ||
| // GemmPass selects the FP8 type combination: | ||
| // FWD: E4M3 x E4M3 -> BF16 | ||
| // DGRAD: E5M2 x E4M3 -> BF16 | ||
| // WGRAD: E4M3 x E5M2 -> BF16 | ||
| // ============================================================================ | ||
|
|
||
| enum class GemmPass { FWD, DGRAD, WGRAD }; | ||
|
|
||
| struct ShapeDef { | ||
| const char* label; | ||
| size_t dim1; // FWD/DGRAD: N, WGRAD: M | ||
| size_t dim2; // FWD/DGRAD: K, WGRAD: N | ||
| GemmPass pass; | ||
| }; | ||
|
|
||
| std::ostream& operator<<(std::ostream& os, const ShapeDef& s) { | ||
| return os << s.label; | ||
| } | ||
|
|
||
| static void resolve_mkn(const ShapeDef& s, size_t mbs, | ||
| size_t& m, size_t& k, size_t& n) { | ||
| size_t tokens = mbs * 4096; | ||
| switch (s.pass) { | ||
| case GemmPass::FWD: | ||
| case GemmPass::DGRAD: | ||
| m = tokens; n = s.dim1; k = s.dim2; | ||
| break; | ||
| case GemmPass::WGRAD: | ||
| m = s.dim1; n = s.dim2; k = tokens; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head) | ||
| static const ShapeDef deepseek3_shapes[] = { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is execution time for all ProdGEMM tests? |
||
| // Forward (M=tokens, N, K) | ||
| {"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD}, | ||
| {"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD}, | ||
| {"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD}, | ||
| {"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD}, | ||
| {"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD}, | ||
| {"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD}, | ||
| {"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD}, | ||
| {"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD}, | ||
| {"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD}, | ||
| {"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD}, | ||
| {"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD}, | ||
| // Dgrad (M=tokens, N, K) | ||
| {"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD}, | ||
| {"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD}, | ||
| {"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD}, | ||
| {"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD}, | ||
| {"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD}, | ||
| {"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD}, | ||
| {"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD}, | ||
| {"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD}, | ||
| // Wgrad (M, N, K=tokens) | ||
| {"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD}, | ||
| {"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD}, | ||
| {"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD}, | ||
| {"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD}, | ||
| {"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD}, | ||
| {"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD}, | ||
| {"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD}, | ||
| }; | ||
|
|
||
| // Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head) | ||
| static const ShapeDef qwen3_shapes[] = { | ||
| // Forward (M=tokens, N, K) | ||
| {"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD}, | ||
| {"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD}, | ||
| {"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD}, | ||
| {"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD}, | ||
| // Dgrad (M=tokens, N, K) | ||
| {"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD}, | ||
| {"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD}, | ||
| {"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD}, | ||
| {"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD}, | ||
| // Wgrad (M, N, K=tokens) | ||
| {"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD}, | ||
| {"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD}, | ||
| {"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD}, | ||
| {"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD}, | ||
| }; | ||
|
|
||
| // A, B, Bias, Gelu, D | ||
| // Bias type choose as bf16 in use_fp8, D_type otherwise | ||
| // Gelu type the same as Bias_Type | ||
|
|
@@ -559,7 +653,9 @@ void performTest(const TestParams& params) { | |
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| template <typename A_Type, typename B_Type, typename D_Type> | ||
| void performDqTest(const TestParams ¶ms) { | ||
| void performDqTest(const TestParams ¶ms, | ||
| std::optional<double> atol_override = std::nullopt, | ||
| std::optional<double> rtol_override = std::nullopt) { | ||
| DType atype = TypeInfo<A_Type>::dtype; | ||
| DType btype = TypeInfo<B_Type>::dtype; | ||
| DType dtype = TypeInfo<D_Type>::dtype; | ||
|
|
@@ -633,6 +729,10 @@ void performDqTest(const TestParams ¶ms) { | |
|
|
||
| //compare results | ||
| auto [atol, rtol] = getTestTolerances(dtype, true, true); | ||
| if (atol_override) | ||
| atol = *atol_override; | ||
| if (rtol_override) | ||
| rtol = *rtol_override; | ||
| compareResults("D", D, D_ref.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
@@ -751,6 +851,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | |
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| }); | ||
|
|
||
| // ============================================================================ | ||
| // Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*') | ||
| // ============================================================================ | ||
|
|
||
| // Known-failing GEMM shapes on gfx950 | ||
| static const std::set<std::string> kGfx950Skips = { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are known failures? |
||
| "DeepSeek3_Linear1_fwd_mbs1_NT", | ||
| "DeepSeek3_Linear1_fwd_mbs2_NT", | ||
| "DeepSeek3_Linear1_fwd_mbs4_NT", | ||
| "DeepSeek3_LNLinear0_fwd_mbs4_NN", | ||
| "DeepSeek3_LNLinear0_fwd_mbs4_NT", | ||
| "DeepSeek3_attn_wgrad_mbs1_NN", | ||
| "Qwen3_LMHead_fwd_mbs2_NN", | ||
| "Qwen3_Router_fwd_mbs2_NT", | ||
| "Qwen3_LMHead_fwd_mbs4_TN", | ||
| "Qwen3_LMHead_fwd_mbs4_NN", | ||
| "Qwen3_LMHead_fwd_mbs4_NT", | ||
| }; | ||
|
|
||
| // Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine. | ||
| using ProdGemmParam = std::tuple<ShapeDef, size_t, Layout>; | ||
|
|
||
| class ProdDqGEMMTestSuite : public ::testing::TestWithParam<ProdGemmParam> {}; | ||
|
|
||
| TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) { | ||
| const auto& shape = std::get<0>(GetParam()); | ||
| size_t mbs = std::get<1>(GetParam()); | ||
| const auto& layout = std::get<2>(GetParam()); | ||
|
|
||
| std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs) | ||
| + "_" + TN(layout); | ||
| if (kGfx950Skips.count(name)) { | ||
| GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name; | ||
| } | ||
|
|
||
| size_t m, k, n; | ||
| resolve_mkn(shape, mbs, m, k, n); | ||
|
|
||
| TestParams params = {.m = m, .k = k, .n = n, | ||
| .use_bias = false, .use_gelu = false, | ||
| .transa = layout.first, .transb = layout.second, | ||
| .scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING}; | ||
|
|
||
| // Production shapes use looser tolerances: the MXFP8 and bf16 reference | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If increased numerical error a result of bigger shapes, tolerance update better to be done inside the test. Also, DQ tests are to compare MXFP8 GEMM vs FP16 GEMM, while it can be used to test MXFP8 it's main purpose is to catch swizzling errors. Have you considered using PerformTest where MXFP8 GEMM it compared to reference MXFP8 GEMM? |
||
| // GEMM use different internal accumulation paths, so results can differ | ||
| // by up to 1 ULP in bf16 (~1.5-2% relative). | ||
| const double prod_atol = 1e-3; | ||
| const double prod_rtol = 2e-2; | ||
|
|
||
| switch (shape.pass) { | ||
| case GemmPass::FWD: | ||
| performDqTest<fp8, fp8, bf16>(params, prod_atol, prod_rtol); | ||
| break; | ||
| case GemmPass::DGRAD: | ||
| performDqTest<bf8, fp8, bf16>(params, prod_atol, prod_rtol); | ||
| break; | ||
| case GemmPass::WGRAD: | ||
| performDqTest<fp8, bf8, bf16>(params, prod_atol, prod_rtol); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| static auto prodTestName = [](const testing::TestParamInfo<ProdGemmParam>& info) { | ||
| const auto& shape = std::get<0>(info.param); | ||
| size_t mbs = std::get<1>(info.param); | ||
| const auto& layout = std::get<2>(info.param); | ||
| return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout); | ||
| }; | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite, | ||
| ::testing::Combine( | ||
| ::testing::ValuesIn(deepseek3_shapes), | ||
| ::testing::Values(size_t{1}, size_t{2}, size_t{4}), | ||
| ::testing::ValuesIn(kLayouts)), | ||
| prodTestName); | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite, | ||
| ::testing::Combine( | ||
| ::testing::ValuesIn(qwen3_shapes), | ||
| ::testing::Values(size_t{1}, size_t{2}, size_t{4}), | ||
| ::testing::ValuesIn(kLayouts)), | ||
| prodTestName); | ||
|
|
||
| TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { | ||
| const size_t rows = 128; | ||
| const size_t cols = 256; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use token count directly?