diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..343653b43d 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -12,10 +12,12 @@ add_executable(test_operator test_qdq.cu test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu + test_cast_fp8_grouped.cu test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_dequantize_mxfp8_grouped.cu + test_dequantize_fp8_grouped.cu test_dequantize_nvfp4.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_fp8_grouped.cu b/tests/cpp/operator/test_cast_fp8_grouped.cu new file mode 100644 index 0000000000..63cfb3bd1a --- /dev/null +++ b/tests/cpp/operator/test_cast_fp8_grouped.cu @@ -0,0 +1,164 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void test_cast_fp8_grouped_impl(const std::vector>& shapes, + DType input_dtype, DType output_dtype) { + const size_t num_tensors = shapes.size(); + + // Create standard Tensor objects + std::vector in_tensors; + std::vector out_tensors; + std::vector in_tensor_ptrs; + std::vector out_tensor_ptrs; + + in_tensors.reserve(num_tensors); + out_tensors.reserve(num_tensors); + in_tensor_ptrs.reserve(num_tensors); + out_tensor_ptrs.reserve(num_tensors); + + for (size_t t = 0; t < num_tensors; ++t) { + in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype); + out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype, + true, false, NVTE_DELAYED_TENSOR_SCALING); + + // Initialize inputs with random uniform values + fillUniform(&in_tensors[t]); + in_tensors[t].from_cpu(); + + // Initialize scales with random scaling factors + float random_scale = 1.5f + static_cast(t) * 0.5f; + out_tensors[t].set_scale(random_scale); + out_tensors[t].set_scale_inv(0.0f); // Clear to ensure it's written + out_tensors[t].set_amax(0.0f); // Clear amax + + in_tensor_ptrs.push_back(&in_tensors[t]); + out_tensor_ptrs.push_back(&out_tensors[t]); + } + + // Build grouped tensors + GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + + // CPU reference computation + std::vector> ref_outputs(num_tensors); + std::vector ref_amaxs(num_tensors, 0.0f); + std::vector ref_scale_invs(num_tensors, 0.0f); + + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + ref_outputs[t].resize(size); + float scale = out_tensors[t].scale(); + float cur_amax = 0.0f; + + InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float val = static_cast(in_cpu[i]); + cur_amax = std::max(cur_amax, std::abs(val)); + float scaled_val = val * scale; + ref_outputs[t][i] = scaled_val; + } + ref_amaxs[t] = cur_amax; + ref_scale_invs[t] = 1.0f / scale; + } + + // Run GPU grouped quantization + QuantizationConfigWrapper quant_config; + nvte_group_quantize(in_group.get_handle(), out_group.get_handle(), quant_config, 0); + cudaDeviceSynchronize(); + + // Copy results back from grouped buffer to individual output tensors + for (size_t t = 0; t < num_tensors; ++t) { + // 1. Copy output data + size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8; + NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(), + static_cast(out_group.get_data()) + offset_bytes, + out_group.tensor_bytes[t], + cudaMemcpyDeviceToDevice)); + + // 2. Copy scale_inv + NVTEBasicTensor scale_inv_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTERowwiseScaleInv); + if (scale_inv_bt.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(scale_inv_bt.data_ptr, + static_cast(out_group.scale_inv.get()) + t, + sizeof(float), + cudaMemcpyDeviceToDevice)); + } + + // 3. Copy amax + NVTEBasicTensor amax_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTEAmax); + if (amax_bt.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(amax_bt.data_ptr, + static_cast(out_group.amax_dev.get()) + t, + sizeof(float), + cudaMemcpyDeviceToDevice)); + } + } + + // Validate results + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + out_tensors[t].to_cpu(); + + // 1. Compare scale inverse + float scale_inv_gpu = out_tensors[t].rowwise_scale_inv(); + EXPECT_NEAR(scale_inv_gpu, ref_scale_invs[t], 1e-4); + + // 2. Compare amax + float amax_gpu = out_tensors[t].amax(); + EXPECT_NEAR(amax_gpu, ref_amaxs[t], 1e-4); + + // 3. Compare outputs + OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float gpu_val = static_cast(out_cpu[i]); + float ref_val = static_cast(OutputType(ref_outputs[t][i])); + EXPECT_NEAR(gpu_val, ref_val, 1e-4); + } + } +} + +class CastFP8GroupedTestSuite : public ::testing::Test {}; + +TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Uniform) { + test_cast_fp8_grouped_impl( + {{32, 64}, {32, 64}, {32, 64}}, + DType::kBFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Varying) { + test_cast_fp8_grouped_impl( + {{16, 32}, {64, 128}, {32, 64}}, + DType::kBFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, FP16_to_E4M3_Varying) { + test_cast_fp8_grouped_impl( + {{8, 16}, {128, 64}, {64, 32}}, + DType::kFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, FP32_to_E5M2_Varying) { + test_cast_fp8_grouped_impl( + {{32, 32}, {16, 64}, {128, 32}}, + DType::kFloat32, DType::kFloat8E5M2); +} + +} // namespace diff --git a/tests/cpp/operator/test_dequantize_fp8_grouped.cu b/tests/cpp/operator/test_dequantize_fp8_grouped.cu new file mode 100644 index 0000000000..378a345abb --- /dev/null +++ b/tests/cpp/operator/test_dequantize_fp8_grouped.cu @@ -0,0 +1,131 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void test_dequantize_fp8_grouped_impl(const std::vector>& shapes, + DType input_dtype, DType output_dtype) { + const size_t num_tensors = shapes.size(); + + // Create standard Tensor objects + std::vector in_tensors; + std::vector out_tensors; + std::vector in_tensor_ptrs; + std::vector out_tensor_ptrs; + + in_tensors.reserve(num_tensors); + out_tensors.reserve(num_tensors); + in_tensor_ptrs.reserve(num_tensors); + out_tensor_ptrs.reserve(num_tensors); + + for (size_t t = 0; t < num_tensors; ++t) { + // Input is FP8 (with scale_inv) + in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype, + true, false, NVTE_DELAYED_TENSOR_SCALING); + // Output is higher precision + out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype); + + // Initialize inputs with random uniform FP8 values + fillUniform(&in_tensors[t]); + in_tensors[t].from_cpu(); + + // Set scale_inv + float random_scale_inv = 0.5f + static_cast(t) * 0.25f; + in_tensors[t].set_scale_inv(random_scale_inv); + + // Clear output + fillUniform(&out_tensors[t]); // Initialize to some random values + out_tensors[t].from_cpu(); + + in_tensor_ptrs.push_back(&in_tensors[t]); + out_tensor_ptrs.push_back(&out_tensors[t]); + } + + // Build grouped tensors + GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + + // CPU reference computation + std::vector> ref_outputs(num_tensors); + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + ref_outputs[t].resize(size); + float scale_inv = in_tensors[t].rowwise_scale_inv(); + + InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float val = static_cast(in_cpu[i]); + ref_outputs[t][i] = val * scale_inv; + } + } + + // Run GPU grouped dequantization + nvte_group_dequantize(in_group.get_handle(), out_group.get_handle(), 0); + cudaDeviceSynchronize(); + + // Copy results back from grouped buffer to individual output tensors + for (size_t t = 0; t < num_tensors; ++t) { + size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8; + NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(), + static_cast(out_group.get_data()) + offset_bytes, + out_group.tensor_bytes[t], + cudaMemcpyDeviceToDevice)); + } + + // Validate results + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + out_tensors[t].to_cpu(); + + OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float gpu_val = static_cast(out_cpu[i]); + float ref_val = ref_outputs[t][i]; + EXPECT_NEAR(gpu_val, ref_val, 1e-4); + } + } +} + +class DequantizeFP8GroupedTestSuite : public ::testing::Test {}; + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Uniform) { + test_dequantize_fp8_grouped_impl( + {{32, 64}, {32, 64}, {32, 64}}, + DType::kFloat8E4M3, DType::kBFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Varying) { + test_dequantize_fp8_grouped_impl( + {{16, 32}, {64, 128}, {32, 64}}, + DType::kFloat8E4M3, DType::kBFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_FP16_Varying) { + test_dequantize_fp8_grouped_impl( + {{8, 16}, {128, 64}, {64, 32}}, + DType::kFloat8E4M3, DType::kFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E5M2_to_FP32_Varying) { + test_dequantize_fp8_grouped_impl( + {{32, 32}, {16, 64}, {128, 32}}, + DType::kFloat8E5M2, DType::kFloat32); +} + +} // namespace diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index fc41d44720..c56a156680 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1108,19 +1108,14 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, [&](int64_t v) { return v == last_dims[0]; }); std::vector offsets(num_tensors + 1, 0); + std::mt19937 gen(12345); auto random_padding = [&]() -> int64_t { // Random padding ensuring 16-byte alignment regardless of element size - // cuBLAS requires aligned pointers for vectorized loads - static std::mt19937 gen(12345); + // We use a constant 64 elements alignment. This guarantees that all + // grouped tensors (input and output) in pointwise operations will + // have identical element offsets, preventing layout misalignment in tests. std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment - size_t align_elements; - if (is_sub_byte) { - // Sub-byte types (e.g. FP4): 16 bytes = 16*8/bits_per_elem elements - align_elements = (16 * 8) / bits_per_elem; - } else { - align_elements = std::max(1, (16 + elem_size - 1) / elem_size); - } + size_t align_elements = 64; return dist(gen) * static_cast(align_elements); }; @@ -1263,6 +1258,8 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // FP8 tensor scaling: one float scale_inv per tensor // For delayed scaling, rowwise and columnwise share the same scale std::vector scale_inv_cpu(num_tensors, 1.f); + std::vector scale_cpu(num_tensors, 1.f); + std::vector amax_cpu(num_tensors, 0.f); for (size_t i = 0; i < num_tensors; ++i) { tensors[i]->to_cpu(); if (has_rowwise) { @@ -1270,16 +1267,44 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, } else { scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr()[0]; } + // Gather scale + NVTEBasicTensor scale_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEScale); + if (scale_bt.data_ptr != nullptr) { + float val; + NVTE_CHECK_CUDA(cudaMemcpy(&val, scale_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + scale_cpu[i] = val; + } + // Gather amax + NVTEBasicTensor amax_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEAmax); + if (amax_bt.data_ptr != nullptr) { + float val; + NVTE_CHECK_CUDA(cudaMemcpy(&val, amax_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + amax_cpu[i] = val; + } } grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor, - sizeof(scale_tensor)); - nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, - sizeof(scale_tensor)); + NVTEBasicTensor scale_inv_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor, + sizeof(scale_inv_tensor)); + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_inv_tensor, + sizeof(scale_inv_tensor)); + + // Set scale on the grouped tensor + grouped.scale = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale.get(), scale_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEBasicTensor scale_tensor{grouped.scale.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedScale, &scale_tensor, sizeof(scale_tensor)); + + // Set amax on the grouped tensor + grouped.amax_dev = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.amax_dev.get(), amax_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEBasicTensor amax_tensor{grouped.amax_dev.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedAmax, &amax_tensor, sizeof(amax_tensor)); } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { // MXFP8: E8M0 scale_inv per block of 32 elements (1 byte per scale element). if (has_rowwise) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 11d96c2e60..d01cba594b 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -594,6 +594,7 @@ struct GroupedBuffers { CudaPtr<> data; CudaPtr<> scale_inv; CudaPtr<> columnwise_scale_inv; + CudaPtr<> scale; CudaPtr first_dims_dev; CudaPtr last_dims_dev; CudaPtr offsets_dev; diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 63c1b046ff..ebfef5b485 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -61,6 +61,10 @@ inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *o CheckOutputGroupedTensor(*output, "group_dequantize_output"); switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_dequantize(input, output, stream); + break; + } case NVTE_MXFP8_1D_SCALING: { if (is_supported_by_CC_100()) { mxfp8::group_dequantize(&input, output, stream); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index bad53a03c6..009a5d0d20 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -450,6 +450,10 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_quantize(*input_tensor, noop_tensor, output_tensor, stream); + break; + } case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, @@ -491,6 +495,12 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe // Dispatch to quantization kernel depending on data format switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 6a0eaf94fb..b637f792bc 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -47,6 +47,32 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) stream);); // NOLINT(*) ); // NOLINT(*) } +struct GroupDequantizeParam {}; + +__device__ inline float group_dequantize_func(float value, const GroupDequantizeParam &) { + return value; +} + +inline void group_dequantize(const GroupedTensor &input, GroupedTensor *output, + cudaStream_t stream) { + const size_t N = product(input.data.shape); + const size_t scale_inv_numel = product(input.scale_inv.shape); + + const int64_t *const offsets = reinterpret_cast(input.tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input.first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input.last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, constexpr int nvec = 32 / sizeof(OType); GroupDequantizeParam p; + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, + const_cast(reinterpret_cast(input.scale_inv.dptr)), N, p, + stream, offsets, first_dims, last_dims, input.num_tensors, 1, scale_inv_numel, 1););); +} + } // namespace fp8 } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index bad10c954e..aa76dc0167 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -572,6 +572,65 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } +template +void group_quantize(const GroupedTensor &input, const Tensor *noop, GroupedTensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + const size_t scale_numel = product(output->scale.shape); + const size_t scale_inv_numel = product(output->scale_inv.shape); + const size_t amax_numel = product(output->amax.shape); + + const int64_t *const offsets = reinterpret_cast(input.tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input.first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input.last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, offsets, first_dims, + last_dims, input.num_tensors, scale_numel, scale_inv_numel, amax_numel););); +} + +template +void group_quantize(const GroupedTensor &grad, const GroupedTensor *input, const Tensor *noop, + GroupedTensor *output, GroupedTensor *dbias, Tensor *workspace, + cudaStream_t stream) { + NVTE_CHECK(!IS_DBIAS && !IS_DACT, + "Gated or DBias fusions are not supported in FP8 Grouped Quantization."); + + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + const size_t scale_numel = product(output->scale.shape); + const size_t scale_inv_numel = product(output->scale_inv.shape); + const size_t amax_numel = product(output->amax.shape); + + const int64_t *const offsets = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input->last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, offsets, first_dims, + last_dims, input->num_tensors, scale_numel, scale_inv_numel, amax_numel););); +} + } // namespace fp8 } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b3179d38fd..73a3cfbe8d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -317,13 +317,13 @@ void CheckGroupedTensorShapeArrays(const GroupedTensor &t, std::string_view name // Validate data size matches logical_shape size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1]; if (t.has_data()) { - NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (", - t.data.numel(), ") must match logical_shape size (", expected_numel, ")"); + NVTE_CHECK(t.data.numel() >= expected_numel, "Grouped tensor ", name, " data size (", + t.data.numel(), ") must be at least logical_shape size (", expected_numel, ")"); } if (t.has_columnwise_data()) { - NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name, + NVTE_CHECK(t.columnwise_data.numel() >= expected_numel, "Grouped tensor ", name, " columnwise_data size (", t.columnwise_data.numel(), - ") must match logical_shape size (", expected_numel, ")"); + ") must be at least logical_shape size (", expected_numel, ")"); } } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 7707c68a08..fc2ef8d5b9 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -171,57 +171,121 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; +__device__ __forceinline__ size_t find_tensor_id(const int64_t *const offsets, + const size_t num_tensors, + const size_t global_offset) { + size_t low = 0; + size_t high = num_tensors; + while (low < high) { + size_t mid = low + (high - low) / 2; + if (static_cast(offsets[mid]) <= global_offset) { + low = mid + 1; + } else { + high = mid; + } + } + return low - 1; +} + template __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, - const size_t N, const size_t num_aligned_elements) { + const size_t N, const size_t num_aligned_elements, + const int64_t *offsets = nullptr, const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; - VectorizedLoader loader(input, N); - VectorizedStorer storer(output, N); + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); + loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val = static_cast(loader.separate()[i]); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; + + ComputeType val = static_cast(loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } + } ComputeType temp = OP(val, p); if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); } if constexpr (is_fp8::value) { - temp = temp * s; + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; + } + temp = temp * current_scale; } storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); + } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + atomicMaxFloat(&amax[amax_idx], max); } } - if constexpr (is_fp8::value) { - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); + if (offsets == nullptr && num_tensors == 1) { + if constexpr (is_fp8::value) { + // Update scale-inverse for single-tensor path + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } } } } @@ -232,54 +296,102 @@ template loader(input, N); - VectorizedLoader grad_loader(grad, N); - VectorizedStorer storer(output, N); + Param p, const size_t N, const size_t num_aligned_elements, + const int64_t *offsets = nullptr, const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, + size_t amax_numel = 1) { + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedLoader grad_loader(grad + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); - grad_loader.load(tid, N); + loader.load(tid, size); + grad_loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val = static_cast(loader.separate()[i]); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; + + ComputeType val = static_cast(loader.separate()[i]); const ComputeType g = static_cast(grad_loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } + } ComputeType temp = OP(val, p) * g; if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); } if constexpr (is_fp8::value) { - temp = temp * s; + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; + } + temp = temp * current_scale; } - storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); + } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + atomicMaxFloat(&amax[amax_idx], max); } } - if constexpr (is_fp8::value) { - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); + if (offsets == nullptr && num_tensors == 1) { + if constexpr (is_fp8::value) { + // Update scale-inverse for single-tensor path + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } } } } @@ -309,7 +421,7 @@ inline int CalcAlignment(const void *ptr, const int size) { \param other_dim The size of the other dimensions of the tensors. \param nvec Length of the vector. \param ptrs Inputs and Outputs to the operator. -*/ + */ template Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) { std::vector alignments; @@ -338,29 +450,37 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param ¶ms, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream, + const int64_t *offsets = nullptr, + const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, + size_t amax_numel = 1) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; - size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - num_blocks = std::min(num_blocks, max_blocks); + const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); + const dim3 grid(num_blocks, num_tensors); switch (align) { case Alignment::SAME_ALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -373,29 +493,36 @@ template <<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } }