Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ add_executable(test_operator
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_fp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_dequantize_mxfp8_grouped.cu
test_dequantize_fp8_grouped.cu
test_dequantize_nvfp4.cu
test_transpose.cu
test_cast_transpose.cu
Expand Down
164 changes: 164 additions & 0 deletions tests/cpp/operator/test_cast_fp8_grouped.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/cast.h>
#include "../test_common.h"

using namespace transformer_engine;
using namespace test;

namespace {

template <typename InputType, typename OutputType>
void test_cast_fp8_grouped_impl(const std::vector<std::vector<size_t>>& shapes,
DType input_dtype, DType output_dtype) {
const size_t num_tensors = shapes.size();

// Create standard Tensor objects
std::vector<Tensor> in_tensors;
std::vector<Tensor> out_tensors;
std::vector<Tensor*> in_tensor_ptrs;
std::vector<Tensor*> out_tensor_ptrs;

in_tensors.reserve(num_tensors);
out_tensors.reserve(num_tensors);
in_tensor_ptrs.reserve(num_tensors);
out_tensor_ptrs.reserve(num_tensors);

for (size_t t = 0; t < num_tensors; ++t) {
in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype);
out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype,
true, false, NVTE_DELAYED_TENSOR_SCALING);

// Initialize inputs with random uniform values
fillUniform(&in_tensors[t]);
in_tensors[t].from_cpu();

// Initialize scales with random scaling factors
float random_scale = 1.5f + static_cast<float>(t) * 0.5f;
out_tensors[t].set_scale(random_scale);
out_tensors[t].set_scale_inv(0.0f); // Clear to ensure it's written
out_tensors[t].set_amax(0.0f); // Clear amax

in_tensor_ptrs.push_back(&in_tensors[t]);
out_tensor_ptrs.push_back(&out_tensors[t]);
}

// Build grouped tensors
GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING);
GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING);

// CPU reference computation
std::vector<std::vector<float>> ref_outputs(num_tensors);
std::vector<float> ref_amaxs(num_tensors, 0.0f);
std::vector<float> ref_scale_invs(num_tensors, 0.0f);

for (size_t t = 0; t < num_tensors; ++t) {
size_t size = product(shapes[t]);
ref_outputs[t].resize(size);
float scale = out_tensors[t].scale();
float cur_amax = 0.0f;

InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < size; ++i) {
float val = static_cast<float>(in_cpu[i]);
cur_amax = std::max(cur_amax, std::abs(val));
float scaled_val = val * scale;
ref_outputs[t][i] = scaled_val;
}
ref_amaxs[t] = cur_amax;
ref_scale_invs[t] = 1.0f / scale;
}

// Run GPU grouped quantization
QuantizationConfigWrapper quant_config;
nvte_group_quantize(in_group.get_handle(), out_group.get_handle(), quant_config, 0);
cudaDeviceSynchronize();

// Copy results back from grouped buffer to individual output tensors
for (size_t t = 0; t < num_tensors; ++t) {
// 1. Copy output data
size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8;
NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(),
static_cast<char*>(out_group.get_data()) + offset_bytes,
out_group.tensor_bytes[t],
cudaMemcpyDeviceToDevice));

// 2. Copy scale_inv
NVTEBasicTensor scale_inv_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTERowwiseScaleInv);
if (scale_inv_bt.data_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpy(scale_inv_bt.data_ptr,
static_cast<float*>(out_group.scale_inv.get()) + t,
sizeof(float),
cudaMemcpyDeviceToDevice));
}

// 3. Copy amax
NVTEBasicTensor amax_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTEAmax);
if (amax_bt.data_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpy(amax_bt.data_ptr,
static_cast<float*>(out_group.amax_dev.get()) + t,
sizeof(float),
cudaMemcpyDeviceToDevice));
}
}

// Validate results
for (size_t t = 0; t < num_tensors; ++t) {
size_t size = product(shapes[t]);
out_tensors[t].to_cpu();

// 1. Compare scale inverse
float scale_inv_gpu = out_tensors[t].rowwise_scale_inv();
EXPECT_NEAR(scale_inv_gpu, ref_scale_invs[t], 1e-4);

// 2. Compare amax
float amax_gpu = out_tensors[t].amax();
EXPECT_NEAR(amax_gpu, ref_amaxs[t], 1e-4);

// 3. Compare outputs
OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr<OutputType>();
for (size_t i = 0; i < size; ++i) {
float gpu_val = static_cast<float>(out_cpu[i]);
float ref_val = static_cast<float>(OutputType(ref_outputs[t][i]));
EXPECT_NEAR(gpu_val, ref_val, 1e-4);
}
}
}

class CastFP8GroupedTestSuite : public ::testing::Test {};

TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Uniform) {
test_cast_fp8_grouped_impl<bf16, fp8e4m3>(
{{32, 64}, {32, 64}, {32, 64}},
DType::kBFloat16, DType::kFloat8E4M3);
}

TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Varying) {
test_cast_fp8_grouped_impl<bf16, fp8e4m3>(
{{16, 32}, {64, 128}, {32, 64}},
DType::kBFloat16, DType::kFloat8E4M3);
}

TEST_F(CastFP8GroupedTestSuite, FP16_to_E4M3_Varying) {
test_cast_fp8_grouped_impl<fp16, fp8e4m3>(
{{8, 16}, {128, 64}, {64, 32}},
DType::kFloat16, DType::kFloat8E4M3);
}

TEST_F(CastFP8GroupedTestSuite, FP32_to_E5M2_Varying) {
test_cast_fp8_grouped_impl<float, fp8e5m2>(
{{32, 32}, {16, 64}, {128, 32}},
DType::kFloat32, DType::kFloat8E5M2);
}

} // namespace
131 changes: 131 additions & 0 deletions tests/cpp/operator/test_dequantize_fp8_grouped.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/cast.h>
#include "../test_common.h"

using namespace transformer_engine;
using namespace test;

namespace {

template <typename InputType, typename OutputType>
void test_dequantize_fp8_grouped_impl(const std::vector<std::vector<size_t>>& shapes,
DType input_dtype, DType output_dtype) {
const size_t num_tensors = shapes.size();

// Create standard Tensor objects
std::vector<Tensor> in_tensors;
std::vector<Tensor> out_tensors;
std::vector<Tensor*> in_tensor_ptrs;
std::vector<Tensor*> out_tensor_ptrs;

in_tensors.reserve(num_tensors);
out_tensors.reserve(num_tensors);
in_tensor_ptrs.reserve(num_tensors);
out_tensor_ptrs.reserve(num_tensors);

for (size_t t = 0; t < num_tensors; ++t) {
// Input is FP8 (with scale_inv)
in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype,
true, false, NVTE_DELAYED_TENSOR_SCALING);
// Output is higher precision
out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype);

// Initialize inputs with random uniform FP8 values
fillUniform(&in_tensors[t]);
in_tensors[t].from_cpu();

// Set scale_inv
float random_scale_inv = 0.5f + static_cast<float>(t) * 0.25f;
in_tensors[t].set_scale_inv(random_scale_inv);

// Clear output
fillUniform(&out_tensors[t]); // Initialize to some random values
out_tensors[t].from_cpu();

in_tensor_ptrs.push_back(&in_tensors[t]);
out_tensor_ptrs.push_back(&out_tensors[t]);
}

// Build grouped tensors
GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING);
GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING);

// CPU reference computation
std::vector<std::vector<float>> ref_outputs(num_tensors);
for (size_t t = 0; t < num_tensors; ++t) {
size_t size = product(shapes[t]);
ref_outputs[t].resize(size);
float scale_inv = in_tensors[t].rowwise_scale_inv();

InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < size; ++i) {
float val = static_cast<float>(in_cpu[i]);
ref_outputs[t][i] = val * scale_inv;
}
}

// Run GPU grouped dequantization
nvte_group_dequantize(in_group.get_handle(), out_group.get_handle(), 0);
cudaDeviceSynchronize();

// Copy results back from grouped buffer to individual output tensors
for (size_t t = 0; t < num_tensors; ++t) {
size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8;
NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(),
static_cast<char*>(out_group.get_data()) + offset_bytes,
out_group.tensor_bytes[t],
cudaMemcpyDeviceToDevice));
}

// Validate results
for (size_t t = 0; t < num_tensors; ++t) {
size_t size = product(shapes[t]);
out_tensors[t].to_cpu();

OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr<OutputType>();
for (size_t i = 0; i < size; ++i) {
float gpu_val = static_cast<float>(out_cpu[i]);
float ref_val = ref_outputs[t][i];
EXPECT_NEAR(gpu_val, ref_val, 1e-4);
}
}
}

class DequantizeFP8GroupedTestSuite : public ::testing::Test {};

TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Uniform) {
test_dequantize_fp8_grouped_impl<fp8e4m3, bf16>(
{{32, 64}, {32, 64}, {32, 64}},
DType::kFloat8E4M3, DType::kBFloat16);
}

TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Varying) {
test_dequantize_fp8_grouped_impl<fp8e4m3, bf16>(
{{16, 32}, {64, 128}, {32, 64}},
DType::kFloat8E4M3, DType::kBFloat16);
}

TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_FP16_Varying) {
test_dequantize_fp8_grouped_impl<fp8e4m3, fp16>(
{{8, 16}, {128, 64}, {64, 32}},
DType::kFloat8E4M3, DType::kFloat16);
}

TEST_F(DequantizeFP8GroupedTestSuite, E5M2_to_FP32_Varying) {
test_dequantize_fp8_grouped_impl<fp8e5m2, float>(
{{32, 32}, {16, 64}, {128, 32}},
DType::kFloat8E5M2, DType::kFloat32);
}

} // namespace
55 changes: 40 additions & 15 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1108,19 +1108,14 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
[&](int64_t v) { return v == last_dims[0]; });

std::vector<int64_t> offsets(num_tensors + 1, 0);
std::mt19937 gen(12345);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
// We use a constant 64 elements alignment. This guarantees that all
// grouped tensors (input and output) in pointwise operations will
// have identical element offsets, preventing layout misalignment in tests.
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment
size_t align_elements;
if (is_sub_byte) {
// Sub-byte types (e.g. FP4): 16 bytes = 16*8/bits_per_elem elements
align_elements = (16 * 8) / bits_per_elem;
} else {
align_elements = std::max<size_t>(1, (16 + elem_size - 1) / elem_size);
}
size_t align_elements = 64;
return dist(gen) * static_cast<int64_t>(align_elements);
};

Expand Down Expand Up @@ -1263,23 +1258,53 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
// FP8 tensor scaling: one float scale_inv per tensor
// For delayed scaling, rowwise and columnwise share the same scale
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
std::vector<float> scale_cpu(num_tensors, 1.f);
std::vector<float> amax_cpu(num_tensors, 0.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
if (has_rowwise) {
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
} else {
scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr<float>()[0];
}
// Gather scale
NVTEBasicTensor scale_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEScale);
if (scale_bt.data_ptr != nullptr) {
float val;
NVTE_CHECK_CUDA(cudaMemcpy(&val, scale_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost));
scale_cpu[i] = val;
}
// Gather amax
NVTEBasicTensor amax_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEAmax);
if (amax_bt.data_ptr != nullptr) {
float val;
NVTE_CHECK_CUDA(cudaMemcpy(&val, amax_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost));
amax_cpu[i] = val;
}
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
NVTEBasicTensor scale_inv_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor,
sizeof(scale_inv_tensor));
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_inv_tensor,
sizeof(scale_inv_tensor));

// Set scale on the grouped tensor
grouped.scale = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale.get(), scale_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEBasicTensor scale_tensor{grouped.scale.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedScale, &scale_tensor, sizeof(scale_tensor));

// Set amax on the grouped tensor
grouped.amax_dev = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.amax_dev.get(), amax_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEBasicTensor amax_tensor{grouped.amax_dev.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedAmax, &amax_tensor, sizeof(amax_tensor));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
// MXFP8: E8M0 scale_inv per block of 32 elements (1 byte per scale element).
if (has_rowwise) {
Expand Down
Loading