Skip to content
185 changes: 184 additions & 1 deletion tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
************************************************************************/
#include <cmath>
#include <iostream>
#include <optional>
#include <set>
#include <string>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -33,6 +35,98 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{768, 3072, 4096},
};

// ============================================================================
// Production LLM shapes for MXFP8 GEMM testing.
//
// Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use token count directly?

// yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT)
// via ::testing::Combine.
//
// GemmPass selects the FP8 type combination:
// FWD: E4M3 x E4M3 -> BF16
// DGRAD: E5M2 x E4M3 -> BF16
// WGRAD: E4M3 x E5M2 -> BF16
// ============================================================================

enum class GemmPass { FWD, DGRAD, WGRAD };

struct ShapeDef {
const char* label;
size_t dim1; // FWD/DGRAD: N, WGRAD: M
size_t dim2; // FWD/DGRAD: K, WGRAD: N
GemmPass pass;
};

std::ostream& operator<<(std::ostream& os, const ShapeDef& s) {
return os << s.label;
}

static void resolve_mkn(const ShapeDef& s, size_t mbs,
size_t& m, size_t& k, size_t& n) {
size_t tokens = mbs * 4096;
switch (s.pass) {
case GemmPass::FWD:
case GemmPass::DGRAD:
m = tokens; n = s.dim1; k = s.dim2;
break;
case GemmPass::WGRAD:
m = s.dim1; n = s.dim2; k = tokens;
break;
}
}

// DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head)
static const ShapeDef deepseek3_shapes[] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is execution time for all ProdGEMM tests?

// Forward (M=tokens, N, K)
{"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD},
{"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD},
{"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD},
{"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD},
{"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD},
{"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD},
{"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD},
{"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD},
{"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD},
{"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD},
{"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD},
{"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD},
{"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD},
{"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD},
{"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD},
{"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD},
{"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD},
{"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD},
{"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD},
{"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD},
};

// Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head)
static const ShapeDef qwen3_shapes[] = {
// Forward (M=tokens, N, K)
{"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD},
{"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD},
{"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD},
{"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD},
{"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD},
{"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD},
{"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD},
{"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD},
{"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD},
{"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD},
};

// A, B, Bias, Gelu, D
// Bias type choose as bf16 in use_fp8, D_type otherwise
// Gelu type the same as Bias_Type
Expand Down Expand Up @@ -559,7 +653,9 @@ void performTest(const TestParams& params) {

#ifdef __HIP_PLATFORM_AMD__
template <typename A_Type, typename B_Type, typename D_Type>
void performDqTest(const TestParams &params) {
void performDqTest(const TestParams &params,
std::optional<double> atol_override = std::nullopt,
std::optional<double> rtol_override = std::nullopt) {
DType atype = TypeInfo<A_Type>::dtype;
DType btype = TypeInfo<B_Type>::dtype;
DType dtype = TypeInfo<D_Type>::dtype;
Expand Down Expand Up @@ -633,6 +729,10 @@ void performDqTest(const TestParams &params) {

//compare results
auto [atol, rtol] = getTestTolerances(dtype, true, true);
if (atol_override)
atol = *atol_override;
if (rtol_override)
rtol = *rtol_override;
compareResults("D", D, D_ref.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
}
#endif // __HIP_PLATFORM_AMD__
Expand Down Expand Up @@ -751,6 +851,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
});

// ============================================================================
// Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*')
// ============================================================================

// Known-failing GEMM shapes on gfx950
static const std::set<std::string> kGfx950Skips = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are known failures?

"DeepSeek3_Linear1_fwd_mbs1_NT",
"DeepSeek3_Linear1_fwd_mbs2_NT",
"DeepSeek3_Linear1_fwd_mbs4_NT",
"DeepSeek3_LNLinear0_fwd_mbs4_NN",
"DeepSeek3_LNLinear0_fwd_mbs4_NT",
"DeepSeek3_attn_wgrad_mbs1_NN",
"Qwen3_LMHead_fwd_mbs2_NN",
"Qwen3_Router_fwd_mbs2_NT",
"Qwen3_LMHead_fwd_mbs4_TN",
"Qwen3_LMHead_fwd_mbs4_NN",
"Qwen3_LMHead_fwd_mbs4_NT",
};

// Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine.
using ProdGemmParam = std::tuple<ShapeDef, size_t, Layout>;

class ProdDqGEMMTestSuite : public ::testing::TestWithParam<ProdGemmParam> {};

TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) {
const auto& shape = std::get<0>(GetParam());
size_t mbs = std::get<1>(GetParam());
const auto& layout = std::get<2>(GetParam());

std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs)
+ "_" + TN(layout);
if (kGfx950Skips.count(name)) {
GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name;
}

size_t m, k, n;
resolve_mkn(shape, mbs, m, k, n);

TestParams params = {.m = m, .k = k, .n = n,
.use_bias = false, .use_gelu = false,
.transa = layout.first, .transb = layout.second,
.scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING};

// Production shapes use looser tolerances: the MXFP8 and bf16 reference
Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo Jun 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If increased numerical error a result of bigger shapes, tolerance update better to be done inside the test. Also, DQ tests are to compare MXFP8 GEMM vs FP16 GEMM, while it can be used to test MXFP8 it's main purpose is to catch swizzling errors. Have you considered using PerformTest where MXFP8 GEMM it compared to reference MXFP8 GEMM?

// GEMM use different internal accumulation paths, so results can differ
// by up to 1 ULP in bf16 (~1.5-2% relative).
const double prod_atol = 1e-3;
const double prod_rtol = 2e-2;

switch (shape.pass) {
case GemmPass::FWD:
performDqTest<fp8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::DGRAD:
performDqTest<bf8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::WGRAD:
performDqTest<fp8, bf8, bf16>(params, prod_atol, prod_rtol);
break;
}
}

static auto prodTestName = [](const testing::TestParamInfo<ProdGemmParam>& info) {
const auto& shape = std::get<0>(info.param);
size_t mbs = std::get<1>(info.param);
const auto& layout = std::get<2>(info.param);
return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout);
};

INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(deepseek3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(qwen3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
const size_t rows = 128;
const size_t cols = 256;
Expand Down
Loading