From 12c7d9a232cb2932ebb5dc4206626dc65a41867f Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 6 Jun 2026 11:55:08 -0700 Subject: [PATCH 1/8] feat: implement fp8 quantize Signed-off-by: Abhishek --- tests/cpp/operator/CMakeLists.txt | 2 + tests/cpp/operator/test_cast_fp8_grouped.cu | 165 +++++++++++ .../operator/test_dequantize_fp8_grouped.cu | 131 +++++++++ tests/cpp/test_common.cu | 42 ++- tests/cpp/test_common.h | 1 + .../common/cast/dispatch/dequantize.cuh | 4 + .../common/cast/dispatch/quantize.cuh | 11 + .../common/cast/fp8/dequantize_fp8.cuh | 31 ++ .../common/cast/fp8/quantize_fp8.cuh | 66 +++++ .../common/transformer_engine.cpp | 8 +- .../common/util/vectorized_pointwise.h | 268 +++++++++++++++--- 11 files changed, 672 insertions(+), 57 deletions(-) create mode 100644 tests/cpp/operator/test_cast_fp8_grouped.cu create mode 100644 tests/cpp/operator/test_dequantize_fp8_grouped.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..343653b43d 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -12,10 +12,12 @@ add_executable(test_operator test_qdq.cu test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu + test_cast_fp8_grouped.cu test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_dequantize_mxfp8_grouped.cu + test_dequantize_fp8_grouped.cu test_dequantize_nvfp4.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_fp8_grouped.cu b/tests/cpp/operator/test_cast_fp8_grouped.cu new file mode 100644 index 0000000000..1d3b56118b --- /dev/null +++ b/tests/cpp/operator/test_cast_fp8_grouped.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void test_cast_fp8_grouped_impl(const std::vector>& shapes, + DType input_dtype, DType output_dtype) { + const size_t num_tensors = shapes.size(); + + // Create standard Tensor objects + std::vector in_tensors; + std::vector out_tensors; + std::vector in_tensor_ptrs; + std::vector out_tensor_ptrs; + + in_tensors.reserve(num_tensors); + out_tensors.reserve(num_tensors); + in_tensor_ptrs.reserve(num_tensors); + out_tensor_ptrs.reserve(num_tensors); + + for (size_t t = 0; t < num_tensors; ++t) { + in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype); + out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype, + true, false, NVTE_DELAYED_TENSOR_SCALING); + + // Initialize inputs with random uniform values + fillUniform(&in_tensors[t]); + in_tensors[t].from_cpu(); + + // Initialize scales with random scaling factors + float random_scale = 1.5f + static_cast(t) * 0.5f; + out_tensors[t].set_scale(random_scale); + out_tensors[t].set_scale_inv(0.0f); // Clear to ensure it's written + out_tensors[t].set_amax(0.0f); // Clear amax + + in_tensor_ptrs.push_back(&in_tensors[t]); + out_tensor_ptrs.push_back(&out_tensors[t]); + } + + // Build grouped tensors + GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + + // CPU reference computation + std::vector> ref_outputs(num_tensors); + std::vector ref_amaxs(num_tensors, 0.0f); + std::vector ref_scale_invs(num_tensors, 0.0f); + + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + ref_outputs[t].resize(size); + float scale = out_tensors[t].scale(); + float cur_amax = 0.0f; + + InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float val = static_cast(in_cpu[i]); + cur_amax = std::max(cur_amax, std::abs(val)); + float scaled_val = val * scale; + ref_outputs[t][i] = scaled_val; + } + ref_amaxs[t] = cur_amax; + ref_scale_invs[t] = 1.0f / scale; + } + + // Run GPU grouped quantization + QuantizationConfigWrapper quant_config; + nvte_group_quantize(in_group.get_handle(), out_group.get_handle(), quant_config, 0); + cudaDeviceSynchronize(); + + // Copy results back from grouped buffer to individual output tensors + for (size_t t = 0; t < num_tensors; ++t) { + // 1. Copy output data + size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8; + NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(), + static_cast(out_group.get_data()) + offset_bytes, + out_group.tensor_bytes[t], + cudaMemcpyDeviceToDevice)); + + // 2. Copy scale_inv + NVTEBasicTensor scale_inv_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTERowwiseScaleInv); + if (scale_inv_bt.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(scale_inv_bt.data_ptr, + static_cast(out_group.scale_inv.get()) + t, + sizeof(float), + cudaMemcpyDeviceToDevice)); + } + + // 3. Copy amax + NVTEBasicTensor amax_bt = nvte_get_tensor_param(out_tensors[t].data(), kNVTEAmax); + if (amax_bt.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpy(amax_bt.data_ptr, + static_cast(out_group.amax_dev.get()) + t, + sizeof(float), + cudaMemcpyDeviceToDevice)); + } + } + + // Validate results + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + out_tensors[t].to_cpu(); + + // 1. Compare scale inverse + float scale_inv_gpu = out_tensors[t].rowwise_scale_inv(); + EXPECT_NEAR(scale_inv_gpu, ref_scale_invs[t], 1e-4); + + // 2. Compare amax + float amax_gpu = out_tensors[t].amax(); + EXPECT_NEAR(amax_gpu, ref_amaxs[t], 1e-4); + + // 3. Compare outputs + OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float gpu_val = static_cast(out_cpu[i]) / out_tensors[t].scale(); + float ref_val = ref_outputs[t][i] / out_tensors[t].scale(); + // Since it's FP8 casting, check within small quantization limits + EXPECT_NEAR(gpu_val, ref_val, 0.05f); + } + } +} + +class CastFP8GroupedTestSuite : public ::testing::Test {}; + +TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Uniform) { + test_cast_fp8_grouped_impl( + {{32, 64}, {32, 64}, {32, 64}}, + DType::kBFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, BF16_to_E4M3_Varying) { + test_cast_fp8_grouped_impl( + {{16, 32}, {64, 128}, {32, 64}}, + DType::kBFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, FP16_to_E4M3_Varying) { + test_cast_fp8_grouped_impl( + {{8, 16}, {128, 64}, {64, 32}}, + DType::kFloat16, DType::kFloat8E4M3); +} + +TEST_F(CastFP8GroupedTestSuite, FP32_to_E5M2_Varying) { + test_cast_fp8_grouped_impl( + {{32, 32}, {16, 64}, {128, 32}}, + DType::kFloat32, DType::kFloat8E5M2); +} + +} // namespace diff --git a/tests/cpp/operator/test_dequantize_fp8_grouped.cu b/tests/cpp/operator/test_dequantize_fp8_grouped.cu new file mode 100644 index 0000000000..378a345abb --- /dev/null +++ b/tests/cpp/operator/test_dequantize_fp8_grouped.cu @@ -0,0 +1,131 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void test_dequantize_fp8_grouped_impl(const std::vector>& shapes, + DType input_dtype, DType output_dtype) { + const size_t num_tensors = shapes.size(); + + // Create standard Tensor objects + std::vector in_tensors; + std::vector out_tensors; + std::vector in_tensor_ptrs; + std::vector out_tensor_ptrs; + + in_tensors.reserve(num_tensors); + out_tensors.reserve(num_tensors); + in_tensor_ptrs.reserve(num_tensors); + out_tensor_ptrs.reserve(num_tensors); + + for (size_t t = 0; t < num_tensors; ++t) { + // Input is FP8 (with scale_inv) + in_tensors.emplace_back("in_" + std::to_string(t), shapes[t], input_dtype, + true, false, NVTE_DELAYED_TENSOR_SCALING); + // Output is higher precision + out_tensors.emplace_back("out_" + std::to_string(t), shapes[t], output_dtype); + + // Initialize inputs with random uniform FP8 values + fillUniform(&in_tensors[t]); + in_tensors[t].from_cpu(); + + // Set scale_inv + float random_scale_inv = 0.5f + static_cast(t) * 0.25f; + in_tensors[t].set_scale_inv(random_scale_inv); + + // Clear output + fillUniform(&out_tensors[t]); // Initialize to some random values + out_tensors[t].from_cpu(); + + in_tensor_ptrs.push_back(&in_tensors[t]); + out_tensor_ptrs.push_back(&out_tensors[t]); + } + + // Build grouped tensors + GroupedBuffers in_group = build_grouped_tensor(in_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + GroupedBuffers out_group = build_grouped_tensor(out_tensor_ptrs, NVTE_DELAYED_TENSOR_SCALING); + + // CPU reference computation + std::vector> ref_outputs(num_tensors); + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + ref_outputs[t].resize(size); + float scale_inv = in_tensors[t].rowwise_scale_inv(); + + InputType* in_cpu = in_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float val = static_cast(in_cpu[i]); + ref_outputs[t][i] = val * scale_inv; + } + } + + // Run GPU grouped dequantization + nvte_group_dequantize(in_group.get_handle(), out_group.get_handle(), 0); + cudaDeviceSynchronize(); + + // Copy results back from grouped buffer to individual output tensors + for (size_t t = 0; t < num_tensors; ++t) { + size_t offset_bytes = (out_group.offsets_host[t] * typeToNumBits(out_group.dtype)) / 8; + NVTE_CHECK_CUDA(cudaMemcpy(out_tensors[t].rowwise_dptr(), + static_cast(out_group.get_data()) + offset_bytes, + out_group.tensor_bytes[t], + cudaMemcpyDeviceToDevice)); + } + + // Validate results + for (size_t t = 0; t < num_tensors; ++t) { + size_t size = product(shapes[t]); + out_tensors[t].to_cpu(); + + OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + float gpu_val = static_cast(out_cpu[i]); + float ref_val = ref_outputs[t][i]; + EXPECT_NEAR(gpu_val, ref_val, 1e-4); + } + } +} + +class DequantizeFP8GroupedTestSuite : public ::testing::Test {}; + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Uniform) { + test_dequantize_fp8_grouped_impl( + {{32, 64}, {32, 64}, {32, 64}}, + DType::kFloat8E4M3, DType::kBFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_BF16_Varying) { + test_dequantize_fp8_grouped_impl( + {{16, 32}, {64, 128}, {32, 64}}, + DType::kFloat8E4M3, DType::kBFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E4M3_to_FP16_Varying) { + test_dequantize_fp8_grouped_impl( + {{8, 16}, {128, 64}, {64, 32}}, + DType::kFloat8E4M3, DType::kFloat16); +} + +TEST_F(DequantizeFP8GroupedTestSuite, E5M2_to_FP32_Varying) { + test_dequantize_fp8_grouped_impl( + {{32, 32}, {16, 64}, {128, 32}}, + DType::kFloat8E5M2, DType::kFloat32); +} + +} // namespace diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index fc41d44720..ef0adbda89 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1108,10 +1108,10 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, [&](int64_t v) { return v == last_dims[0]; }); std::vector offsets(num_tensors + 1, 0); + std::mt19937 gen(12345); auto random_padding = [&]() -> int64_t { // Random padding ensuring 16-byte alignment regardless of element size // cuBLAS requires aligned pointers for vectorized loads - static std::mt19937 gen(12345); std::uniform_int_distribution dist(0, 3); // Calculate elements needed for 16-byte alignment size_t align_elements; @@ -1263,6 +1263,8 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, // FP8 tensor scaling: one float scale_inv per tensor // For delayed scaling, rowwise and columnwise share the same scale std::vector scale_inv_cpu(num_tensors, 1.f); + std::vector scale_cpu(num_tensors, 1.f); + std::vector amax_cpu(num_tensors, 0.f); for (size_t i = 0; i < num_tensors; ++i) { tensors[i]->to_cpu(); if (has_rowwise) { @@ -1270,16 +1272,44 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, } else { scale_inv_cpu[i] = tensors[i]->columnwise_cpu_scale_inv_ptr()[0]; } + // Gather scale + NVTEBasicTensor scale_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEScale); + if (scale_bt.data_ptr != nullptr) { + float val; + NVTE_CHECK_CUDA(cudaMemcpy(&val, scale_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + scale_cpu[i] = val; + } + // Gather amax + NVTEBasicTensor amax_bt = nvte_get_tensor_param(tensors[i]->data(), kNVTEAmax); + if (amax_bt.data_ptr != nullptr) { + float val; + NVTE_CHECK_CUDA(cudaMemcpy(&val, amax_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost)); + amax_cpu[i] = val; + } } grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor, - sizeof(scale_tensor)); - nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, - sizeof(scale_tensor)); + NVTEBasicTensor scale_inv_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_inv_tensor, + sizeof(scale_inv_tensor)); + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_inv_tensor, + sizeof(scale_inv_tensor)); + + // Set scale on the grouped tensor + grouped.scale = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale.get(), scale_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEBasicTensor scale_tensor{grouped.scale.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedScale, &scale_tensor, sizeof(scale_tensor)); + + // Set amax on the grouped tensor + grouped.amax_dev = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.amax_dev.get(), amax_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEBasicTensor amax_tensor{grouped.amax_dev.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(h, kNVTEGroupedAmax, &amax_tensor, sizeof(amax_tensor)); } else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { // MXFP8: E8M0 scale_inv per block of 32 elements (1 byte per scale element). if (has_rowwise) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 11d96c2e60..d01cba594b 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -594,6 +594,7 @@ struct GroupedBuffers { CudaPtr<> data; CudaPtr<> scale_inv; CudaPtr<> columnwise_scale_inv; + CudaPtr<> scale; CudaPtr first_dims_dev; CudaPtr last_dims_dev; CudaPtr offsets_dev; diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 63c1b046ff..ebfef5b485 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -61,6 +61,10 @@ inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *o CheckOutputGroupedTensor(*output, "group_dequantize_output"); switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_dequantize(input, output, stream); + break; + } case NVTE_MXFP8_1D_SCALING: { if (is_supported_by_CC_100()) { mxfp8::group_dequantize(&input, output, stream); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index bad53a03c6..5be00382fc 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -450,6 +450,11 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_quantize( + *input_tensor, noop_tensor, output_tensor, stream); + break; + } case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, @@ -491,6 +496,12 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe // Dispatch to quantization kernel depending on data format switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::group_quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 6a0eaf94fb..6baaf89a59 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -47,6 +47,37 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) stream);); // NOLINT(*) ); // NOLINT(*) } +struct GroupDequantizeParam {}; + +__device__ inline float group_dequantize_func(float value, const GroupDequantizeParam &) { + return value; +} + +inline void group_dequantize(const GroupedTensor &input, GroupedTensor *output, + cudaStream_t stream) { + const size_t N = product(input.data.shape); + const size_t scale_inv_numel = product(input.scale_inv.shape); + + const int64_t *const offsets = reinterpret_cast(input.tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input.first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input.last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + constexpr int nvec = 32 / sizeof(OType); + GroupDequantizeParam p; + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, + const_cast(reinterpret_cast(input.scale_inv.dptr)), N, p, stream, + offsets, first_dims, last_dims, input.num_tensors, + 1, scale_inv_numel, 1); + ); + ); +} + } // namespace fp8 } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index bad10c954e..4a548374fe 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -572,6 +572,72 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } +template +void group_quantize(const GroupedTensor &input, const Tensor *noop, GroupedTensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + const size_t scale_numel = product(output->scale.shape); + const size_t scale_inv_numel = product(output->scale_inv.shape); + const size_t amax_numel = product(output->amax.shape); + + const int64_t *const offsets = reinterpret_cast(input.tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input.first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input.last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, + offsets, first_dims, last_dims, input.num_tensors, + scale_numel, scale_inv_numel, amax_numel); + ); + ); +} + +template +void group_quantize(const GroupedTensor &grad, const GroupedTensor *input, const Tensor *noop, + GroupedTensor *output, GroupedTensor *dbias, Tensor *workspace, + cudaStream_t stream) { + NVTE_CHECK(!IS_DBIAS && !IS_DACT, "Gated or DBias fusions are not supported in FP8 Grouped Quantization."); + + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + const size_t scale_numel = product(output->scale.shape); + const size_t scale_inv_numel = product(output->scale_inv.shape); + const size_t amax_numel = product(output->amax.shape); + + const int64_t *const offsets = reinterpret_cast(input->tensor_offsets.dptr); + const int64_t *const first_dims = reinterpret_cast(input->first_dims.dptr); + const int64_t *const last_dims = reinterpret_cast(input->last_dims.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, + offsets, first_dims, last_dims, input->num_tensors, + scale_numel, scale_inv_numel, amax_numel); + ); + ); +} + } // namespace fp8 } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b3179d38fd..73a3cfbe8d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -317,13 +317,13 @@ void CheckGroupedTensorShapeArrays(const GroupedTensor &t, std::string_view name // Validate data size matches logical_shape size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1]; if (t.has_data()) { - NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (", - t.data.numel(), ") must match logical_shape size (", expected_numel, ")"); + NVTE_CHECK(t.data.numel() >= expected_numel, "Grouped tensor ", name, " data size (", + t.data.numel(), ") must be at least logical_shape size (", expected_numel, ")"); } if (t.has_columnwise_data()) { - NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name, + NVTE_CHECK(t.columnwise_data.numel() >= expected_numel, "Grouped tensor ", name, " columnwise_data size (", t.columnwise_data.numel(), - ") must match logical_shape size (", expected_numel, ")"); + ") must be at least logical_shape size (", expected_numel, ")"); } } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 7707c68a08..61428ac0a1 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -171,12 +171,36 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; +__device__ __forceinline__ size_t find_tensor_id( + const int64_t *const offsets, + const size_t num_tensors, + const size_t global_offset) { + size_t low = 0; + size_t high = num_tensors; + while (low < high) { + size_t mid = low + (high - low) / 2; + if (static_cast(offsets[mid]) <= global_offset) { + low = mid + 1; + } else { + high = mid; + } + } + return low - 1; +} + template __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, - const size_t N, const size_t num_aligned_elements) { + const size_t N, const size_t num_aligned_elements, + const int64_t *offsets = nullptr, + const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, + size_t num_tensors = 1, + size_t scale_numel = 1, + size_t scale_inv_numel = 1, + size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; VectorizedLoader loader(input, N); @@ -185,43 +209,105 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; + float block_max[64] = {0.0f}; + const size_t M = num_aligned_elements; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val = static_cast(loader.separate()[i]); - ComputeType temp = OP(val, p); - if (requires_amax) { - __builtin_assume(max >= 0); - max = fmaxf(fabsf(temp), max); + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= N) continue; + + size_t tensor_id = 0; + bool is_valid = true; + if (offsets != nullptr) { + tensor_id = find_tensor_id(offsets, num_tensors, global_idx); + size_t start = offsets[tensor_id]; + size_t size = first_dims[tensor_id] * last_dims[tensor_id]; + is_valid = (global_idx >= start && global_idx < start + size); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + tensor_id = global_idx / size; + if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; } - if constexpr (is_fp8::value) { - temp = temp * s; + + if (is_valid) { + ComputeType val = static_cast(loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } + } + ComputeType temp = OP(val, p); + if (requires_amax) { + __builtin_assume(block_max[tensor_id] >= 0); + block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + bool is_start = false; + if (offsets != nullptr) { + is_start = (global_idx == offsets[tensor_id]); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + is_start = (global_idx == tensor_id * size); + } + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); + } + } else { + storer.separate()[i] = OutputType(); } - storer.separate()[i] = static_cast(temp); } storer.store(tid, N); } // Reduce amax over block if (requires_amax) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + if (offsets != nullptr || num_tensors > 1) { + for (size_t t = 0; t < num_tensors; ++t) { + float t_max = block_max[t]; + t_max = reduce_max(t_max, warp_id); + if (threadIdx.x == 0 && t_max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? t : 0; + atomicMaxFloat(&amax[amax_idx], t_max); + } + } + } else { + max = block_max[0]; + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } } } - if constexpr (is_fp8::value) { - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); + if (offsets == nullptr && num_tensors == 1) { + if constexpr (is_fp8::value) { + // Update scale-inverse for single-tensor path + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } } } } @@ -232,7 +318,14 @@ template loader(input, N); VectorizedLoader grad_loader(grad, N); VectorizedStorer storer(output, N); @@ -240,10 +333,12 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; + float block_max[64] = {0.0f}; + const size_t M = num_aligned_elements; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { @@ -251,35 +346,94 @@ __launch_bounds__(unary_kernel_threads) __global__ grad_loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val = static_cast(loader.separate()[i]); - const ComputeType g = static_cast(grad_loader.separate()[i]); - ComputeType temp = OP(val, p) * g; - if (requires_amax) { - __builtin_assume(max >= 0); - max = fmaxf(fabsf(temp), max); - } - if constexpr (is_fp8::value) { - temp = temp * s; + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= N) continue; + + size_t tensor_id = 0; + bool is_valid = true; + if (offsets != nullptr) { + tensor_id = find_tensor_id(offsets, num_tensors, global_idx); + size_t start = offsets[tensor_id]; + size_t size = first_dims[tensor_id] * last_dims[tensor_id]; + is_valid = (global_idx >= start && global_idx < start + size); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + tensor_id = global_idx / size; + if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; } - storer.separate()[i] = static_cast(temp); + if (is_valid) { + ComputeType val = static_cast(loader.separate()[i]); + const ComputeType g = static_cast(grad_loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } + } + ComputeType temp = OP(val, p) * g; + if (requires_amax) { + __builtin_assume(block_max[tensor_id] >= 0); + block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + bool is_start = false; + if (offsets != nullptr) { + is_start = (global_idx == offsets[tensor_id]); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + is_start = (global_idx == tensor_id * size); + } + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); + } + } else { + storer.separate()[i] = OutputType(); + } } storer.store(tid, N); } // Reduce amax over block if (requires_amax) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); + if (offsets != nullptr || num_tensors > 1) { + for (size_t t = 0; t < num_tensors; ++t) { + float t_max = block_max[t]; + t_max = reduce_max(t_max, warp_id); + if (threadIdx.x == 0 && t_max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? t : 0; + atomicMaxFloat(&amax[amax_idx], t_max); + } + } + } else { + max = block_max[0]; + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } } } - if constexpr (is_fp8::value) { - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); + if (offsets == nullptr && num_tensors == 1) { + if constexpr (is_fp8::value) { + // Update scale-inverse for single-tensor path + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } } } } @@ -309,7 +463,7 @@ inline int CalcAlignment(const void *ptr, const int size) { \param other_dim The size of the other dimensions of the tensors. \param nvec Length of the vector. \param ptrs Inputs and Outputs to the operator. -*/ + */ template Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) { std::vector alignments; @@ -338,7 +492,14 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param ¶ms, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream, + const int64_t *offsets = nullptr, + const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, + size_t num_tensors = 1, + size_t scale_numel = 1, + size_t scale_inv_numel = 1, + size_t amax_numel = 1) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -351,16 +512,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N); + input, noop, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -373,7 +537,14 @@ template <<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N); + grad, input, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } From 6aa2bc07884490042b10fbb3d6181430ab7f8e94 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 6 Jun 2026 21:28:17 -0700 Subject: [PATCH 2/8] Made tests run Signed-off-by: Abhishek --- tests/cpp/operator/test_cast_fp8_grouped.cu | 7 +++---- tests/cpp/test_common.cu | 13 ++++--------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/cpp/operator/test_cast_fp8_grouped.cu b/tests/cpp/operator/test_cast_fp8_grouped.cu index 1d3b56118b..63cfb3bd1a 100644 --- a/tests/cpp/operator/test_cast_fp8_grouped.cu +++ b/tests/cpp/operator/test_cast_fp8_grouped.cu @@ -128,10 +128,9 @@ void test_cast_fp8_grouped_impl(const std::vector>& shapes, // 3. Compare outputs OutputType* out_cpu = out_tensors[t].rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { - float gpu_val = static_cast(out_cpu[i]) / out_tensors[t].scale(); - float ref_val = ref_outputs[t][i] / out_tensors[t].scale(); - // Since it's FP8 casting, check within small quantization limits - EXPECT_NEAR(gpu_val, ref_val, 0.05f); + float gpu_val = static_cast(out_cpu[i]); + float ref_val = static_cast(OutputType(ref_outputs[t][i])); + EXPECT_NEAR(gpu_val, ref_val, 1e-4); } } } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ef0adbda89..c56a156680 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1111,16 +1111,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, std::mt19937 gen(12345); auto random_padding = [&]() -> int64_t { // Random padding ensuring 16-byte alignment regardless of element size - // cuBLAS requires aligned pointers for vectorized loads + // We use a constant 64 elements alignment. This guarantees that all + // grouped tensors (input and output) in pointwise operations will + // have identical element offsets, preventing layout misalignment in tests. std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment - size_t align_elements; - if (is_sub_byte) { - // Sub-byte types (e.g. FP4): 16 bytes = 16*8/bits_per_elem elements - align_elements = (16 * 8) / bits_per_elem; - } else { - align_elements = std::max(1, (16 + elem_size - 1) / elem_size); - } + size_t align_elements = 64; return dist(gen) * static_cast(align_elements); }; From 67ff6bf3611738ab740b5e8e5a6b73f4149c3183 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 7 Jun 2026 05:06:06 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/dispatch/quantize.cuh | 7 +- .../common/cast/fp8/dequantize_fp8.cuh | 11 +-- .../common/cast/fp8/quantize_fp8.cuh | 23 +++--- .../common/util/vectorized_pointwise.h | 71 ++++++++----------- 4 files changed, 44 insertions(+), 68 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5be00382fc..009a5d0d20 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -451,8 +451,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - fp8::group_quantize( - *input_tensor, noop_tensor, output_tensor, stream); + fp8::group_quantize(*input_tensor, noop_tensor, output_tensor, stream); break; } case NVTE_MXFP8_1D_SCALING: { @@ -498,8 +497,8 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { fp8::group_quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); break; } case NVTE_MXFP8_1D_SCALING: { diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 6baaf89a59..b637f792bc 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -65,17 +65,12 @@ inline void group_dequantize(const GroupedTensor &input, GroupedTensor *output, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->dtype(), OType, - constexpr int nvec = 32 / sizeof(OType); - GroupDequantizeParam p; + output->dtype(), OType, constexpr int nvec = 32 / sizeof(OType); GroupDequantizeParam p; VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), nullptr, reinterpret_cast(output->data.dptr), nullptr, nullptr, - const_cast(reinterpret_cast(input.scale_inv.dptr)), N, p, stream, - offsets, first_dims, last_dims, input.num_tensors, - 1, scale_inv_numel, 1); - ); - ); + const_cast(reinterpret_cast(input.scale_inv.dptr)), N, p, + stream, offsets, first_dims, last_dims, input.num_tensors, 1, scale_inv_numel, 1););); } } // namespace fp8 diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index 4a548374fe..aa76dc0167 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -588,19 +588,15 @@ void group_quantize(const GroupedTensor &input, const Tensor *noop, GroupedTenso TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - constexpr int nvec = 32 / sizeof(IType); + output->dtype(), OType, constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryKernelLauncher( reinterpret_cast(input.data.dptr), reinterpret_cast(noop->data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream, - offsets, first_dims, last_dims, input.num_tensors, - scale_numel, scale_inv_numel, amax_numel); - ); - ); + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, offsets, first_dims, + last_dims, input.num_tensors, scale_numel, scale_inv_numel, amax_numel););); } template data.shape); @@ -623,19 +620,15 @@ void group_quantize(const GroupedTensor &grad, const GroupedTensor *input, const TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - constexpr int nvec = 32 / sizeof(IType); + output->dtype(), OType, constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryGradKernelLauncher( reinterpret_cast(grad.data.dptr), reinterpret_cast(input->data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream, - offsets, first_dims, last_dims, input->num_tensors, - scale_numel, scale_inv_numel, amax_numel); - ); - ); + reinterpret_cast(output->scale_inv.dptr), N, {}, stream, offsets, first_dims, + last_dims, input->num_tensors, scale_numel, scale_inv_numel, amax_numel););); } } // namespace fp8 diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 61428ac0a1..3fa7f4c8ee 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -171,10 +171,9 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; -__device__ __forceinline__ size_t find_tensor_id( - const int64_t *const offsets, - const size_t num_tensors, - const size_t global_offset) { +__device__ __forceinline__ size_t find_tensor_id(const int64_t *const offsets, + const size_t num_tensors, + const size_t global_offset) { size_t low = 0; size_t high = num_tensors; while (low < high) { @@ -194,13 +193,9 @@ __launch_bounds__(unary_kernel_threads) __global__ void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const size_t num_aligned_elements, - const int64_t *offsets = nullptr, - const int64_t *first_dims = nullptr, - const int64_t *last_dims = nullptr, - size_t num_tensors = 1, - size_t scale_numel = 1, - size_t scale_inv_numel = 1, - size_t amax_numel = 1) { + const int64_t *offsets = nullptr, const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; VectorizedLoader loader(input, N); @@ -221,7 +216,8 @@ __launch_bounds__(unary_kernel_threads) __global__ loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= N) continue; size_t tensor_id = 0; @@ -319,12 +315,9 @@ __launch_bounds__(unary_kernel_threads) __global__ void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, const size_t num_aligned_elements, - const int64_t *offsets = nullptr, - const int64_t *first_dims = nullptr, - const int64_t *last_dims = nullptr, - size_t num_tensors = 1, - size_t scale_numel = 1, - size_t scale_inv_numel = 1, + const int64_t *offsets = nullptr, const int64_t *first_dims = nullptr, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { VectorizedLoader loader(input, N); VectorizedLoader grad_loader(grad, N); @@ -346,7 +339,8 @@ __launch_bounds__(unary_kernel_threads) __global__ grad_loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= N) continue; size_t tensor_id = 0; @@ -495,10 +489,8 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out const Param ¶ms, cudaStream_t stream, const int64_t *offsets = nullptr, const int64_t *first_dims = nullptr, - const int64_t *last_dims = nullptr, - size_t num_tensors = 1, - size_t scale_numel = 1, - size_t scale_inv_numel = 1, + const int64_t *last_dims = nullptr, size_t num_tensors = 1, + size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -512,19 +504,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -537,13 +529,10 @@ template <<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } From d5fb0bf479f691644a58c9a596b34534491a1ec3 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 8 Jun 2026 23:35:36 -0700 Subject: [PATCH 4/8] Removed dependency on max tensor size Signed-off-by: Abhishek --- docker_build_and_test.sh | 85 +++++ patch_swizzle.py | 260 ++++++++++++++ patch_swizzle_cpp.py | 43 +++ .../common/util/vectorized_pointwise.h | 318 ++++++++---------- 4 files changed, 528 insertions(+), 178 deletions(-) create mode 100755 docker_build_and_test.sh create mode 100644 patch_swizzle.py create mode 100644 patch_swizzle_cpp.py diff --git a/docker_build_and_test.sh b/docker_build_and_test.sh new file mode 100755 index 0000000000..d923fda192 --- /dev/null +++ b/docker_build_and_test.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +# Build TransformerEngine and C++ tests in Docker, then run tests. +# Usage: +# ./docker_build_and_test.sh # build, run all operator tests +# ./docker_build_and_test.sh --clean # clean + build + run all tests +# ./docker_build_and_test.sh --gtest_filter="*Swizzle*" +# ./docker_build_and_test.sh --clean --gtest_filter="OperatorTest/SwizzleTestSuite*" +# +# Lint only specific files (paths relative to repo root, space-separated): +# LINT_FILES="transformer_engine/common/swizzle/swizzle.cu transformer_engine/common/include/transformer_engine/swizzle.h" ./docker_build_and_test.sh + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +IMAGE="nvcr.io/nvidia/pytorch:25.02-py3" +MOUNT="${SCRIPT_DIR}:/workspace/TransformerEngine" + +DO_CLEAN="0" +TEST_ARGS=() + +while [[ $# -gt 0 ]]; do + case "$1" in + --clean) + DO_CLEAN="1" + shift + ;; + *) + TEST_ARGS+=("$1") + shift + ;; + esac +done + +docker run --gpus all -it --rm \ + -v "${MOUNT}" \ + -e DO_CLEAN="${DO_CLEAN}" \ + -e LINT_FILES="${LINT_FILES:-}" \ + "${IMAGE}" \ + bash -c ' + set -e + cd /workspace/TransformerEngine + + if [ "${DO_CLEAN}" = "1" ]; then + echo "=== Cleaning build artifacts ===" + rm -rf build/ tests/cpp/build/ *.so libtransformer_engine.so *.egg-info + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -name "*.so" -type f -delete 2>/dev/null || true + pip uninstall -y transformer-engine 2>/dev/null || true + fi + + echo "=== Building TransformerEngine ===" + NVTE_CUDA_ARCHS="89" MAX_JOBS=4 pip install --no-build-isolation -v -e . + + echo "=== Building C++ tests ===" + cd tests/cpp + cmake -GNinja -Bbuild . + cmake --build build + + echo "=== Lint checks (C++ and Python) ===" + cd /workspace/TransformerEngine + if [ -n "${LINT_FILES}" ]; then + pip3 install cpplint==1.6.0 pylint==3.3.1 -q + for f in ${LINT_FILES}; do + [ -f "$f" ] || continue + case "$f" in + *.cu|*.cuh|*.c|*.cpp|*.h|*.hpp|*.cc|*.cxx) echo "cpplint $f"; python3 -m cpplint --root=transformer_engine/common/include "$f" ;; + *.py) echo "pylint $f"; python3 -m pylint "$f" ;; + *) echo "skip (unknown type) $f" ;; + esac + done + else + TE_PATH=/workspace/TransformerEngine bash qa/L0_pytorch_lint/test.sh + TE_PATH=/workspace/TransformerEngine bash qa/L0_jax_lint/test.sh + fi + + echo "=== L0_* tests ===" + # for d in qa/L0_*/; do + # echo "--- $d ---" + # (cd /workspace/TransformerEngine && TE_PATH=/workspace/TransformerEngine bash "$d/test.sh") + # done + + echo "=== Running operator tests ===" + cd /workspace/TransformerEngine/tests/cpp + ./build/operator/test_operator "$@" + ' _ "${TEST_ARGS[@]}" \ No newline at end of file diff --git a/patch_swizzle.py b/patch_swizzle.py new file mode 100644 index 0000000000..018cf83dc6 --- /dev/null +++ b/patch_swizzle.py @@ -0,0 +1,260 @@ +import re + +with open("transformer_engine/common/swizzle/swizzle.cu", "r") as f: + content = f.read() + +# 1. Insert kernels before swizzle_grouped_scaling_factors +kernels_code = """ +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_swizzle_scaling_variable_shape_kernel( + const void* input, + void* output, + const int64_t* m_array, + const int64_t* k_array, + const int* block_offsets, + const size_t* scale_offsets, + int* global_counter, + int num_tensors, + bool rowwise) { + + __shared__ int linear_block_id; + if (threadIdx.x == 0 && threadIdx.y == 0) { + linear_block_id = atomicAdd(global_counter, 1); + } + __syncthreads(); + + int tensor_id = -1; + int low = 0; + int high = num_tensors - 1; + while (low <= high) { + int mid = low + (high - low) / 2; + if (linear_block_id >= block_offsets[mid] && linear_block_id < block_offsets[mid + 1]) { + tensor_id = mid; + break; + } else if (linear_block_id < block_offsets[mid]) { + high = mid - 1; + } else { + low = mid + 1; + } + } + + if (tensor_id == -1) return; + + int local_block_id = linear_block_id - block_offsets[tensor_id]; + + size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id]; + size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id]; + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; + + const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; + uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; + + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + + if (rowwise) { + if (vec_load_size == 4) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } + } else { + if (vec_load_size == 4) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } + } +} + +__global__ void compute_grouped_swizzle_setup( + const int64_t* m_array, + const int64_t* k_array, + int* block_offsets, + size_t* scale_offsets, + int* total_blocks, + int* global_counter, + size_t num_tensors, + bool rowwise, + size_t scale_elem_size) { + + if (blockIdx.x == 0 && threadIdx.x == 0) { + int current_block_offset = 0; + size_t current_scale_offset = 0; + + for (size_t i = 0; i < num_tensors; ++i) { + block_offsets[i] = current_block_offset; + scale_offsets[i] = current_scale_offset; + + size_t m = rowwise ? m_array[i] : k_array[i]; + size_t k = rowwise ? k_array[i] : m_array[i]; + + size_t padded_m = round_up_to_multiple(m, 128); + size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / 128; + int num_tiles_k = padded_k / 4; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + + int blocks_m = num_tiles_m; + int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); + if (!rowwise) { + blocks_m = DIVUP(num_tiles_m, vec_load_size); + blocks_k = DIVUP(num_tiles_k, TB_DIM); + } + + current_block_offset += blocks_m * blocks_k; + current_scale_offset += padded_m * padded_k * scale_elem_size; + } + + block_offsets[num_tensors] = current_block_offset; + scale_offsets[num_tensors] = current_scale_offset; + *total_blocks = current_block_offset; + *global_counter = 0; + } +} + +namespace transformer_engine { +""" +content = content.replace("namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors", kernels_code + "\nvoid swizzle_grouped_scaling_factors") + +# 2. Modify swizzle_grouped_scaling_factors +old_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, + cudaStream_t stream) {""" +new_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, + void* workspace, cudaStream_t stream) {""" +content = content.replace(old_func, new_func) + +# 3. Add variable shape logic +old_logic = """ // Only support uniform shapes for graph-safe grouped swizzle + NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); + NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), + "Grouped swizzle requires uniform tensor shapes.");""" +new_logic = """ const int64_t* m_array = reinterpret_cast(input->first_dims.data_ptr); + const int64_t* k_array = reinterpret_cast(input->last_dims.data_ptr); + const bool is_variable_shape = (m_array != nullptr && k_array != nullptr); + + if (!is_variable_shape) { + // Fallback to uniform shape implementation + NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); + NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), + "Grouped swizzle requires uniform tensor shapes.");""" +content = content.replace(old_logic, new_logic) + +# Close the if block and add the else block for variable shape +old_launch_end = """ if (has_rowwise_scale_inv) { + launch_grouped_swizzle(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle(false); + }""" +new_launch_end = """ if (has_rowwise_scale_inv) { + launch_grouped_swizzle(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle(false); + } + } else { + // Variable shape implementation using Device-Side Block Scheduler + size_t num_tensors = input->num_tensors; + NVTE_CHECK(workspace != nullptr, "Workspace must be provided for variable shape grouped swizzle."); + + int* d_block_offsets = reinterpret_cast(workspace); + size_t* d_scale_offsets = reinterpret_cast(d_block_offsets + num_tensors + 2); + int* d_global_counter = reinterpret_cast(d_scale_offsets + num_tensors + 1); + int* d_total_blocks = d_global_counter + 1; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + auto launch_grouped_swizzle_variable = [&](bool rowwise) { + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + + compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( + m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, + d_global_counter, num_tensors, rowwise, scale_elem_size); + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_scaling_variable_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); + + int persistent_blocks = 108 * 8; + dim3 num_blocks(persistent_blocks); + + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + grouped_swizzle_scaling_variable_shape_kernel + <<>>( + input_ptr, output_ptr, m_array, k_array, d_block_offsets, + d_scale_offsets, d_global_counter, num_tensors, rowwise); + + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (has_rowwise_scale_inv) { + launch_grouped_swizzle_variable(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle_variable(false); + } + }""" +content = content.replace(old_launch_end, new_launch_end) + +# 4. Modify nvte_swizzle_grouped_scaling_factors wrapper +old_wrapper = """void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); + using namespace transformer_engine; + swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), stream); +}""" +new_wrapper = """void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + void* workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); + using namespace transformer_engine; + swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), workspace, stream); +}""" +content = content.replace(old_wrapper, new_wrapper) + +with open("transformer_engine/common/swizzle/swizzle.cu", "w") as f: + f.write(content) + diff --git a/patch_swizzle_cpp.py b/patch_swizzle_cpp.py new file mode 100644 index 0000000000..53ca6f3a60 --- /dev/null +++ b/patch_swizzle_cpp.py @@ -0,0 +1,43 @@ +import re + +with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "r") as f: + content = f.read() + +old_code = """ swizzle_output.set_with_gemm_swizzled_scales(true); + NVTE_SCOPED_GIL_RELEASE({ + nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), + at::cuda::getCurrentCUDAStream()); + });""" + +new_code = """ swizzle_output.set_with_gemm_swizzled_scales(true); + + size_t num_tensors = input.num_tensors(); + size_t workspace_size = (num_tensors + 2) * sizeof(int) + (num_tensors + 1) * sizeof(size_t); + workspace_size = roundup(workspace_size, 256); + auto workspace = allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), + getDataPtr(workspace), + at::cuda::getCurrentCUDAStream()); + });""" + +content = content.replace(old_code, new_code) + +# Check if first_dims error check exists and remove it +old_check = """ const auto first_dims = input.get_first_dims(); + const auto last_dims = input.get_last_dims(); + if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { + NVTE_ERROR( + "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " + "absent)."); + }""" + +new_check = """ const auto first_dims = input.get_first_dims(); + const auto last_dims = input.get_last_dims();""" + +content = content.replace(old_check, new_check) + +with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "w") as f: + f.write(content) + diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 3fa7f4c8ee..af7097e660 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -198,103 +198,84 @@ __launch_bounds__(unary_kernel_threads) __global__ size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; - VectorizedLoader loader(input, N); - VectorizedStorer storer(output, N); + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - float block_max[64] = {0.0f}; - - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); + loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = - (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= N) continue; - - size_t tensor_id = 0; - bool is_valid = true; - if (offsets != nullptr) { - tensor_id = find_tensor_id(offsets, num_tensors, global_idx); - size_t start = offsets[tensor_id]; - size_t size = first_dims[tensor_id] * last_dims[tensor_id]; - is_valid = (global_idx >= start && global_idx < start + size); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - tensor_id = global_idx / size; - if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; - } + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; - if (is_valid) { - ComputeType val = static_cast(loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); - } - } - ComputeType temp = OP(val, p); - if (requires_amax) { - __builtin_assume(block_max[tensor_id] >= 0); - block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + ComputeType val = static_cast(loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - bool is_start = false; - if (offsets != nullptr) { - is_start = (global_idx == offsets[tensor_id]); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - is_start = (global_idx == tensor_id * size); + } + ComputeType temp = OP(val, p); + if (requires_amax) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - } else { - storer.separate()[i] = OutputType(); + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { - if (offsets != nullptr || num_tensors > 1) { - for (size_t t = 0; t < num_tensors; ++t) { - float t_max = block_max[t]; - t_max = reduce_max(t_max, warp_id); - if (threadIdx.x == 0 && t_max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? t : 0; - atomicMaxFloat(&amax[amax_idx], t_max); - } - } - } else { - max = block_max[0]; - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + max = reduce_max(max, warp_id); + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; + static_assert(std::is_same::value); + atomicMaxFloat(&amax[amax_idx], max); } } @@ -319,106 +300,87 @@ __launch_bounds__(unary_kernel_threads) __global__ const int64_t *last_dims = nullptr, size_t num_tensors = 1, size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { - VectorizedLoader loader(input, N); - VectorizedLoader grad_loader(grad, N); - VectorizedStorer storer(output, N); + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedLoader grad_loader(grad + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - float block_max[64] = {0.0f}; - - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); - grad_loader.load(tid, N); + loader.load(tid, size); + grad_loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = - (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= N) continue; - - size_t tensor_id = 0; - bool is_valid = true; - if (offsets != nullptr) { - tensor_id = find_tensor_id(offsets, num_tensors, global_idx); - size_t start = offsets[tensor_id]; - size_t size = first_dims[tensor_id] * last_dims[tensor_id]; - is_valid = (global_idx >= start && global_idx < start + size); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - tensor_id = global_idx / size; - if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; - } - - if (is_valid) { - ComputeType val = static_cast(loader.separate()[i]); - const ComputeType g = static_cast(grad_loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); - } - } - ComputeType temp = OP(val, p) * g; - if (requires_amax) { - __builtin_assume(block_max[tensor_id] >= 0); - block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; + + ComputeType val = static_cast(loader.separate()[i]); + const ComputeType g = static_cast(grad_loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - bool is_start = false; - if (offsets != nullptr) { - is_start = (global_idx == offsets[tensor_id]); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - is_start = (global_idx == tensor_id * size); + } + ComputeType temp = OP(val, p) * g; + if (requires_amax) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - } else { - storer.separate()[i] = OutputType(); + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { - if (offsets != nullptr || num_tensors > 1) { - for (size_t t = 0; t < num_tensors; ++t) { - float t_max = block_max[t]; - t_max = reduce_max(t_max, warp_id); - if (threadIdx.x == 0 && t_max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? t : 0; - atomicMaxFloat(&amax[amax_idx], t_max); - } - } - } else { - max = block_max[0]; - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + max = reduce_max(max, warp_id); + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; + static_assert(std::is_same::value); + atomicMaxFloat(&amax[amax_idx], max); } } @@ -497,26 +459,26 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; - size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - num_blocks = std::min(num_blocks, max_blocks); + const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); + const dim3 grid(num_blocks, num_tensors); switch (align) { case Alignment::SAME_ALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, - last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, noop, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -539,26 +501,26 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; - size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - num_blocks = std::min(num_blocks, max_blocks); + const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); + const dim3 grid(num_blocks, num_tensors); switch (align) { case Alignment::SAME_ALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, - last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } From 2968700047f99eeeb2dfc87c272d789e15b99690 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 06:40:58 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docker_build_and_test.sh | 2 +- patch_swizzle.py | 52 ++++++++++--------- patch_swizzle_cpp.py | 3 +- .../common/util/vectorized_pointwise.h | 30 ++++++----- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/docker_build_and_test.sh b/docker_build_and_test.sh index d923fda192..c1ba3354a0 100755 --- a/docker_build_and_test.sh +++ b/docker_build_and_test.sh @@ -82,4 +82,4 @@ docker run --gpus all -it --rm \ echo "=== Running operator tests ===" cd /workspace/TransformerEngine/tests/cpp ./build/operator/test_operator "$@" - ' _ "${TEST_ARGS[@]}" \ No newline at end of file + ' _ "${TEST_ARGS[@]}" diff --git a/patch_swizzle.py b/patch_swizzle.py index 018cf83dc6..4faec4c7d0 100644 --- a/patch_swizzle.py +++ b/patch_swizzle.py @@ -8,8 +8,8 @@ template __global__ void __launch_bounds__(TB_DIM* TB_DIM) grouped_swizzle_scaling_variable_shape_kernel( - const void* input, - void* output, + const void* input, + void* output, const int64_t* m_array, const int64_t* k_array, const int* block_offsets, @@ -42,23 +42,23 @@ if (tensor_id == -1) return; int local_block_id = linear_block_id - block_offsets[tensor_id]; - + size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id]; size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id]; - + size_t padded_m = round_up_to_multiple(M, 128); size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - + int num_tiles_m = padded_m / SF_TILE_DIM_M; int num_tiles_k = padded_k / SF_TILE_DIM_K; - + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); if (vec_load_size == 3) vec_load_size = 1; int n_tiles_in_tb = TB_DIM * vec_load_size; int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - + int block_x = local_block_id % grid_dim_x; int block_y = local_block_id / grid_dim_x; @@ -71,29 +71,29 @@ if (rowwise) { if (vec_load_size == 4) { swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } else if (vec_load_size == 2) { swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } else { swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } } else { if (vec_load_size == 4) { swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } else if (vec_load_size == 2) { swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } else { swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, grid_dim_x, grid_dim_y); } } @@ -113,34 +113,34 @@ if (blockIdx.x == 0 && threadIdx.x == 0) { int current_block_offset = 0; size_t current_scale_offset = 0; - + for (size_t i = 0; i < num_tensors; ++i) { block_offsets[i] = current_block_offset; scale_offsets[i] = current_scale_offset; - + size_t m = rowwise ? m_array[i] : k_array[i]; size_t k = rowwise ? k_array[i] : m_array[i]; - + size_t padded_m = round_up_to_multiple(m, 128); size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - + int num_tiles_m = padded_m / 128; int num_tiles_k = padded_k / 4; - + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); if (vec_load_size == 3) vec_load_size = 1; - + int blocks_m = num_tiles_m; int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); if (!rowwise) { blocks_m = DIVUP(num_tiles_m, vec_load_size); blocks_k = DIVUP(num_tiles_k, TB_DIM); } - + current_block_offset += blocks_m * blocks_k; current_scale_offset += padded_m * padded_k * scale_elem_size; } - + block_offsets[num_tensors] = current_block_offset; scale_offsets[num_tensors] = current_scale_offset; *total_blocks = current_block_offset; @@ -150,7 +150,10 @@ namespace transformer_engine { """ -content = content.replace("namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors", kernels_code + "\nvoid swizzle_grouped_scaling_factors") +content = content.replace( + "namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors", + kernels_code + "\nvoid swizzle_grouped_scaling_factors", +) # 2. Modify swizzle_grouped_scaling_factors old_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, @@ -206,7 +209,7 @@ auto launch_grouped_swizzle_variable = [&](bool rowwise) { const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - + compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, d_global_counter, num_tensors, rowwise, scale_elem_size); @@ -215,7 +218,7 @@ grouped_swizzle_scaling_variable_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); - int persistent_blocks = 108 * 8; + int persistent_blocks = 108 * 8; dim3 num_blocks(persistent_blocks); const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; @@ -257,4 +260,3 @@ with open("transformer_engine/common/swizzle/swizzle.cu", "w") as f: f.write(content) - diff --git a/patch_swizzle_cpp.py b/patch_swizzle_cpp.py index 53ca6f3a60..7eef291fa6 100644 --- a/patch_swizzle_cpp.py +++ b/patch_swizzle_cpp.py @@ -10,7 +10,7 @@ });""" new_code = """ swizzle_output.set_with_gemm_swizzled_scales(true); - + size_t num_tensors = input.num_tensors(); size_t workspace_size = (num_tensors + 2) * sizeof(int) + (num_tensors + 1) * sizeof(size_t); workspace_size = roundup(workspace_size, 256); @@ -40,4 +40,3 @@ with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "w") as f: f.write(content) - diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index af7097e660..fc2ef8d5b9 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -228,7 +228,8 @@ __launch_bounds__(unary_kernel_threads) __global__ loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= size) continue; ComputeType val = static_cast(loader.separate()[i]); @@ -332,7 +333,8 @@ __launch_bounds__(unary_kernel_threads) __global__ grad_loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= size) continue; ComputeType val = static_cast(loader.separate()[i]); @@ -466,19 +468,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -508,19 +510,19 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp switch (align) { case Alignment::SAME_ALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } From f0f46288e220d7ea5817a62d4f8beedfb3dad279 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 8 Jun 2026 23:45:13 -0700 Subject: [PATCH 6/8] Remove patches Signed-off-by: Abhishek --- docker_build_and_test.sh | 85 ------ patch_swizzle.py | 262 ---------------- patch_swizzle_cpp.py | 42 --- .../common/util/vectorized_pointwise.h | 286 ++++++++++-------- 4 files changed, 161 insertions(+), 514 deletions(-) delete mode 100755 docker_build_and_test.sh delete mode 100644 patch_swizzle.py delete mode 100644 patch_swizzle_cpp.py diff --git a/docker_build_and_test.sh b/docker_build_and_test.sh deleted file mode 100755 index c1ba3354a0..0000000000 --- a/docker_build_and_test.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash -# Build TransformerEngine and C++ tests in Docker, then run tests. -# Usage: -# ./docker_build_and_test.sh # build, run all operator tests -# ./docker_build_and_test.sh --clean # clean + build + run all tests -# ./docker_build_and_test.sh --gtest_filter="*Swizzle*" -# ./docker_build_and_test.sh --clean --gtest_filter="OperatorTest/SwizzleTestSuite*" -# -# Lint only specific files (paths relative to repo root, space-separated): -# LINT_FILES="transformer_engine/common/swizzle/swizzle.cu transformer_engine/common/include/transformer_engine/swizzle.h" ./docker_build_and_test.sh - -set -e - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -IMAGE="nvcr.io/nvidia/pytorch:25.02-py3" -MOUNT="${SCRIPT_DIR}:/workspace/TransformerEngine" - -DO_CLEAN="0" -TEST_ARGS=() - -while [[ $# -gt 0 ]]; do - case "$1" in - --clean) - DO_CLEAN="1" - shift - ;; - *) - TEST_ARGS+=("$1") - shift - ;; - esac -done - -docker run --gpus all -it --rm \ - -v "${MOUNT}" \ - -e DO_CLEAN="${DO_CLEAN}" \ - -e LINT_FILES="${LINT_FILES:-}" \ - "${IMAGE}" \ - bash -c ' - set -e - cd /workspace/TransformerEngine - - if [ "${DO_CLEAN}" = "1" ]; then - echo "=== Cleaning build artifacts ===" - rm -rf build/ tests/cpp/build/ *.so libtransformer_engine.so *.egg-info - find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true - find . -name "*.so" -type f -delete 2>/dev/null || true - pip uninstall -y transformer-engine 2>/dev/null || true - fi - - echo "=== Building TransformerEngine ===" - NVTE_CUDA_ARCHS="89" MAX_JOBS=4 pip install --no-build-isolation -v -e . - - echo "=== Building C++ tests ===" - cd tests/cpp - cmake -GNinja -Bbuild . - cmake --build build - - echo "=== Lint checks (C++ and Python) ===" - cd /workspace/TransformerEngine - if [ -n "${LINT_FILES}" ]; then - pip3 install cpplint==1.6.0 pylint==3.3.1 -q - for f in ${LINT_FILES}; do - [ -f "$f" ] || continue - case "$f" in - *.cu|*.cuh|*.c|*.cpp|*.h|*.hpp|*.cc|*.cxx) echo "cpplint $f"; python3 -m cpplint --root=transformer_engine/common/include "$f" ;; - *.py) echo "pylint $f"; python3 -m pylint "$f" ;; - *) echo "skip (unknown type) $f" ;; - esac - done - else - TE_PATH=/workspace/TransformerEngine bash qa/L0_pytorch_lint/test.sh - TE_PATH=/workspace/TransformerEngine bash qa/L0_jax_lint/test.sh - fi - - echo "=== L0_* tests ===" - # for d in qa/L0_*/; do - # echo "--- $d ---" - # (cd /workspace/TransformerEngine && TE_PATH=/workspace/TransformerEngine bash "$d/test.sh") - # done - - echo "=== Running operator tests ===" - cd /workspace/TransformerEngine/tests/cpp - ./build/operator/test_operator "$@" - ' _ "${TEST_ARGS[@]}" diff --git a/patch_swizzle.py b/patch_swizzle.py deleted file mode 100644 index 4faec4c7d0..0000000000 --- a/patch_swizzle.py +++ /dev/null @@ -1,262 +0,0 @@ -import re - -with open("transformer_engine/common/swizzle/swizzle.cu", "r") as f: - content = f.read() - -# 1. Insert kernels before swizzle_grouped_scaling_factors -kernels_code = """ -template -__global__ void __launch_bounds__(TB_DIM* TB_DIM) - grouped_swizzle_scaling_variable_shape_kernel( - const void* input, - void* output, - const int64_t* m_array, - const int64_t* k_array, - const int* block_offsets, - const size_t* scale_offsets, - int* global_counter, - int num_tensors, - bool rowwise) { - - __shared__ int linear_block_id; - if (threadIdx.x == 0 && threadIdx.y == 0) { - linear_block_id = atomicAdd(global_counter, 1); - } - __syncthreads(); - - int tensor_id = -1; - int low = 0; - int high = num_tensors - 1; - while (low <= high) { - int mid = low + (high - low) / 2; - if (linear_block_id >= block_offsets[mid] && linear_block_id < block_offsets[mid + 1]) { - tensor_id = mid; - break; - } else if (linear_block_id < block_offsets[mid]) { - high = mid - 1; - } else { - low = mid + 1; - } - } - - if (tensor_id == -1) return; - - int local_block_id = linear_block_id - block_offsets[tensor_id]; - - size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id]; - size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id]; - - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / SF_TILE_DIM_M; - int num_tiles_k = padded_k / SF_TILE_DIM_K; - - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - - int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); - int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - - int block_x = local_block_id % grid_dim_x; - int block_y = local_block_id / grid_dim_x; - - const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; - uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; - - int original_M = static_cast(M); - int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); - - if (rowwise) { - if (vec_load_size == 4) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } - } else { - if (vec_load_size == 4) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } - } -} - -__global__ void compute_grouped_swizzle_setup( - const int64_t* m_array, - const int64_t* k_array, - int* block_offsets, - size_t* scale_offsets, - int* total_blocks, - int* global_counter, - size_t num_tensors, - bool rowwise, - size_t scale_elem_size) { - - if (blockIdx.x == 0 && threadIdx.x == 0) { - int current_block_offset = 0; - size_t current_scale_offset = 0; - - for (size_t i = 0; i < num_tensors; ++i) { - block_offsets[i] = current_block_offset; - scale_offsets[i] = current_scale_offset; - - size_t m = rowwise ? m_array[i] : k_array[i]; - size_t k = rowwise ? k_array[i] : m_array[i]; - - size_t padded_m = round_up_to_multiple(m, 128); - size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / 128; - int num_tiles_k = padded_k / 4; - - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - - int blocks_m = num_tiles_m; - int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); - if (!rowwise) { - blocks_m = DIVUP(num_tiles_m, vec_load_size); - blocks_k = DIVUP(num_tiles_k, TB_DIM); - } - - current_block_offset += blocks_m * blocks_k; - current_scale_offset += padded_m * padded_k * scale_elem_size; - } - - block_offsets[num_tensors] = current_block_offset; - scale_offsets[num_tensors] = current_scale_offset; - *total_blocks = current_block_offset; - *global_counter = 0; - } -} - -namespace transformer_engine { -""" -content = content.replace( - "namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors", - kernels_code + "\nvoid swizzle_grouped_scaling_factors", -) - -# 2. Modify swizzle_grouped_scaling_factors -old_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, - cudaStream_t stream) {""" -new_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, - void* workspace, cudaStream_t stream) {""" -content = content.replace(old_func, new_func) - -# 3. Add variable shape logic -old_logic = """ // Only support uniform shapes for graph-safe grouped swizzle - NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); - NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), - "Grouped swizzle requires uniform tensor shapes.");""" -new_logic = """ const int64_t* m_array = reinterpret_cast(input->first_dims.data_ptr); - const int64_t* k_array = reinterpret_cast(input->last_dims.data_ptr); - const bool is_variable_shape = (m_array != nullptr && k_array != nullptr); - - if (!is_variable_shape) { - // Fallback to uniform shape implementation - NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); - NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), - "Grouped swizzle requires uniform tensor shapes.");""" -content = content.replace(old_logic, new_logic) - -# Close the if block and add the else block for variable shape -old_launch_end = """ if (has_rowwise_scale_inv) { - launch_grouped_swizzle(true); - } - if (has_columnwise_scale_inv) { - launch_grouped_swizzle(false); - }""" -new_launch_end = """ if (has_rowwise_scale_inv) { - launch_grouped_swizzle(true); - } - if (has_columnwise_scale_inv) { - launch_grouped_swizzle(false); - } - } else { - // Variable shape implementation using Device-Side Block Scheduler - size_t num_tensors = input->num_tensors; - NVTE_CHECK(workspace != nullptr, "Workspace must be provided for variable shape grouped swizzle."); - - int* d_block_offsets = reinterpret_cast(workspace); - size_t* d_scale_offsets = reinterpret_cast(d_block_offsets + num_tensors + 2); - int* d_global_counter = reinterpret_cast(d_scale_offsets + num_tensors + 1); - int* d_total_blocks = d_global_counter + 1; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - const dim3 block_size(TB_DIM, TB_DIM); - const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - - auto launch_grouped_swizzle_variable = [&](bool rowwise) { - const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) - : typeToSize(input->columnwise_scale_inv.dtype); - - compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( - m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, - d_global_counter, num_tensors, rowwise, scale_elem_size); - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_scaling_variable_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); - - int persistent_blocks = 108 * 8; - dim3 num_blocks(persistent_blocks); - - const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; - void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; - - grouped_swizzle_scaling_variable_shape_kernel - <<>>( - input_ptr, output_ptr, m_array, k_array, d_block_offsets, - d_scale_offsets, d_global_counter, num_tensors, rowwise); - - NVTE_CHECK_CUDA(cudaGetLastError()); - }; - - if (has_rowwise_scale_inv) { - launch_grouped_swizzle_variable(true); - } - if (has_columnwise_scale_inv) { - launch_grouped_swizzle_variable(false); - } - }""" -content = content.replace(old_launch_end, new_launch_end) - -# 4. Modify nvte_swizzle_grouped_scaling_factors wrapper -old_wrapper = """void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); - using namespace transformer_engine; - swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), - convertNVTEGroupedTensorCheck(output), stream); -}""" -new_wrapper = """void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, - void* workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); - using namespace transformer_engine; - swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), - convertNVTEGroupedTensorCheck(output), workspace, stream); -}""" -content = content.replace(old_wrapper, new_wrapper) - -with open("transformer_engine/common/swizzle/swizzle.cu", "w") as f: - f.write(content) diff --git a/patch_swizzle_cpp.py b/patch_swizzle_cpp.py deleted file mode 100644 index 7eef291fa6..0000000000 --- a/patch_swizzle_cpp.py +++ /dev/null @@ -1,42 +0,0 @@ -import re - -with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "r") as f: - content = f.read() - -old_code = """ swizzle_output.set_with_gemm_swizzled_scales(true); - NVTE_SCOPED_GIL_RELEASE({ - nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), - at::cuda::getCurrentCUDAStream()); - });""" - -new_code = """ swizzle_output.set_with_gemm_swizzled_scales(true); - - size_t num_tensors = input.num_tensors(); - size_t workspace_size = (num_tensors + 2) * sizeof(int) + (num_tensors + 1) * sizeof(size_t); - workspace_size = roundup(workspace_size, 256); - auto workspace = allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), - getDataPtr(workspace), - at::cuda::getCurrentCUDAStream()); - });""" - -content = content.replace(old_code, new_code) - -# Check if first_dims error check exists and remove it -old_check = """ const auto first_dims = input.get_first_dims(); - const auto last_dims = input.get_last_dims(); - if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { - NVTE_ERROR( - "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " - "absent)."); - }""" - -new_check = """ const auto first_dims = input.get_first_dims(); - const auto last_dims = input.get_last_dims();""" - -content = content.replace(old_check, new_check) - -with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "w") as f: - f.write(content) diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index fc2ef8d5b9..3fa7f4c8ee 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -198,85 +198,103 @@ __launch_bounds__(unary_kernel_threads) __global__ size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; - size_t start = 0; - size_t size = N; - size_t tensor_id = 0; - if (gridDim.y > 1) { - tensor_id = blockIdx.y; - if (offsets != nullptr) { - start = offsets[tensor_id]; - size = first_dims[tensor_id] * last_dims[tensor_id]; - } else if (num_tensors > 1) { - size = N / num_tensors; - start = tensor_id * size; - } - } - - VectorizedLoader loader(input + start, size); - VectorizedStorer storer(output + start, size); + VectorizedLoader loader(input, N); + VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; + if (scale != nullptr && offsets == nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - const size_t M = loader.num_aligned_elements(); + float block_max[64] = {0.0f}; + + const size_t M = num_aligned_elements; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, size); + loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= size) continue; + if (global_idx >= N) continue; + + size_t tensor_id = 0; + bool is_valid = true; + if (offsets != nullptr) { + tensor_id = find_tensor_id(offsets, num_tensors, global_idx); + size_t start = offsets[tensor_id]; + size_t size = first_dims[tensor_id] * last_dims[tensor_id]; + is_valid = (global_idx >= start && global_idx < start + size); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + tensor_id = global_idx / size; + if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; + } - ComputeType val = static_cast(loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + if (is_valid) { + ComputeType val = static_cast(loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } } - } - ComputeType temp = OP(val, p); - if (requires_amax) { - __builtin_assume(max >= 0); - max = fmaxf(fabsf(temp), max); - } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } else if (gridDim.y == 1) { - current_scale = s; + ComputeType temp = OP(val, p); + if (requires_amax) { + __builtin_assume(block_max[tensor_id] >= 0); + block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - const bool is_start = (global_idx == 0); - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } else if (gridDim.y == 1) { - current_scale = s; + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + bool is_start = false; + if (offsets != nullptr) { + is_start = (global_idx == offsets[tensor_id]); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + is_start = (global_idx == tensor_id * size); + } + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + } else { + storer.separate()[i] = OutputType(); } } - storer.store(tid, size); + storer.store(tid, N); } // Reduce amax over block if (requires_amax) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0 && max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; - static_assert(std::is_same::value); - atomicMaxFloat(&amax[amax_idx], max); + if (offsets != nullptr || num_tensors > 1) { + for (size_t t = 0; t < num_tensors; ++t) { + float t_max = block_max[t]; + t_max = reduce_max(t_max, warp_id); + if (threadIdx.x == 0 && t_max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? t : 0; + atomicMaxFloat(&amax[amax_idx], t_max); + } + } + } else { + max = block_max[0]; + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } } } @@ -301,88 +319,106 @@ __launch_bounds__(unary_kernel_threads) __global__ const int64_t *last_dims = nullptr, size_t num_tensors = 1, size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { - size_t start = 0; - size_t size = N; - size_t tensor_id = 0; - if (gridDim.y > 1) { - tensor_id = blockIdx.y; - if (offsets != nullptr) { - start = offsets[tensor_id]; - size = first_dims[tensor_id] * last_dims[tensor_id]; - } else if (num_tensors > 1) { - size = N / num_tensors; - start = tensor_id * size; - } - } - - VectorizedLoader loader(input + start, size); - VectorizedLoader grad_loader(grad + start, size); - VectorizedStorer storer(output + start, size); + VectorizedLoader loader(input, N); + VectorizedLoader grad_loader(grad, N); + VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; + if (scale != nullptr && offsets == nullptr) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - const size_t M = loader.num_aligned_elements(); + float block_max[64] = {0.0f}; + + const size_t M = num_aligned_elements; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, size); - grad_loader.load(tid, size); + loader.load(tid, N); + grad_loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= size) continue; + if (global_idx >= N) continue; + + size_t tensor_id = 0; + bool is_valid = true; + if (offsets != nullptr) { + tensor_id = find_tensor_id(offsets, num_tensors, global_idx); + size_t start = offsets[tensor_id]; + size_t size = first_dims[tensor_id] * last_dims[tensor_id]; + is_valid = (global_idx >= start && global_idx < start + size); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + tensor_id = global_idx / size; + if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; + } - ComputeType val = static_cast(loader.separate()[i]); - const ComputeType g = static_cast(grad_loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + if (is_valid) { + ComputeType val = static_cast(loader.separate()[i]); + const ComputeType g = static_cast(grad_loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); + } } - } - ComputeType temp = OP(val, p) * g; - if (requires_amax) { - __builtin_assume(max >= 0); - max = fmaxf(fabsf(temp), max); - } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } else if (gridDim.y == 1) { - current_scale = s; + ComputeType temp = OP(val, p) * g; + if (requires_amax) { + __builtin_assume(block_max[tensor_id] >= 0); + block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - const bool is_start = (global_idx == 0); - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } else if (gridDim.y == 1) { - current_scale = s; + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + temp = temp * current_scale; } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + bool is_start = false; + if (offsets != nullptr) { + is_start = (global_idx == offsets[tensor_id]); + } else if (num_tensors > 1) { + size_t size = N / num_tensors; + is_start = (global_idx == tensor_id * size); + } + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); + } + } else { + storer.separate()[i] = OutputType(); } } - storer.store(tid, size); + storer.store(tid, N); } // Reduce amax over block if (requires_amax) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0 && max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; - static_assert(std::is_same::value); - atomicMaxFloat(&amax[amax_idx], max); + if (offsets != nullptr || num_tensors > 1) { + for (size_t t = 0; t < num_tensors; ++t) { + float t_max = block_max[t]; + t_max = reduce_max(t_max, warp_id); + if (threadIdx.x == 0 && t_max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? t : 0; + atomicMaxFloat(&amax[amax_idx], t_max); + } + } + } else { + max = block_max[0]; + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } } } @@ -461,24 +497,24 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); - const dim3 grid(num_blocks, num_tensors); + num_blocks = std::min(num_blocks, max_blocks); switch (align) { case Alignment::SAME_ALIGNED: - unary_kernel<<>>( + unary_kernel<<>>( input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_kernel<<>>( + unary_kernel<<>>( input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( + unary_kernel<1, true, fp32, Param, OP><<>>( input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; @@ -503,24 +539,24 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); - const dim3 grid(num_blocks, num_tensors); + num_blocks = std::min(num_blocks, max_blocks); switch (align) { case Alignment::SAME_ALIGNED: - unary_grad_kernel<<>>( + unary_grad_kernel<<>>( grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_grad_kernel<<>>( + unary_grad_kernel<<>>( grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( + unary_grad_kernel<1, true, fp32, Param, OP><<>>( grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; From feb1b87a8ee648c8a17f31bd47017abba6d1a092 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Mon, 8 Jun 2026 23:49:19 -0700 Subject: [PATCH 7/8] Added changes back Signed-off-by: Abhishek --- .../common/util/vectorized_pointwise.h | 318 ++++++++---------- 1 file changed, 140 insertions(+), 178 deletions(-) diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 3fa7f4c8ee..af7097e660 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -198,103 +198,84 @@ __launch_bounds__(unary_kernel_threads) __global__ size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { if (noop != nullptr && noop[0] == 1.0f) return; - VectorizedLoader loader(input, N); - VectorizedStorer storer(output, N); + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - float block_max[64] = {0.0f}; - - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); + loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = - (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= N) continue; - - size_t tensor_id = 0; - bool is_valid = true; - if (offsets != nullptr) { - tensor_id = find_tensor_id(offsets, num_tensors, global_idx); - size_t start = offsets[tensor_id]; - size_t size = first_dims[tensor_id] * last_dims[tensor_id]; - is_valid = (global_idx >= start && global_idx < start + size); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - tensor_id = global_idx / size; - if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; - } + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; - if (is_valid) { - ComputeType val = static_cast(loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); - } - } - ComputeType temp = OP(val, p); - if (requires_amax) { - __builtin_assume(block_max[tensor_id] >= 0); - block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + ComputeType val = static_cast(loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - bool is_start = false; - if (offsets != nullptr) { - is_start = (global_idx == offsets[tensor_id]); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - is_start = (global_idx == tensor_id * size); + } + ComputeType temp = OP(val, p); + if (requires_amax) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - } else { - storer.separate()[i] = OutputType(); + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { - if (offsets != nullptr || num_tensors > 1) { - for (size_t t = 0; t < num_tensors; ++t) { - float t_max = block_max[t]; - t_max = reduce_max(t_max, warp_id); - if (threadIdx.x == 0 && t_max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? t : 0; - atomicMaxFloat(&amax[amax_idx], t_max); - } - } - } else { - max = block_max[0]; - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + max = reduce_max(max, warp_id); + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; + static_assert(std::is_same::value); + atomicMaxFloat(&amax[amax_idx], max); } } @@ -319,106 +300,87 @@ __launch_bounds__(unary_kernel_threads) __global__ const int64_t *last_dims = nullptr, size_t num_tensors = 1, size_t scale_numel = 1, size_t scale_inv_numel = 1, size_t amax_numel = 1) { - VectorizedLoader loader(input, N); - VectorizedLoader grad_loader(grad, N); - VectorizedStorer storer(output, N); + size_t start = 0; + size_t size = N; + size_t tensor_id = 0; + if (gridDim.y > 1) { + tensor_id = blockIdx.y; + if (offsets != nullptr) { + start = offsets[tensor_id]; + size = first_dims[tensor_id] * last_dims[tensor_id]; + } else if (num_tensors > 1) { + size = N / num_tensors; + start = tensor_id * size; + } + } + + VectorizedLoader loader(input + start, size); + VectorizedLoader grad_loader(grad + start, size); + VectorizedStorer storer(output + start, size); ComputeType max = 0; ComputeType s = 1; const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { - if (scale != nullptr && offsets == nullptr) s = *scale; + if (scale != nullptr && offsets == nullptr && gridDim.y == 1) s = *scale; } const int warp_id = threadIdx.x / THREADS_PER_WARP; - float block_max[64] = {0.0f}; - - const size_t M = num_aligned_elements; + const size_t M = loader.num_aligned_elements(); for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { - loader.load(tid, N); - grad_loader.load(tid, N); + loader.load(tid, size); + grad_loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = - (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); - if (global_idx >= N) continue; - - size_t tensor_id = 0; - bool is_valid = true; - if (offsets != nullptr) { - tensor_id = find_tensor_id(offsets, num_tensors, global_idx); - size_t start = offsets[tensor_id]; - size_t size = first_dims[tensor_id] * last_dims[tensor_id]; - is_valid = (global_idx >= start && global_idx < start + size); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - tensor_id = global_idx / size; - if (tensor_id >= num_tensors) tensor_id = num_tensors - 1; - } - - if (is_valid) { - ComputeType val = static_cast(loader.separate()[i]); - const ComputeType g = static_cast(grad_loader.separate()[i]); - if constexpr (is_fp8::value) { - if (scale_inv != nullptr) { - val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); - } - } - ComputeType temp = OP(val, p) * g; - if (requires_amax) { - __builtin_assume(block_max[tensor_id] >= 0); - block_max[tensor_id] = fmaxf(fabsf(temp), block_max[tensor_id]); + const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + if (global_idx >= size) continue; + + ComputeType val = static_cast(loader.separate()[i]); + const ComputeType g = static_cast(grad_loader.separate()[i]); + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) { + val = val * ((scale_inv_numel == num_tensors) ? scale_inv[tensor_id] : scale_inv[0]); } - if constexpr (is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - temp = temp * current_scale; - } - storer.separate()[i] = static_cast(temp); - - // Update scale-inverse for quantization if requested - bool is_start = false; - if (offsets != nullptr) { - is_start = (global_idx == offsets[tensor_id]); - } else if (num_tensors > 1) { - size_t size = N / num_tensors; - is_start = (global_idx == tensor_id * size); + } + ComputeType temp = OP(val, p) * g; + if (requires_amax) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - if (is_start && scale_inv != nullptr && !is_fp8::value) { - float current_scale = 1.0f; - if (scale != nullptr) { - current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; - } - size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; - reciprocal(&scale_inv[scale_inv_idx], current_scale); + temp = temp * current_scale; + } + storer.separate()[i] = static_cast(temp); + + // Update scale-inverse for quantization if requested + const bool is_start = (global_idx == 0); + if (is_start && scale_inv != nullptr && !is_fp8::value) { + float current_scale = 1.0f; + if (scale != nullptr) { + current_scale = (scale_numel == num_tensors) ? scale[tensor_id] : scale[0]; + } else if (gridDim.y == 1) { + current_scale = s; } - } else { - storer.separate()[i] = OutputType(); + size_t scale_inv_idx = (scale_inv_numel == num_tensors) ? tensor_id : 0; + reciprocal(&scale_inv[scale_inv_idx], current_scale); } } - storer.store(tid, N); + storer.store(tid, size); } // Reduce amax over block if (requires_amax) { - if (offsets != nullptr || num_tensors > 1) { - for (size_t t = 0; t < num_tensors; ++t) { - float t_max = block_max[t]; - t_max = reduce_max(t_max, warp_id); - if (threadIdx.x == 0 && t_max > 0.0f) { - size_t amax_idx = (amax_numel == num_tensors) ? t : 0; - atomicMaxFloat(&amax[amax_idx], t_max); - } - } - } else { - max = block_max[0]; - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + max = reduce_max(max, warp_id); + if (threadIdx.x == 0 && max > 0.0f) { + size_t amax_idx = (amax_numel == num_tensors) ? tensor_id : 0; + static_assert(std::is_same::value); + atomicMaxFloat(&amax[amax_idx], max); } } @@ -497,26 +459,26 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; - size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - num_blocks = std::min(num_blocks, max_blocks); + const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); + const dim3 grid(num_blocks, num_tensors); switch (align) { case Alignment::SAME_ALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, - last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_kernel<1, true, fp32, Param, OP><<>>( + input, noop, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -539,26 +501,26 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; - size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; - num_blocks = std::min(num_blocks, max_blocks); + const size_t num_blocks = std::min(DIVUP(num_aligned_elements, threads), max_blocks); + const dim3 grid(num_blocks, num_tensors); switch (align) { case Alignment::SAME_ALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, - first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, - last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N, + offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } From 02dc01706394850a3dec84083e1d8c99313cb88b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 06:51:06 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/util/vectorized_pointwise.h | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index af7097e660..fc2ef8d5b9 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -228,7 +228,8 @@ __launch_bounds__(unary_kernel_threads) __global__ loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= size) continue; ComputeType val = static_cast(loader.separate()[i]); @@ -332,7 +333,8 @@ __launch_bounds__(unary_kernel_threads) __global__ grad_loader.load(tid, size); #pragma unroll for (int i = 0; i < nvec; ++i) { - const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); + const size_t global_idx = + (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment())); if (global_idx >= size) continue; ComputeType val = static_cast(loader.separate()[i]); @@ -466,19 +468,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } } @@ -508,19 +510,19 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp switch (align) { case Alignment::SAME_ALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::SAME_UNALIGNED: unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets, + first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N, - offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); + grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims, + last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel); break; } }