Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bc363fa
add MX scale pre-swizzling for gfx1250
matthiasdiener Apr 27, 2026
a6ca3af
switch to mxfp4
matthiasdiener Apr 27, 2026
d1ee5bd
tensile-like implementation
matthiasdiener Apr 28, 2026
d1647ee
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener Apr 29, 2026
1fff6d9
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
d714038
gfx1250 swizzle_xor changes for FP4
matthiasdiener May 1, 2026
76ca4b1
change line endings to unix, trim trailing whitespace
matthiasdiener May 1, 2026
81a0a27
Merge branch 'mdiener/swizzle_xor-1250' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
2991bcf
fix arch
matthiasdiener May 1, 2026
8ceb89c
[WIP] e2e gemm test, not working yet
matthiasdiener May 1, 2026
167d2eb
fix for gfx1250
matthiasdiener May 3, 2026
5d46537
k-tile
matthiasdiener May 3, 2026
313a6b7
extend tests
matthiasdiener May 3, 2026
2a8eeb5
remove ifdef
matthiasdiener May 3, 2026
c37a781
undo BLK32_UE8M0_32_8_EXT
matthiasdiener May 4, 2026
5d2d38f
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 5, 2026
f093f64
Revert "change line endings to unix, trim trailing whitespace"
matthiasdiener May 5, 2026
ecbffea
Revert "gfx1250 swizzle_xor changes for FP4"
matthiasdiener May 5, 2026
6855218
Claude PR review use OIDC-free method (#560)
Micky774 May 7, 2026
a0b88f4
gfx1250 swizzle_xor changes for FP4 (#571)
matthiasdiener May 9, 2026
27f4acd
NVFP4: Work around intermittent incorrect results for backward GEMMs …
matthiasdiener May 13, 2026
33fca6e
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 13, 2026
b55a538
address review comments
matthiasdiener May 13, 2026
398cc3c
cleanups
matthiasdiener May 13, 2026
384d590
re-add scale swizzle hooks in GEMM paths for gfx1250
matthiasdiener May 13, 2026
5c5a902
cleanups
matthiasdiener May 13, 2026
2c05ec5
arch fixes
matthiasdiener May 14, 2026
5552b09
more test fixes gfx1250
matthiasdiener May 18, 2026
5cb098b
RMS Norm Optimization (#583)
aris134 May 18, 2026
bdee033
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 19, 2026
90db6f4
address review comments
matthiasdiener May 19, 2026
2a6302d
additional padding
matthiasdiener May 19, 2026
03e33b1
Revert "Claude PR review use OIDC-free method (#560)"
matthiasdiener May 21, 2026
96254fa
Revert "RMS Norm Optimization (#583)"
matthiasdiener May 21, 2026
b83a2d9
revert unnecessary changes for gfx1250
matthiasdiener May 21, 2026
bea6b18
remove extra guards
matthiasdiener May 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 218 additions & 39 deletions .github/workflows/claude-pr-action.yml

Large diffs are not rendered by default.

119 changes: 0 additions & 119 deletions .github/workflows/claude-pr-trigger.yml

This file was deleted.

4 changes: 2 additions & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources
test_multi_unpadding.cu
test_causal_softmax.cu
test_swap_first_dims.cu
test_swizzle.cu
../test_common.cu)
if(USE_CUDA)
list(APPEND test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu)
test_cast_float8blockwise.cu)
else()
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu
Expand Down
225 changes: 225 additions & 0 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand Down Expand Up @@ -793,4 +794,228 @@ TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
}
}

// ============================================================================
// End-to-end MXFP8 GEMM test with pre-swizzled scales
//
// Verifies that the full pipeline works:
// 1. Create MXFP8 FP8 tensors with random data + scales
// 2. Run a reference GEMM (using un-swizzled scales)
// 3. Swizzle the scales via nvte_swizzle_scaling_factors
// 4. Run the actual hipBLASlt GEMM
// 5. Compare results
// ============================================================================

// Helper: swizzle the MXFP8 scale_inv of a test::Tensor in-place.
// Allocates a temp device buffer, swizzles into it, copies back.
static void swizzle_tensor_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;

void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr)
return;

const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();

size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) {
num_scales *= scale_shape.data[d];
}

// Allocate temp buffer for swizzled output
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));

// Build TensorWrapper pair for the swizzle call
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);

if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}

nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Copy swizzled scales back over the original
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));

// Mark tensor as having swizzled scales
t.set_with_gemm_swizzled_scales(true);
}

// Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout)
// A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major
// Scales are E8M0, one per group of 32 elements along K.
__global__ void mxfp8_gemm_ref_kernel(
Comment thread
alextmagro marked this conversation as resolved.
Outdated
const test::fp8e4m3 *a_data, const uint8_t *a_scale, size_t a_scale_ld,
const test::fp8e4m3 *b_data, const uint8_t *b_scale, size_t b_scale_ld,
test::bf16 *d_data,
size_t M, size_t K, size_t N) {
const size_t i = blockIdx.y * blockDim.y + threadIdx.y;
const size_t j = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= M || j >= N)
return;

float acc = 0.0f;

for (size_t kk = 0; kk < K; kk++) {
size_t kc = kk / 32;
float a_sinv = exp2f(static_cast<float>(a_scale[i * a_scale_ld + kc]) - 127.0f);
float b_sinv = exp2f(static_cast<float>(b_scale[j * b_scale_ld + kc]) - 127.0f);
float a_val = static_cast<float>(a_data[i * K + kk]);
float b_val = static_cast<float>(b_data[j * K + kk]);
acc += a_sinv * a_val * b_sinv * b_val;
}

d_data[i + j * M] = static_cast<test::bf16>(acc);
}

struct MxGemmParams {
size_t m, k, n;
};

class MxGemmSwizzleGfx1250TestSuite
: public ::testing::TestWithParam<MxGemmParams> {};

TEST_P(MxGemmSwizzleGfx1250TestSuite, TestMxfp8GemmE2E) {
Comment thread
alextmagro marked this conversation as resolved.
Outdated
using namespace transformer_engine;
using namespace test;

const auto &p = GetParam();
const size_t M = p.m;
const size_t K = p.k;
const size_t N = p.n;

cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));

// This test validates the MX scale pre-swizzle -> GEMM pipeline on gfx1250+.
// Non-swizzle MXFP8 GEMMs are already covered by GEMMTestSuite.
if (prop.major < 12) {
GTEST_SKIP() << "MX scale pre-swizzle GEMM requires gfx1250+";
}

// TN layout: A is [M, K], B is [N, K]
const bool transa = true;
const bool transb = false;

Tensor A("A", std::vector<size_t>{M, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING);
Tensor B("B", std::vector<size_t>{N, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING);
Tensor D("D", std::vector<size_t>{N, M}, DType::kBFloat16);
Tensor RefD("RefD", std::vector<size_t>{N, M}, DType::kBFloat16);
Tensor bias;
Tensor pre_gelu_out;

fillUniform(&A);
fillUniform(&B);

// Override scales with values in [120,127] so layout errors are detectable.
// Default random [0,127] produces mostly tiny scales (2^(-127)..2^0),
// making the test insensitive to permutation errors.
{
auto fill_discriminating_scales = [](void *scale_ptr, size_t count) {
std::vector<uint8_t> h(count);
std::mt19937 rng(42);
std::uniform_int_distribution<uint8_t> dist(120, 127);
for (size_t i = 0; i < count; i++)
h[i] = dist(rng);
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, h.data(), count, cudaMemcpyHostToDevice));
};
auto a_sh = A.rowwise_scale_inv_shape();
auto b_sh = B.rowwise_scale_inv_shape();
fill_discriminating_scales(A.rowwise_scale_inv_dptr(), a_sh.data[0] * a_sh.data[1]);
fill_discriminating_scales(B.rowwise_scale_inv_dptr(), b_sh.data[0] * b_sh.data[1]);
}

// GPU reference with un-swizzled (compact) scales
const auto a_scale_shape = A.rowwise_scale_inv_shape();
const auto b_scale_shape = B.rowwise_scale_inv_shape();

std::cout << " A_scale shape: [" << a_scale_shape.data[0] << ", " << a_scale_shape.data[1]
<< "], B_scale shape: [" << b_scale_shape.data[0] << ", " << b_scale_shape.data[1]
<< "]" << std::endl;

{
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
mxfp8_gemm_ref_kernel<<<grid, block>>>(
static_cast<const fp8e4m3 *>(A.rowwise_dptr()),
static_cast<const uint8_t *>(A.rowwise_scale_inv_dptr()),
a_scale_shape.data[1],
static_cast<const fp8e4m3 *>(B.rowwise_dptr()),
static_cast<const uint8_t *>(B.rowwise_scale_inv_dptr()),
b_scale_shape.data[1],
static_cast<bf16 *>(RefD.rowwise_dptr()),
M, K, N);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
}

// Swizzle scales to K-tiled layout for hipBLASlt BLK32_UE8M0_32_8_EXT on gfx1250.
// Layout: {M, K_scale}.reshape({M, K_scale/4, 4}).permute({1,0,2})
// dst(m,k) = (k/4)*M*4 + m*4 + (k%4)
swizzle_tensor_scales(A, true);
swizzle_tensor_scales(B, true);

// Run actual GEMM
size_t workspace_size = 134217728; // 128MB
Tensor Workspace("Workspace", std::vector<size_t>{workspace_size}, DType::kByte);

nvte_cublas_gemm(A.data(), B.data(), D.data(),
bias.data(), pre_gelu_out.data(),
transa, transb,
/*grad=*/false,
Workspace.data(),
/*accumulate=*/false,
/*use_split_accumulator=*/false,
prop.multiProcessorCount,
0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Compare
D.to_cpu();
RefD.to_cpu();

// MXFP8 accumulation errors grow with K due to different reduction orders
// between hardware and reference kernels.
const double atol = 5e-2 + K * 2e-4;
const double rtol = 1.5e-2;
compareResults("D", D, RefD.rowwise_cpu_dptr<bf16>(), true, atol, rtol);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxGemmSwizzleGfx1250TestSuite,
::testing::Values(
MxGemmParams{32, 128, 16},
MxGemmParams{64, 128, 32},
MxGemmParams{128, 128, 64},
MxGemmParams{64, 256, 32},
MxGemmParams{128, 384, 64},
MxGemmParams{256, 512, 128},
MxGemmParams{512, 1024, 256},
MxGemmParams{1024, 2048, 128},
MxGemmParams{4096, 8192, 64}
),
[](const testing::TestParamInfo<MxGemmSwizzleGfx1250TestSuite::ParamType> &info) {
return "M" + std::to_string(info.param.m) +
"_K" + std::to_string(info.param.k) +
"_N" + std::to_string(info.param.n);
});

#endif // __HIP_PLATFORM_AMD__
Loading
Loading