Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_executable(test_operator
test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu #CUDA-only test
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu #CUDA-only test
../test_common.cu)
Expand All @@ -42,7 +42,6 @@ if(USE_ROCM)
# Remove CUDA-only tests and add ROCm specific ones
list(REMOVE_ITEM test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu
test_grouped_gemm.cu)
list(APPEND test_cuda_sources
test_dequantize_nvfp4.cu
Expand Down
112 changes: 95 additions & 17 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand All @@ -30,7 +31,15 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{64, 128, 32},
{128, 128, 64},
{64, 256, 32},
{128, 384, 64},
{256, 512, 128},
{512, 1024, 256},
{768, 3072, 4096},
{1024, 2048, 128},
{4096, 8192, 64},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -303,6 +312,40 @@ void cpu_rowwise_to_columnwise(
}
}

// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250.
static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;
void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr) return;
const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();
size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d];
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);
if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));
}

std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) {
auto [atol, rtol] = getTolerances(type);

Expand All @@ -318,6 +361,12 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
else if (use_fp8) {
atol = 1e-3;
rtol = std::max(rtol, 1e-2);
// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major == 12 && type == DType::kBFloat16) {
rtol = std::max(rtol, 5e-2);
}
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
Expand Down Expand Up @@ -496,6 +545,31 @@ void performTest(const TestParams& params) {
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

//perform the reference gemm on GPU (before swizzle, which modifies scales in-place)
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (use_mxfp8 && prop.major == 12) {
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
}

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
Expand All @@ -517,23 +591,6 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the reference gemm on GPU
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// check if error message happens in running
(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand Down Expand Up @@ -582,6 +639,17 @@ void performDqTest(const TestParams &params) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

// hipBLASLt on gfx950 produces incorrect results for certain small MXFP8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ticket for that?

// GEMMs with non-TN layouts.
if (prop.major == 9 && prop.minor == 5) {
const bool is_NN = !params.transa && !params.transb;
const bool is_NT = !params.transa && params.transb;
if ((is_NN && params.m == 64) ||
(is_NT && params.m > 32 && params.m <= 128 && params.n <= 64)) {
GTEST_SKIP() << "hipBLASLt MXFP8 non-TN GEMM with small M/N is not supported on gfx950";
}
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
TShape b_shape = params.transb ? TShape{params.k, params.n} : TShape{params.n, params.k};
Expand All @@ -605,6 +673,16 @@ void performDqTest(const TestParams &params) {
nvte_dequantize(A_fp8.data(), A_ref.data(), 0);
nvte_dequantize(B_fp8.data(), B_ref.data(), 0);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (prop.major == 12) {
const bool a_colwise = !params.transa;
const bool b_colwise = params.transb;
if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true);
if (a_colwise) swizzle_mxfp8_scales(A_fp8, false);
if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true);
if (b_colwise) swizzle_mxfp8_scales(B_fp8, false);
}

Tensor bias;
Tensor pre_gelu_out;

Expand Down
180 changes: 180 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,183 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param));
return name;
});

#ifdef __HIP_PLATFORM_AMD__

// MX pre-swizzle test (gfx1250 Tensile 3D layout)
//
// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2})
// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4)

// CPU reference for Tensile 3D MX scale pre-swizzle.
// Row-major input [M, K], output is a flat permuted array.
void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127; // E8M0 identity: 2^0 = 1.0
if (m < orig_M && k < orig_K) {
val = h_input[m * orig_K + k];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127;
if (m < orig_M && k < orig_K) {
val = h_input[k * orig_M + m];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

static size_t roundup_sz(size_t val, size_t mult) {
return ((val + mult - 1) / mult) * mult;
}

class MxSwizzleTestSuite
: public ::testing::TestWithParam<
std::tuple<std::pair<int, int>, bool>> {};

TEST_P(MxSwizzleTestSuite, TestMxSwizzle) {
using namespace transformer_engine;
using namespace test;

cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (prop.major < 12) {
GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250";
}

const auto dims = std::get<0>(GetParam());
const bool rowwise = std::get<1>(GetParam());

// Original (unpadded) scale dimensions
const size_t orig_M = dims.first;
const size_t orig_K = dims.second;

// Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4
const size_t M = orig_M;
const size_t K = roundup_sz(orig_K, 4);

// Allocate host input (unpadded) and fill with random data
const size_t input_size = orig_M * orig_K;
std::unique_ptr<uint8_t[]> h_input(new uint8_t[input_size]);
std::mt19937 rng(42);
for (size_t i = 0; i < input_size; i++) {
h_input[i] = static_cast<uint8_t>(rng() % 256);
}

// Allocate device input
uint8_t *d_input = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size));
NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice));

// Allocate device output (padded size)
const size_t output_size = M * K;
uint8_t *d_output = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size));
NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size));

// Build TensorWrapper for input and output
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);

// Data shape must be consistent with scale shape for validation.
// Scale shapes use padded K; data shapes use unpadded dims
// (kernel derives original_M/K from them).
if (rowwise) {
std::vector<size_t> data_shape_in = {orig_M, orig_K * 32};
std::vector<size_t> data_shape_out = {M, K * 32};
std::vector<size_t> scale_shape_in = {M, K};
std::vector<size_t> scale_shape_out = {M, K};
input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
} else {
std::vector<size_t> data_shape_in = {orig_K * 32, orig_M};
std::vector<size_t> data_shape_out = {K * 32, M};
std::vector<size_t> scale_shape_in = {K, M};
std::vector<size_t> scale_shape_out = {K, M};
input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
}

nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Copy output back to host
std::unique_ptr<uint8_t[]> h_output(new uint8_t[output_size]);
NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost));

// Compute reference
std::unique_ptr<uint8_t[]> h_ref(new uint8_t[output_size]);
memset(h_ref.get(), 0, output_size);
if (rowwise) {
compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
} else {
compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
}

// Compare
compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size);

cudaFree(d_input);
cudaFree(d_output);
}

namespace {

// Scale dimensions (M_scale, K_scale).
// K_scale will be padded to multiple of 4 by the test.
std::vector<std::pair<int, int>> mx_scale_dims = {
{4, 4}, // minimal
{8, 4}, // small
{32, 8}, // medium
{64, 16}, // larger
{96, 8}, // non-power-of-2 M
{128, 32}, // big
{256, 64}, // bigger
{512, 128}, // stress inter-tile
{1024, 256}, // large
{4096, 256}, // max stress
};

} // namespace

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxSwizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(mx_scale_dims),
::testing::Values(true, false)
),
[](const testing::TestParamInfo<MxSwizzleTestSuite::ParamType>& info) {
std::string name = "M" + std::to_string(std::get<0>(info.param).first) +
"_K" + std::to_string(std::get<0>(info.param).second) +
(std::get<1>(info.param) ? "_row" : "_col");
return name;
});

#endif // __HIP_PLATFORM_AMD__
Loading
Loading