From 18d4d2c5c32519861d8faa122581d00a874b3426 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Fri, 12 Jun 2026 17:28:54 -0700 Subject: [PATCH 1/6] support scaled swiglu, scaled srelu and scaled clamp swiglu Signed-off-by: zhongboz --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_scaled_activation.cu | 328 ++++++++++ transformer_engine/common/CMakeLists.txt | 2 + .../common/activation/scaled_activation.cu | 567 ++++++++++++++++++ .../include/transformer_engine/activation.h | 103 ++++ 5 files changed, 1001 insertions(+) create mode 100644 tests/cpp/operator/test_scaled_activation.cu create mode 100644 transformer_engine/common/activation/scaled_activation.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 9b67c09f34..d5c446fb48 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu + test_scaled_activation.cu test_normalization.cu test_normalization_mxfp8.cu test_memset.cu diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu new file mode 100644 index 0000000000..1cb630a0bc --- /dev/null +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -0,0 +1,328 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +enum class ScaledActivationCase { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +constexpr float kClampedLimit = 1.3f; +constexpr float kClampedAlpha = 1.702f; +constexpr float kClampedLinearOffset = 0.5f; + +const char *activation_name(ScaledActivationCase activation) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return "scaled_swiglu"; + case ScaledActivationCase::kClampedSwiGLU: + return "scaled_clamped_swiglu"; + case ScaledActivationCase::kSReLU: + return "scaled_srelu"; + } + return "unknown"; +} + +inline float sigmoid(const float x) { return 1.0f / (1.0f + expf(-x)); } + +inline float qgelu_alpha(const float x, const float alpha) { return x * sigmoid(alpha * x); } + +inline float dqgelu_alpha(const float x, const float alpha) { + const float sig = sigmoid(alpha * x); + return alpha * x * sig * (1.0f - sig) + sig; +} + +inline float silu_ref(const float x) { return x * sigmoid(x); } + +inline float dsilu_ref(const float x) { + const float sig = sigmoid(x); + return x * sig * (1.0f - sig) + sig; +} + +inline float srelu_ref(const float x) { return x > 0.0f ? x * x : 0.0f; } + +inline float dsrelu_ref(const float x) { return fmaxf(0.0f, 2.0f * x); } + +inline void glu_indices(const size_t row, const size_t col, const size_t hidden, + const int64_t interleave, size_t *act_idx, size_t *linear_idx) { + if (interleave > 0) { + const size_t block = col / static_cast(interleave); + const size_t lane = col % static_cast(interleave); + const size_t base = row * hidden * 2 + block * static_cast(interleave) * 2 + lane; + *act_idx = base; + *linear_idx = base + static_cast(interleave); + } else { + const size_t base = row * hidden * 2; + *act_idx = base + col; + *linear_idx = base + hidden + col; + } +} + +inline float gated_unscaled(const ScaledActivationCase activation, const float act_in, + const float linear_in) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: + return silu_ref(act_in) * linear_in; + case ScaledActivationCase::kClampedSwiGLU: { + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + return act * linear; + } + case ScaledActivationCase::kSReLU: + return srelu_ref(act_in); + } + return 0.0f; +} + +inline void gated_grads(const ScaledActivationCase activation, const float act_in, + const float linear_in, float *dact, float *dlinear, float *unscaled) { + switch (activation) { + case ScaledActivationCase::kSwiGLU: { + const float act = silu_ref(act_in); + *unscaled = act * linear_in; + *dact = dsilu_ref(act_in) * linear_in; + *dlinear = act; + return; + } + case ScaledActivationCase::kClampedSwiGLU: { + const bool dlinear_mask = linear_in <= kClampedLimit && linear_in >= -kClampedLimit; + const float act = qgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha); + const float dact_base = + act_in <= kClampedLimit ? dqgelu_alpha(fminf(kClampedLimit, act_in), kClampedAlpha) + : 0.0f; + const float linear = + fminf(fmaxf(-kClampedLimit, linear_in), kClampedLimit) + kClampedLinearOffset; + *unscaled = act * linear; + *dact = dact_base * linear; + *dlinear = dlinear_mask ? act : 0.0f; + return; + } + case ScaledActivationCase::kSReLU: + *unscaled = srelu_ref(act_in); + *dact = dsrelu_ref(act_in); + *dlinear = 0.0f; + return; + } +} + +template +void compute_reference(ScaledActivationCase activation, const DataT *input, const ScaleT *scales, + const DataT *grad_output, DataT *output, DataT *grad_input, + DataT *grad_scales, const size_t rows, const size_t hidden, + const int64_t interleave, const bool compute_grad_scales) { + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + std::fill(grad_input, grad_input + rows * input_cols, static_cast(0.0f)); + + for (size_t row = 0; row < rows; ++row) { + const float scale = static_cast(scales[row]); + float scale_grad = 0.0f; + for (size_t col = 0; col < hidden; ++col) { + const size_t out_idx = row * hidden + col; + float unscaled = 0.0f; + float dact = 0.0f; + float dlinear = 0.0f; + if (is_gated) { + size_t act_idx = 0; + size_t linear_idx = 0; + glu_indices(row, col, hidden, interleave, &act_idx, &linear_idx); + const float act_in = static_cast(input[act_idx]); + const float linear_in = static_cast(input[linear_idx]); + unscaled = gated_unscaled(activation, act_in, linear_in); + gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled); + + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[act_idx] = static_cast(scaled_grad * dact); + grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + } else { + const float x = static_cast(input[out_idx]); + unscaled = srelu_ref(x); + const float scaled_grad = static_cast(grad_output[out_idx]) * scale; + grad_input[out_idx] = static_cast(scaled_grad * dsrelu_ref(x)); + } + + output[out_idx] = static_cast(unscaled * scale); + scale_grad += static_cast(grad_output[out_idx]) * unscaled; + } + if (compute_grad_scales) { + grad_scales[row] = static_cast(scale_grad); + } + } +} + +template +void run_scaled_activation_test(ScaledActivationCase activation, const size_t rows, + const size_t hidden, const int64_t interleave, + const bool compute_grad_scales) { + using namespace test; + const DType data_type = TypeInfo::dtype; + const DType scale_type = TypeInfo::dtype; + const bool is_gated = activation != ScaledActivationCase::kSReLU; + const size_t input_cols = is_gated ? hidden * 2 : hidden; + + Tensor input("input", std::vector{rows, input_cols}, data_type); + Tensor scales("act_scales", std::vector{rows}, scale_type); + Tensor output("output", std::vector{rows, hidden}, data_type); + Tensor grad_output("grad_output", std::vector{rows, hidden}, data_type); + Tensor grad_input("grad_input", std::vector{rows, input_cols}, data_type); + Tensor grad_scales("grad_scales", std::vector{rows}, data_type); + + fillUniform(&input); + fillUniform(&scales); + fillUniform(&grad_output); + + std::unique_ptr ref_output = std::make_unique(rows * hidden); + std::unique_ptr ref_grad_input = std::make_unique(rows * input_cols); + std::unique_ptr ref_grad_scales = std::make_unique(rows); + + compute_reference(activation, input.rowwise_cpu_dptr(), scales.rowwise_cpu_dptr(), + grad_output.rowwise_cpu_dptr(), ref_output.get(), + ref_grad_input.get(), ref_grad_scales.get(), rows, hidden, interleave, + compute_grad_scales); + + switch (activation) { + case ScaledActivationCase::kSwiGLU: + nvte_scaled_swiglu(input.data(), scales.data(), output.data(), interleave, 0); + nvte_scaled_dswiglu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, interleave, 0); + break; + case ScaledActivationCase::kClampedSwiGLU: + nvte_scaled_clamped_swiglu(input.data(), scales.data(), output.data(), kClampedLimit, + kClampedAlpha, kClampedLinearOffset, interleave, 0); + nvte_scaled_clamped_dswiglu( + grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, kClampedLimit, kClampedAlpha, + kClampedLinearOffset, interleave, 0); + break; + case ScaledActivationCase::kSReLU: + nvte_scaled_srelu(input.data(), scales.data(), output.data(), 0); + nvte_scaled_dsrelu(grad_output.data(), input.data(), scales.data(), grad_input.data(), + compute_grad_scales ? grad_scales.data() : nullptr, 0); + break; + } + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(data_type); + if (data_type == DType::kFloat32) { + atol = 5e-5; + rtol = 5e-5; + } + compareResults("scaled_activation_output", output, ref_output.get(), atol, rtol); + compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), atol, rtol); + if (compute_grad_scales) { + compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), atol, rtol); + } +} + +class ScaledActivationTest + : public ::testing::TestWithParam< + std::tuple, int64_t, + bool>> { +}; + +std::string test_name_generator( + const testing::TestParamInfo &info) { + const auto activation = std::get<0>(info.param); + const auto data_type = std::get<1>(info.param); + const auto scale_type = std::get<2>(info.param); + const auto shape = std::get<3>(info.param); + const auto interleave = std::get<4>(info.param); + const auto compute_grad_scales = std::get<5>(info.param); + return std::string(activation_name(activation)) + "_data_" + test::typeName(data_type) + + "_scale_" + test::typeName(scale_type) + "_m_" + std::to_string(shape.first) + "_h_" + + std::to_string(shape.second) + "_interleave_" + std::to_string(interleave) + + (compute_grad_scales ? "_with_scale_grad" : "_no_scale_grad"); +} + +} // namespace + +TEST_P(ScaledActivationTest, ForwardBackward) { + const auto activation = std::get<0>(GetParam()); + const auto data_type = std::get<1>(GetParam()); + const auto scale_type = std::get<2>(GetParam()); + const auto shape = std::get<3>(GetParam()); + const auto interleave = std::get<4>(GetParam()); + const auto compute_grad_scales = std::get<5>(GetParam()); + + if (activation == ScaledActivationCase::kSReLU && interleave != 0) { + GTEST_SKIP() << "SReLU is not a GLU activation."; + } + if (activation != ScaledActivationCase::kSReLU && interleave > 0 && + shape.second % static_cast(interleave) != 0) { + GTEST_SKIP() << "Hidden size must be divisible by GLU interleave."; + } + + using namespace test; + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(data_type, DataT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(scale_type, ScaleT, { + run_scaled_activation_test(activation, shape.first, shape.second, interleave, + compute_grad_scales); + }); + }); +} + +// Test axes (the six tuple elements consumed by ScaledActivationTest): +// 1. Activation : SwiGLU and ClampedSwiGLU are gated (input is [M, 2H]); +// SReLU is unary (input is [M, H], no gate split). +// 2. Data dtype : dtype of the activation input/output tensors. +// 3. Scale dtype : dtype of act_scales / grad_act_scales. +// 4. Shape {rows, hidden}: rows = M (tokens), hidden = H (output width; gated input is 2H). +// 5. GLU interleave : 0 = contiguous [a | b]; 32 = interleaved a/b blocks. Only valid +// for gated activations with hidden % 32 == 0; SReLU skips != 0. +// 6. compute_grad_scales : whether the backward also reduces grad_act_scales. + +// Regular shapes: hidden is a multiple of 32, so the interleaved (32) layout is exercised +// alongside the contiguous (0) layout. +// Regular shapes (hidden % 32 == 0) and weird/irregular shapes (tiny, prime, non-32-aligned) +// share one instantiation. Interleave is swept over {0, 32}; invalid combinations -- SReLU with +// any nonzero interleave, or a gated activation whose hidden is not divisible by the interleave -- +// are skipped at runtime by the GTEST_SKIP guards in the test body. +INSTANTIATE_TEST_SUITE_P( + OperatorTest_ScaledActivation, ScaledActivationTest, + ::testing::Combine( + ::testing::Values(ScaledActivationCase::kSwiGLU, ScaledActivationCase::kClampedSwiGLU, + ScaledActivationCase::kSReLU), + ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype + ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype + ::testing::Values(std::pair{17, 64}, // odd rows, aligned hidden + std::pair{8, 96}, // 96 = 3 * 32 + std::pair{32, 32}, // minimal aligned square + std::pair{128, 128}, // square + std::pair{64, 256}, // wide hidden + std::pair{256, 64}, // many rows, narrow hidden + std::pair{128, 512}, // FFN-ish width + std::pair{1, 1}, // single element + std::pair{1, 96}, // single row + std::pair{96, 1}, // single hidden column + std::pair{3, 7}, // tiny primes + std::pair{13, 100}, // non-power-of-two + std::pair{7, 257}, // prime, odd hidden + std::pair{33, 65}, // odd dims + std::pair{129, 31}), // odd rows, hidden < 32 + ::testing::Values(0, 32), // contiguous + interleaved + ::testing::Values(false, true)), // grad_act_scales off / on + test_name_generator); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..b4ba17e048 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -255,6 +255,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu @@ -513,6 +514,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) activation/relu_dbias.cu activation/relu_grouped.cu activation/relu_grouped_dbias.cu + activation/scaled_activation.cu activation/swiglu.cu activation/swiglu_dbias.cu activation/swiglu_grouped.cu diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu new file mode 100644 index 0000000000..176253edc2 --- /dev/null +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -0,0 +1,567 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* Scaled activations: apply an activation, multiply by a per-row scale + * (act_scales[row]), do all math in fp32, and cast once at the store. The + * backward path optionally also reduces the gradient of the per-row scale. + * + * The six __global__ kernels below: + * + * # | Kernel | Activation | Dir | grad_act_scales | Launch + * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- + * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | flat element grid + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | flat element grid + * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | flat element grid + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | flat element grid + * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | one block per row + * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | one block per row + * + * The "with scale grad" variants compute grad_act_scales[row] = sum_j dY * unscaled, + * a per-row reduction that requires the one-block-per-row launch; when + * grad_act_scales is null the cheaper flat element-wise grid is used instead. + */ + +#include + +#include + +#include "../common.h" +#include "../util/math.h" + +namespace transformer_engine { +namespace { + +enum class ScaledActivation { + kSwiGLU, + kClampedSwiGLU, + kSReLU, +}; + +__device__ __forceinline__ void glu_input_indices(const size_t row, const size_t col, + const size_t hidden, + const int64_t glu_interleave_size, + size_t *act_idx, size_t *linear_idx) { + if (glu_interleave_size > 0) { + const size_t interleave = static_cast(glu_interleave_size); + const size_t block = col / interleave; + const size_t lane = col % interleave; + const size_t base = row * hidden * 2 + block * interleave * 2 + lane; + *act_idx = base; + *linear_idx = base + interleave; + } else { + const size_t base = row * hidden * 2; + *act_idx = base + col; + *linear_idx = base + hidden + col; + } +} + +template +__device__ __forceinline__ float gated_forward_value(const float act_in, const float linear_in, + const ClampedSwiGLUParam ¶m) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + Empty empty = {}; + return silu(act_in, empty) * linear_in; + } else { + const float linear = + fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + return clamped_silu(act_in, param) * linear; + } +} + +template +__device__ __forceinline__ void gated_backward_values(const float act_in, const float linear_in, + const ClampedSwiGLUParam ¶m, + float *dact, float *dlinear, + float *unscaled) { + if constexpr (Act == ScaledActivation::kSwiGLU) { + Empty empty = {}; + const float act = silu(act_in, empty); + *unscaled = act * linear_in; + *dact = dsilu(act_in, empty) * linear_in; + *dlinear = act; + } else { + const bool dlinear_mask = linear_in <= param.limit && linear_in >= -param.limit; + const float linear = + fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + const float act = clamped_silu(act_in, param); + *unscaled = act * linear; + *dact = clamped_dsilu(act_in, param) * linear; + *dlinear = dlinear_mask ? act : 0.0f; + } +} + +template +__global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, + const size_t hidden, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param) { + const size_t total = rows * hidden; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const size_t col = idx % hidden; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + const float unscaled = gated_forward_value(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param); + const float scale = static_cast(act_scales[row]); + output[idx] = static_cast(unscaled * scale); + } +} + +template +__global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, + OutputT *output, const size_t rows, + const size_t hidden) { + const size_t total = rows * hidden; + Empty empty = {}; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const float unscaled = srelu(static_cast(input[idx]), empty); + const float scale = static_cast(act_scales[row]); + output[idx] = static_cast(unscaled * scale); + } +} + +template +__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param) { + const size_t total = rows * hidden; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const size_t col = idx % hidden; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + float dact = 0.0f; + float dlinear = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param, &dact, &dlinear, + &unscaled); + (void)unscaled; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_output[idx]) * scale; + grad_input[act_idx] = static_cast(grad * dact); + grad_input[linear_idx] = static_cast(grad * dlinear); + } +} + +template +__global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden) { + const size_t total = rows * hidden; + Empty empty = {}; + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; + idx += gridDim.x * blockDim.x) { + const size_t row = idx / hidden; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_output[idx]) * scale; + grad_input[idx] = + static_cast(grad * dsrelu(static_cast(input[idx]), empty)); + } +} + +template +__global__ void scaled_gated_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, + const int64_t glu_interleave_size, const ClampedSwiGLUParam param) { + constexpr int kThreads = 256; + __shared__ float smem[kThreads]; + const size_t row = blockIdx.x; + float scale_grad = 0.0f; + + for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { + const size_t grad_idx = row * hidden + col; + size_t act_idx = 0; + size_t linear_idx = 0; + glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); + + float dact = 0.0f; + float dlinear = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(input[act_idx]), + static_cast(input[linear_idx]), param, &dact, &dlinear, + &unscaled); + const float grad = static_cast(grad_output[grad_idx]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + grad_input[act_idx] = static_cast(scaled_grad * dact); + grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + } + + smem[threadIdx.x] = scale_grad; + __syncthreads(); + for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + if (threadIdx.x < offset) { + smem[threadIdx.x] += smem[threadIdx.x + offset]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(smem[0]); + } +} + +template +__global__ void scaled_srelu_backward_with_scale_grad_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden) { + constexpr int kThreads = 256; + __shared__ float smem[kThreads]; + const size_t row = blockIdx.x; + float scale_grad = 0.0f; + Empty empty = {}; + + for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { + const size_t idx = row * hidden + col; + const float unscaled = srelu(static_cast(input[idx]), empty); + const float grad = static_cast(grad_output[idx]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + const float dact = dsrelu(static_cast(input[idx]), empty); + grad_input[idx] = static_cast(scaled_grad * dact); + } + + smem[threadIdx.x] = scale_grad; + __syncthreads(); + for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + if (threadIdx.x < offset) { + smem[threadIdx.x] += smem[threadIdx.x + offset]; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + grad_act_scales[row] = static_cast(smem[0]); + } +} + +void check_scale_tensor(const Tensor *act_scales, const size_t rows, const char *api_name) { + NVTE_CHECK(act_scales->numel() == rows, api_name, ": act_scales must have one value per row."); +} + +void check_gated_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0], api_name, ": input/output row mismatch."); + NVTE_CHECK(input_dims[1] == output_dims[1] * 2, api_name, + ": gated input last dimension must be twice output last dimension."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(output_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": output last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_unary_forward_tensors(const Tensor *input, const Tensor *act_scales, + const Tensor *output, const char *api_name, size_t *rows, + size_t *hidden) { + const auto input_dims = input->flat_2d_dims(); + const auto output_dims = output->flat_2d_dims(); + NVTE_CHECK(input_dims[0] == output_dims[0] && input_dims[1] == output_dims[1], api_name, + ": input/output shapes must match."); + check_scale_tensor(act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = output_dims[1]; +} + +void check_grad_scale_tensor(const Tensor *grad_act_scales, const size_t rows, + const char *api_name) { + if (grad_act_scales != nullptr) { + NVTE_CHECK(grad_act_scales->numel() == rows, api_name, + ": grad_act_scales must have one value per row."); + } +} + +void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, + const int64_t glu_interleave_size, const char *api_name, + size_t *rows, size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(input_dims[1] == grad_dims[1] * 2 && grad_input_dims[1] == input_dims[1], api_name, + ": gated backward dimensions are inconsistent."); + NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); + if (glu_interleave_size > 0) { + NVTE_CHECK(grad_dims[1] % static_cast(glu_interleave_size) == 0, api_name, + ": grad last dimension must be divisible by glu_interleave_size."); + } + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +void check_unary_backward_tensors(const Tensor *grad_output, const Tensor *input, + const Tensor *act_scales, const Tensor *grad_input, + const Tensor *grad_act_scales, const char *api_name, + size_t *rows, size_t *hidden) { + const auto grad_dims = grad_output->flat_2d_dims(); + const auto input_dims = input->flat_2d_dims(); + const auto grad_input_dims = grad_input->flat_2d_dims(); + NVTE_CHECK(grad_dims[0] == input_dims[0] && input_dims[0] == grad_input_dims[0], api_name, + ": input/grad row mismatch."); + NVTE_CHECK(grad_dims[1] == input_dims[1] && input_dims[1] == grad_input_dims[1], api_name, + ": unary backward dimensions are inconsistent."); + check_scale_tensor(act_scales, input_dims[0], api_name); + check_grad_scale_tensor(grad_act_scales, input_dims[0], api_name); + *rows = input_dims[0]; + *hidden = grad_dims[1]; +} + +template +void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, const int64_t glu_interleave_size, + const ClampedSwiGLUParam param, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_gated_forward_tensors(input, act_scales, output, glu_interleave_size, api_name, &rows, + &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + scaled_gated_forward_kernel + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(output->data.dptr), rows, hidden, glu_interleave_size, + param); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor nvte_act_scales, + NVTETensor nvte_output, cudaStream_t stream, + const char *api_name) { + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *output = convertNVTETensorCheck(nvte_output); + size_t rows = 0; + size_t hidden = 0; + check_unary_forward_tensors(input, act_scales, output, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { + scaled_srelu_forward_kernel + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(output->data.dptr), rows, hidden); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, + const int64_t glu_interleave_size, + const ClampedSwiGLUParam param, cudaStream_t stream, + const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_gated_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + glu_interleave_size, api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + if (grad_act_scales == nullptr) { + const int blocks = + static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + scaled_gated_backward_kernel + <<>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), rows, hidden, + glu_interleave_size, param); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + scaled_gated_backward_with_scale_grad_kernel + <<(rows), threads, 0, stream>>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), + reinterpret_cast(grad_act_scales->data.dptr), rows, hidden, + glu_interleave_size, param); + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, + const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, + NVTETensor nvte_grad_act_scales, cudaStream_t stream, + const char *api_name) { + const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); + const Tensor *input = convertNVTETensorCheck(nvte_input); + const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); + Tensor *grad_input = convertNVTETensorCheck(nvte_grad_input); + Tensor *grad_act_scales = + nvte_grad_act_scales == nullptr ? nullptr : convertNVTETensorCheck(nvte_grad_act_scales); + size_t rows = 0; + size_t hidden = 0; + check_unary_backward_tensors(grad_output, input, act_scales, grad_input, grad_act_scales, + api_name, &rows, &hidden); + if (rows == 0 || hidden == 0) return; + + constexpr int threads = 256; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + if (grad_act_scales == nullptr) { + const int blocks = + static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); + scaled_srelu_backward_kernel + <<>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), rows, hidden); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { + scaled_srelu_backward_with_scale_grad_kernel + <<(rows), threads, 0, stream>>>( + reinterpret_cast(grad_output->data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(act_scales->data.dptr), + reinterpret_cast(grad_input->data.dptr), + reinterpret_cast(grad_act_scales->data.dptr), rows, hidden); + }); + } + }); + }); + }); + }); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_swiglu); + using namespace transformer_engine; + Empty empty = {}; + (void)empty; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); +} + +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {}; + launch_scaled_gated_backward( + grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, + "nvte_scaled_dswiglu"); +} + +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_forward( + input, act_scales, output, glu_interleave_size, param, stream, + "nvte_scaled_clamped_swiglu"); +} + +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; + launch_scaled_gated_backward( + grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, + "nvte_scaled_clamped_dswiglu"); +} + +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_srelu); + using namespace transformer_engine; + launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); +} + +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, cudaStream_t stream) { + NVTE_API_CALL(nvte_scaled_dsrelu); + using namespace transformer_engine; + launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, + "nvte_scaled_dsrelu"); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 4ed083740d..f1485057ec 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -368,6 +368,41 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU without materializing GLU deinterleave. + * + * Computes output = SwiGLU(input) * act_scales[:, None]. + * If glu_interleave_size > 0, input is interpreted as interleaved + * [activation_block, linear_block] chunks of that size. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + int64_t glu_interleave_size, cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU without materializing GLU deinterleave. + * + * Computes output = ClampedSwiGLU(input) * act_scales[:, None]. + * This uses the same clamping, alpha, and linear-offset semantics as + * nvte_clamped_swiglu_v2. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor of shape [N, H]. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, + NVTETensor output, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -473,6 +508,46 @@ void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTE float limit, float alpha, float glu_linear_offset, cudaStream_t stream); +/*! \brief Computes ScaledSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, int64_t glu_interleave_size, + cudaStream_t stream); + +/*! \brief Computes ScaledClampedSwiGLU backward without materializing GLU deinterleave. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * ClampedSwiGLU(input), dim=-1). + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing gradient of shape [N, H * 2]. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] limit Clipping limit. + * \param[in] alpha Activation sigmoid alpha. + * \param[in] glu_linear_offset Offset added to linear component after clamping. + * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, float limit, float alpha, + float glu_linear_offset, int64_t glu_interleave_size, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -509,6 +584,34 @@ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes ScaledSReLU. + * + * Computes output = SReLU(input) * act_scales[:, None]. + * + * \param[in] input Input tensor for activation. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output, + cudaStream_t stream); + +/*! \brief Computes ScaledSReLU backward. + * + * The optional grad_act_scales tensor may be null. When present, it receives + * sum(dY * SReLU(input), dim=-1). + * + * \param[in] grad Incoming gradient. + * \param[in] input Forward input tensor. + * \param[in] act_scales Row-wise activation scales of shape [N]. + * \param[in,out] grad_input Outgoing input gradient. + * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, + const NVTETensor act_scales, NVTETensor grad_input, + NVTETensor grad_act_scales, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif From 953c46903c47ac65e8e99cd2520761c2ae593240 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 17:34:59 -0700 Subject: [PATCH 2/6] vectorized loading improvement Signed-off-by: Zhongbo Zhu --- .../common/activation/scaled_activation.cu | 652 +++++++++++++----- 1 file changed, 463 insertions(+), 189 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index 176253edc2..7053e06241 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -12,16 +12,47 @@ * * # | Kernel | Activation | Dir | grad_act_scales | Launch * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- - * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | flat element grid - * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | flat element grid - * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | flat element grid - * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | flat element grid - * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | one block per row - * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | one block per row + * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | vectorized row segments + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized flat grid + * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | vectorized row segments + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized flat grid + * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | vectorized, one block per row + * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | vectorized, one block per row * * The "with scale grad" variants compute grad_act_scales[row] = sum_j dY * unscaled, * a per-row reduction that requires the one-block-per-row launch; when * grad_act_scales is null the cheaper flat element-wise grid is used instead. + * + * Vectorization model: + * + * Gated activations consume two FC1 streams per row: an activation stream and a + * gate stream. With no GLU interleave, the row is laid out as: + * + * [ act[0:H] | gate[0:H] ] + * + * With GLU interleave, e.g. interleave=32, the row is laid out as independent + * act/gate segments: + * + * [ act[0:32] | gate[0:32] | act[32:64] | gate[32:64] | ... ] + * + * Vector loads: + * + * interleave=0: + * input [ act0 | act1 | ... | actN | gate0 | gate1 | ... | gateN ] + * | | + * v v + * load act vector i gate vector i + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * interleave=32: + * input [ act0 | gate0 | act1 | gate1 | ... | actN | gateN ] + * | | | | + * v v v v + * load act0 gate0 act1 gate1 + * store output vector i = activation(act vector i) * gate vector i * scale[row] + * + * Only fully aligned segments use vector loads. Everything else uses the same + * kernels with nvec=1, i.e. regular elementwise loads/stores. */ #include @@ -30,6 +61,7 @@ #include "../common.h" #include "../util/math.h" +#include "../util/vectorized_pointwise.h" namespace transformer_engine { namespace { @@ -40,172 +72,306 @@ enum class ScaledActivation { kSReLU, }; -__device__ __forceinline__ void glu_input_indices(const size_t row, const size_t col, - const size_t hidden, - const int64_t glu_interleave_size, - size_t *act_idx, size_t *linear_idx) { - if (glu_interleave_size > 0) { - const size_t interleave = static_cast(glu_interleave_size); - const size_t block = col / interleave; - const size_t lane = col % interleave; - const size_t base = row * hidden * 2 + block * interleave * 2 + lane; - *act_idx = base; - *linear_idx = base + interleave; - } else { - const size_t base = row * hidden * 2; - *act_idx = base + col; - *linear_idx = base + hidden + col; - } -} - template -__device__ __forceinline__ float gated_forward_value(const float act_in, const float linear_in, +__device__ __forceinline__ float gated_forward_value(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m) { if constexpr (Act == ScaledActivation::kSwiGLU) { Empty empty = {}; - return silu(act_in, empty) * linear_in; + return silu(act_in, empty) * gate_in; } else { - const float linear = - fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; - return clamped_silu(act_in, param) * linear; + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; + return clamped_silu(act_in, param) * gate; } } template -__device__ __forceinline__ void gated_backward_values(const float act_in, const float linear_in, - const ClampedSwiGLUParam ¶m, - float *dact, float *dlinear, +__device__ __forceinline__ void gated_backward_values(const float act_in, const float gate_in, + const ClampedSwiGLUParam ¶m, float *dact, + float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { Empty empty = {}; const float act = silu(act_in, empty); - *unscaled = act * linear_in; - *dact = dsilu(act_in, empty) * linear_in; - *dlinear = act; + *unscaled = act * gate_in; + *dact = dsilu(act_in, empty) * gate_in; + *dgate = act; } else { - const bool dlinear_mask = linear_in <= param.limit && linear_in >= -param.limit; - const float linear = - fminf(fmaxf(-param.limit, linear_in), param.limit) + param.glu_linear_offset; + const bool dgate_mask = gate_in <= param.limit && gate_in >= -param.limit; + const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; const float act = clamped_silu(act_in, param); - *unscaled = act * linear; - *dact = clamped_dsilu(act_in, param) * linear; - *dlinear = dlinear_mask ? act : 0.0f; + *unscaled = act * gate; + *dact = clamped_dsilu(act_in, param) * gate; + *dgate = dgate_mask ? act : 0.0f; + } +} + +constexpr int kThreads = unary_kernel_threads; + +template +constexpr int vector_width() { + return 32 / static_cast(sizeof(T)); +} + +inline int launch_blocks(const size_t work_items) { + return static_cast( + std::min(DIVUP(work_items, static_cast(kThreads)), 65535)); +} + +template +Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs... ptrs) { + if (nvec == 1) { + return Alignment::SAME_ALIGNED; + } + // GLU interleave is handled as independent row-local segments. Keep the scalar + // fallback for odd segment widths or unaligned pointers so vector stores never + // cross from an activation segment into its paired gate segment. + if (lead_dim % static_cast(nvec) != 0) { + return Alignment::DIFFERENT; + } + const auto align = CheckAlignment(lead_dim, nvec, ptrs...); + return align == Alignment::SAME_ALIGNED ? Alignment::SAME_ALIGNED : Alignment::DIFFERENT; +} + +template +__device__ __forceinline__ bool vector_lane_index(const size_t vector_idx, const int lane, + const int alignment, const size_t length, + size_t *index) { + size_t idx = vector_idx * static_cast(nvec) + static_cast(lane); + if constexpr (!aligned) { + if (idx < static_cast(alignment)) { + return false; + } + idx -= static_cast(alignment); + } + if (idx >= length) { + return false; } + *index = idx; + return true; } -template +template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, OutputT *output, const size_t rows, - const size_t hidden, - const int64_t glu_interleave_size, + const size_t hidden, const size_t segment_size, + const size_t num_segments, + const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - const size_t total = rows * hidden; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const size_t col = idx % hidden; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - const float unscaled = gated_forward_value(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param); + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer output_storer(output + output_segment_offset, + segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); - output[idx] = static_cast(unscaled * scale); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + const float unscaled = + gated_forward_value(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + } + output_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden) { - const size_t total = rows * hidden; + OutputT *output, const size_t total, + const size_t hidden, + const size_t num_vectors) { Empty empty = {}; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const float unscaled = srelu(static_cast(input[idx]), empty); - const float scale = static_cast(act_scales[row]); - output[idx] = static_cast(unscaled * scale); + VectorizedLoader input_loader(input, total); + VectorizedStorer output_storer(output, total); + for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; + vector_idx += gridDim.x * blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + input_loader.load(vector_idx, total); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t idx = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, + &idx)) { + const size_t row = idx / hidden; + const float unscaled = srelu(static_cast(input_loader.separate()[lane]), + empty); + const float scale = static_cast(act_scales[row]); + output_storer.separate()[lane] = static_cast(unscaled * scale); + } + } + output_storer.store(vector_idx, total); } } -template -__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, - const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden, - const int64_t glu_interleave_size, - const ClampedSwiGLUParam param) { - const size_t total = rows * hidden; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const size_t col = idx % hidden; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - float dact = 0.0f; - float dlinear = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param, &dact, &dlinear, - &unscaled); - (void)unscaled; +template +__global__ void scaled_gated_backward_kernel( + const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { + const size_t total_vectors = rows * num_segments * num_vectors_per_segment; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_segment; + const size_t segment = (tid / num_vectors_per_segment) % num_segments; + const size_t row = tid / (num_vectors_per_segment * num_segments); + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_output[idx]) * scale; - grad_input[act_idx] = static_cast(grad * dact); - grad_input[linear_idx] = static_cast(grad * dlinear); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + (void)unscaled; + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + act_storer.separate()[lane] = static_cast(grad * dact); + gate_storer.separate()[lane] = static_cast(grad * dgate); + } + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden) { - const size_t total = rows * hidden; + const size_t total, const size_t hidden, + const size_t num_vectors) { Empty empty = {}; - for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; - idx += gridDim.x * blockDim.x) { - const size_t row = idx / hidden; - const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_output[idx]) * scale; - grad_input[idx] = - static_cast(grad * dsrelu(static_cast(input[idx]), empty)); + VectorizedLoader grad_loader(grad_output, total); + VectorizedLoader input_loader(input, total); + VectorizedStorer grad_input_storer(grad_input, total); + for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; + vector_idx += gridDim.x * blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, total); + input_loader.load(vector_idx, total); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t idx = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, + &idx)) { + const size_t row = idx / hidden; + const float scale = static_cast(act_scales[row]); + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + grad_input_storer.separate()[lane] = + static_cast(grad * dsrelu( + static_cast(input_loader.separate()[lane]), + empty)); + } + } + grad_input_storer.store(vector_idx, total); } } -template +template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, - const int64_t glu_interleave_size, const ClampedSwiGLUParam param) { - constexpr int kThreads = 256; + const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { __shared__ float smem[kThreads]; const size_t row = blockIdx.x; + (void)rows; float scale_grad = 0.0f; - for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { - const size_t grad_idx = row * hidden + col; - size_t act_idx = 0; - size_t linear_idx = 0; - glu_input_indices(row, col, hidden, glu_interleave_size, &act_idx, &linear_idx); - - float dact = 0.0f; - float dlinear = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(input[act_idx]), - static_cast(input[linear_idx]), param, &dact, &dlinear, - &unscaled); - const float grad = static_cast(grad_output[grad_idx]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - grad_input[act_idx] = static_cast(scaled_grad * dact); - grad_input[linear_idx] = static_cast(scaled_grad * dlinear); + for (size_t segment = 0; segment < num_segments; ++segment) { + const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; + const size_t output_segment_offset = row * hidden + segment * segment_size; + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, + segment_size); + VectorizedLoader gate_loader( + input + input_segment_offset + segment_size, segment_size); + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( + grad_input + input_segment_offset + segment_size, segment_size); + + for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_segment; + vector_idx += blockDim.x) { + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), + segment_size, &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); + } + } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); + } } smem[threadIdx.x] = scale_grad; @@ -221,26 +387,46 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( } } -template +template __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - GradScaleT *grad_act_scales, const size_t rows, const size_t hidden) { - constexpr int kThreads = 256; + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { __shared__ float smem[kThreads]; const size_t row = blockIdx.x; + (void)rows; float scale_grad = 0.0f; Empty empty = {}; - for (size_t col = threadIdx.x; col < hidden; col += blockDim.x) { - const size_t idx = row * hidden + col; - const float unscaled = srelu(static_cast(input[idx]), empty); - const float grad = static_cast(grad_output[idx]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - const float dact = dsrelu(static_cast(input[idx]), empty); - grad_input[idx] = static_cast(scaled_grad * dact); + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_row; + vector_idx += blockDim.x) { + if (vector_idx >= input_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); +#pragma unroll + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, input_loader.alignment(), hidden, + &col)) { + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + const float dact = + dsrelu(static_cast(input_loader.separate()[lane]), empty); + grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); + } + } + grad_input_storer.store(vector_idx, hidden); } smem[threadIdx.x] = scale_grad; @@ -270,6 +456,8 @@ void check_gated_forward_tensors(const Tensor *input, const Tensor *act_scales, ": gated input last dimension must be twice output last dimension."); NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); NVTE_CHECK(output_dims[1] % static_cast(glu_interleave_size) == 0, api_name, ": output last dimension must be divisible by glu_interleave_size."); } @@ -312,6 +500,8 @@ void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input ": gated backward dimensions are inconsistent."); NVTE_CHECK(glu_interleave_size >= 0, api_name, ": glu_interleave_size must be non-negative."); if (glu_interleave_size > 0) { + NVTE_CHECK(glu_interleave_size % 32 == 0, api_name, + ": nonzero glu_interleave_size must be a multiple of 32."); NVTE_CHECK(grad_dims[1] % static_cast(glu_interleave_size) == 0, api_name, ": grad last dimension must be divisible by glu_interleave_size."); } @@ -352,17 +542,34 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; - const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - scaled_gated_forward_kernel - <<>>( - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(output->data.dptr), rows, hidden, glu_interleave_size, - param); + constexpr int nvec = + sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = + row_vector_alignment(segment_size, nvec, input_ptr, input_ptr + segment_size, + output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_forward_kernel<1, true, InputT, ScaleT, OutputT, Act> + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + segment_size, num_segments, segment_size, param); + } }); }); }); @@ -380,16 +587,29 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n check_unary_forward_tensors(input, act_scales, output, api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; - const int blocks = static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - scaled_srelu_forward_kernel - <<>>( - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(output->data.dptr), rows, hidden); + constexpr int nvec = + sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto output_ptr = reinterpret_cast(output->data.dptr); + const size_t total = rows * hidden; + const auto align = CheckAlignment(total, nvec, input_ptr, output_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) : total; + const int blocks = launch_blocks(num_vectors); + if (use_vector) { + scaled_srelu_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + num_vectors); + } else { + scaled_srelu_forward_kernel<1, true, InputT, ScaleT, OutputT> + <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + total); + } }); }); }); @@ -415,32 +635,58 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET glu_interleave_size, api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && + sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const size_t segment_size = + glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; + const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; + const auto align = row_vector_alignment( + segment_size, nvec, grad_ptr, input_ptr, input_ptr + segment_size, grad_input_ptr, + grad_input_ptr + segment_size); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) + : segment_size; if (grad_act_scales == nullptr) { - const int blocks = - static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); - scaled_gated_backward_kernel - <<>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), rows, hidden, - glu_interleave_size, param); + const int blocks = launch_blocks(rows * num_segments * num_vectors); + if (use_vector) { + scaled_gated_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + num_vectors, param); + } else { + scaled_gated_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT, Act> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + rows, hidden, segment_size, num_segments, + segment_size, param); + } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - scaled_gated_backward_with_scale_grad_kernel - <<(rows), threads, 0, stream>>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), - reinterpret_cast(grad_act_scales->data.dptr), rows, hidden, - glu_interleave_size, param); + auto grad_act_scales_ptr = + reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_gated_backward_with_scale_grad_kernel< + nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, num_vectors, param); + } else { + scaled_gated_backward_with_scale_grad_kernel< + 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, segment_size, num_segments, segment_size, param); + } }); } }); @@ -466,30 +712,58 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET api_name, &rows, &hidden); if (rows == 0 || hidden == 0) return; - constexpr int threads = 256; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_output->data.dtype, GradT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && + sizeof(InputT) == sizeof(OutputT) + ? vector_width() + : 1; + const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); + const auto input_ptr = reinterpret_cast(input->data.dptr); + const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); + auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); if (grad_act_scales == nullptr) { - const int blocks = - static_cast(std::min(DIVUP(rows * hidden, static_cast(threads)), 65535)); - scaled_srelu_backward_kernel - <<>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), rows, hidden); + const size_t total = rows * hidden; + const auto align = CheckAlignment(total, nvec, grad_ptr, input_ptr, grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) + : total; + const int blocks = launch_blocks(num_vectors); + if (use_vector) { + scaled_srelu_backward_kernel + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + total, hidden, num_vectors); + } else { + scaled_srelu_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT> + <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, + total, hidden, total); + } } else { + const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, + grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) + : hidden; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - scaled_srelu_backward_with_scale_grad_kernel - <<(rows), threads, 0, stream>>>( - reinterpret_cast(grad_output->data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(act_scales->data.dptr), - reinterpret_cast(grad_input->data.dptr), - reinterpret_cast(grad_act_scales->data.dptr), rows, hidden); + auto grad_act_scales_ptr = + reinterpret_cast(grad_act_scales->data.dptr); + if (use_vector) { + scaled_srelu_backward_with_scale_grad_kernel< + nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, num_vectors); + } else { + scaled_srelu_backward_with_scale_grad_kernel< + 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + <<(rows), kThreads, 0, stream>>>( + grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, + hidden, hidden); + } }); } }); From e3ae293ef8a8996dfee03bc110bf942e20287427 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 19:56:53 -0700 Subject: [PATCH 3/6] fix bug for backward kernel Signed-off-by: Zhongbo Zhu --- .../common/activation/scaled_activation.cu | 78 ++++++++++--------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index 7053e06241..ad854f3e4e 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -106,6 +106,7 @@ __device__ __forceinline__ void gated_backward_values(const float act_in, const } constexpr int kThreads = unary_kernel_threads; +constexpr int kReductionThreads = 256; template constexpr int vector_width() { @@ -322,12 +323,18 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - __shared__ float smem[kThreads]; + __shared__ float smem[kReductionThreads]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; - for (size_t segment = 0; segment < num_segments; ++segment) { + // Flatten (segment, vector) so interleave=32 distributes all row work across + // the block instead of using only a few threads per small act/gate segment. + const size_t row_vectors = num_segments * num_vectors_per_segment; + for (size_t row_vector_idx = threadIdx.x; row_vector_idx < row_vectors; + row_vector_idx += blockDim.x) { + const size_t segment = row_vector_idx / num_vectors_per_segment; + const size_t vector_idx = row_vector_idx % num_vectors_per_segment; const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; VectorizedLoader grad_loader(grad_output + output_segment_offset, @@ -341,42 +348,39 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_segment; - vector_idx += blockDim.x) { - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, segment_size); - act_loader.load(vector_idx, segment_size); - gate_loader.load(vector_idx, segment_size); + if (vector_idx >= act_loader.num_aligned_elements()) { + continue; + } + grad_loader.load(vector_idx, segment_size); + act_loader.load(vector_idx, segment_size); + gate_loader.load(vector_idx, segment_size); #pragma unroll - for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - act_storer.separate()[lane] = static_cast(scaled_grad * dact); - gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); - } + for (int lane = 0; lane < nvec; ++lane) { + size_t col = 0; + if (vector_lane_index(vector_idx, lane, act_loader.alignment(), segment_size, + &col)) { + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scale = static_cast(act_scales[row]); + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); } - act_storer.store(vector_idx, segment_size); - gate_storer.store(vector_idx, segment_size); } + act_storer.store(vector_idx, segment_size); + gate_storer.store(vector_idx, segment_size); } smem[threadIdx.x] = scale_grad; __syncthreads(); - for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { if (threadIdx.x < offset) { smem[threadIdx.x] += smem[threadIdx.x + offset]; } @@ -393,7 +397,7 @@ __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { - __shared__ float smem[kThreads]; + __shared__ float smem[kReductionThreads]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; @@ -431,7 +435,7 @@ __global__ void scaled_srelu_backward_with_scale_grad_kernel( smem[threadIdx.x] = scale_grad; __syncthreads(); - for (int offset = kThreads / 2; offset > 0; offset >>= 1) { + for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { if (threadIdx.x < offset) { smem[threadIdx.x] += smem[threadIdx.x + offset]; } @@ -677,13 +681,13 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET if (use_vector) { scaled_gated_backward_with_scale_grad_kernel< nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { scaled_gated_backward_with_scale_grad_kernel< 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); } @@ -754,13 +758,13 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET if (use_vector) { scaled_srelu_backward_with_scale_grad_kernel< nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { scaled_srelu_backward_with_scale_grad_kernel< 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> - <<(rows), kThreads, 0, stream>>>( + <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); } From 84cbdec2e9088082abe99184347433c87178011d Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 22:14:15 -0700 Subject: [PATCH 4/6] optimize Signed-off-by: Zhongbo Zhu --- tests/cpp/operator/test_scaled_activation.cu | 19 +- .../common/activation/scaled_activation.cu | 401 ++++++++---------- 2 files changed, 181 insertions(+), 239 deletions(-) diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu index 1cb630a0bc..80641e3c5f 100644 --- a/tests/cpp/operator/test_scaled_activation.cu +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -295,12 +295,9 @@ TEST_P(ScaledActivationTest, ForwardBackward) { // for gated activations with hidden % 32 == 0; SReLU skips != 0. // 6. compute_grad_scales : whether the backward also reduces grad_act_scales. -// Regular shapes: hidden is a multiple of 32, so the interleaved (32) layout is exercised -// alongside the contiguous (0) layout. -// Regular shapes (hidden % 32 == 0) and weird/irregular shapes (tiny, prime, non-32-aligned) -// share one instantiation. Interleave is swept over {0, 32}; invalid combinations -- SReLU with -// any nonzero interleave, or a gated activation whose hidden is not divisible by the interleave -- -// are skipped at runtime by the GTEST_SKIP guards in the test body. +// Interleave is swept over {0, 32}; invalid combinations -- SReLU with any nonzero interleave, or +// a gated activation whose hidden is not divisible by the interleave -- are skipped at runtime by +// the GTEST_SKIP guards in the test body. INSTANTIATE_TEST_SUITE_P( OperatorTest_ScaledActivation, ScaledActivationTest, ::testing::Combine( @@ -309,20 +306,14 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16), // data dtype ::testing::Values(DType::kFloat32, DType::kBFloat16), // scale dtype ::testing::Values(std::pair{17, 64}, // odd rows, aligned hidden - std::pair{8, 96}, // 96 = 3 * 32 std::pair{32, 32}, // minimal aligned square std::pair{128, 128}, // square - std::pair{64, 256}, // wide hidden std::pair{256, 64}, // many rows, narrow hidden - std::pair{128, 512}, // FFN-ish width + std::pair{1024, 2048}, // large FFN-ish width std::pair{1, 1}, // single element std::pair{1, 96}, // single row std::pair{96, 1}, // single hidden column - std::pair{3, 7}, // tiny primes - std::pair{13, 100}, // non-power-of-two - std::pair{7, 257}, // prime, odd hidden - std::pair{33, 65}, // odd dims - std::pair{129, 31}), // odd rows, hidden < 32 + std::pair{13, 100}), // non-power-of-two ::testing::Values(0, 32), // contiguous + interleaved ::testing::Values(false, true)), // grad_act_scales off / on test_name_generator); diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index ad854f3e4e..d056966281 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -13,9 +13,9 @@ * # | Kernel | Activation | Dir | grad_act_scales | Launch * ---+-----------------------------------------------+------------------------+-----+-----------------+-------------------- * 1 | scaled_gated_forward_kernel | SwiGLU / ClampedSwiGLU | fwd | -- | vectorized row segments - * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized flat grid + * 2 | scaled_srelu_forward_kernel | SReLU (unary) | fwd | -- | vectorized row grid * 3 | scaled_gated_backward_kernel | SwiGLU / ClampedSwiGLU | bwd | no | vectorized row segments - * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized flat grid + * 4 | scaled_srelu_backward_kernel | SReLU | bwd | no | vectorized row grid * 5 | scaled_gated_backward_with_scale_grad_kernel | SwiGLU / ClampedSwiGLU | bwd | yes | vectorized, one block per row * 6 | scaled_srelu_backward_with_scale_grad_kernel | SReLU | bwd | yes | vectorized, one block per row * @@ -72,6 +72,10 @@ enum class ScaledActivation { kSReLU, }; +__device__ __forceinline__ float sigmoid_from_float(const float x) { + return 1.0f / (1.0f + expf(-x)); +} + template __device__ __forceinline__ float gated_forward_value(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m) { @@ -90,23 +94,55 @@ __device__ __forceinline__ void gated_backward_values(const float act_in, const float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { - Empty empty = {}; - const float act = silu(act_in, empty); + const float sigmoid = sigmoid_from_float(act_in); + const float act = act_in * sigmoid; + const float dact_base = sigmoid + act_in * sigmoid * (1.0f - sigmoid); *unscaled = act * gate_in; - *dact = dsilu(act_in, empty) * gate_in; + *dact = dact_base * gate_in; *dgate = act; } else { const bool dgate_mask = gate_in <= param.limit && gate_in >= -param.limit; const float gate = fminf(fmaxf(-param.limit, gate_in), param.limit) + param.glu_linear_offset; - const float act = clamped_silu(act_in, param); + const bool dact_mask = act_in <= param.limit; + const float clamped_act_in = fminf(act_in, param.limit); + const float sigmoid = sigmoid_from_float(param.alpha * clamped_act_in); + const float act = clamped_act_in * sigmoid; + const float dact_base = + dact_mask ? sigmoid + param.alpha * clamped_act_in * sigmoid * (1.0f - sigmoid) : 0.0f; *unscaled = act * gate; - *dact = clamped_dsilu(act_in, param) * gate; + *dact = dact_base * gate; *dgate = dgate_mask ? act : 0.0f; } } constexpr int kThreads = unary_kernel_threads; constexpr int kReductionThreads = 256; +constexpr int kReductionWarps = kReductionThreads / THREADS_PER_WARP; + +__device__ __forceinline__ float warp_reduce_sum(float value) { +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset >>= 1) { + value += __shfl_down_sync(0xffffffff, value, offset); + } + return value; +} + +__device__ __forceinline__ float block_reduce_sum(float value, float *smem) { + const int lane = threadIdx.x % THREADS_PER_WARP; + const int warp = threadIdx.x / THREADS_PER_WARP; + + value = warp_reduce_sum(value); + if (lane == 0) { + smem[warp] = value; + } + __syncthreads(); + + value = threadIdx.x < kReductionWarps ? smem[lane] : 0.0f; + if (warp == 0) { + value = warp_reduce_sum(value); + } + return value; +} template constexpr int vector_width() { @@ -133,26 +169,7 @@ Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs return align == Alignment::SAME_ALIGNED ? Alignment::SAME_ALIGNED : Alignment::DIFFERENT; } -template -__device__ __forceinline__ bool vector_lane_index(const size_t vector_idx, const int lane, - const int alignment, const size_t length, - size_t *index) { - size_t idx = vector_idx * static_cast(nvec) + static_cast(lane); - if constexpr (!aligned) { - if (idx < static_cast(alignment)) { - return false; - } - idx -= static_cast(alignment); - } - if (idx >= length) { - return false; - } - *index = idx; - return true; -} - -template +template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, OutputT *output, const size_t rows, const size_t hidden, const size_t segment_size, @@ -168,66 +185,52 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer output_storer(output + output_segment_offset, - segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - + VectorizedStorer output_storer(output + output_segment_offset, + segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - const float unscaled = - gated_forward_value(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param); - output_storer.separate()[lane] = static_cast(unscaled * scale); - } + const float unscaled = + gated_forward_value(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param); + output_storer.separate()[lane] = static_cast(unscaled * scale); } output_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t total, + OutputT *output, const size_t rows, const size_t hidden, - const size_t num_vectors) { + const size_t num_vectors_per_row) { Empty empty = {}; - VectorizedLoader input_loader(input, total); - VectorizedStorer output_storer(output, total); - for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; - vector_idx += gridDim.x * blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } - input_loader.load(vector_idx, total); + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer output_storer(output + row * hidden, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t idx = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, - &idx)) { - const size_t row = idx / hidden; - const float unscaled = srelu(static_cast(input_loader.separate()[lane]), - empty); - const float scale = static_cast(act_scales[row]); - output_storer.separate()[lane] = static_cast(unscaled * scale); - } + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + output_storer.separate()[lane] = static_cast(unscaled * scale); } - output_storer.store(vector_idx, total); + output_storer.store(vector_idx, hidden); } } -template +template __global__ void scaled_gated_backward_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, @@ -241,92 +244,77 @@ __global__ void scaled_gated_backward_kernel( const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader grad_loader(grad_output + output_segment_offset, - segment_size); - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer act_storer(grad_input + input_segment_offset, - segment_size); - VectorizedStorer gate_storer( + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), - segment_size, &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - (void)unscaled; - const float grad = static_cast(grad_loader.separate()[lane]) * scale; - act_storer.separate()[lane] = static_cast(grad * dact); - gate_storer.separate()[lane] = static_cast(grad * dgate); - } + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + (void)unscaled; + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + act_storer.separate()[lane] = static_cast(grad * dact); + gate_storer.separate()[lane] = static_cast(grad * dgate); } act_storer.store(vector_idx, segment_size); gate_storer.store(vector_idx, segment_size); } } -template +template __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t total, const size_t hidden, - const size_t num_vectors) { + const size_t rows, const size_t hidden, + const size_t num_vectors_per_row) { Empty empty = {}; - VectorizedLoader grad_loader(grad_output, total); - VectorizedLoader input_loader(input, total); - VectorizedStorer grad_input_storer(grad_input, total); - for (size_t vector_idx = blockIdx.x * blockDim.x + threadIdx.x; vector_idx < num_vectors; - vector_idx += gridDim.x * blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } - grad_loader.load(vector_idx, total); - input_loader.load(vector_idx, total); + const size_t total_vectors = rows * num_vectors_per_row; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; + tid += gridDim.x * blockDim.x) { + const size_t vector_idx = tid % num_vectors_per_row; + const size_t row = tid / num_vectors_per_row; + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + grad_loader.load(vector_idx, hidden); + input_loader.load(vector_idx, hidden); + const float scale = static_cast(act_scales[row]); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t idx = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), total, - &idx)) { - const size_t row = idx / hidden; - const float scale = static_cast(act_scales[row]); - const float grad = static_cast(grad_loader.separate()[lane]) * scale; - grad_input_storer.separate()[lane] = - static_cast(grad * dsrelu( - static_cast(input_loader.separate()[lane]), - empty)); - } + const float grad = static_cast(grad_loader.separate()[lane]) * scale; + grad_input_storer.separate()[lane] = + static_cast( + grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); } - grad_input_storer.store(vector_idx, total); + grad_input_storer.store(vector_idx, hidden); } } -template +template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { - __shared__ float smem[kReductionThreads]; + __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; + const float scale = static_cast(act_scales[row]); // Flatten (segment, vector) so interleave=32 distributes all row work across // the block instead of using only a few threads per small act/gate segment. @@ -337,112 +325,82 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( const size_t vector_idx = row_vector_idx % num_vectors_per_segment; const size_t input_segment_offset = row * hidden * 2 + segment * segment_size * 2; const size_t output_segment_offset = row * hidden + segment * segment_size; - VectorizedLoader grad_loader(grad_output + output_segment_offset, - segment_size); - VectorizedLoader act_loader(input + input_segment_offset, - segment_size); - VectorizedLoader gate_loader( + VectorizedLoader grad_loader(grad_output + output_segment_offset, + segment_size); + VectorizedLoader act_loader(input + input_segment_offset, segment_size); + VectorizedLoader gate_loader( input + input_segment_offset + segment_size, segment_size); - VectorizedStorer act_storer(grad_input + input_segment_offset, - segment_size); - VectorizedStorer gate_storer( + VectorizedStorer act_storer(grad_input + input_segment_offset, + segment_size); + VectorizedStorer gate_storer( grad_input + input_segment_offset + segment_size, segment_size); - if (vector_idx >= act_loader.num_aligned_elements()) { - continue; - } grad_loader.load(vector_idx, segment_size); act_loader.load(vector_idx, segment_size); gate_loader.load(vector_idx, segment_size); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, act_loader.alignment(), segment_size, - &col)) { - float dact = 0.0f; - float dgate = 0.0f; - float unscaled = 0.0f; - gated_backward_values(static_cast(act_loader.separate()[lane]), - static_cast(gate_loader.separate()[lane]), param, &dact, - &dgate, &unscaled); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - act_storer.separate()[lane] = static_cast(scaled_grad * dact); - gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); - } + float dact = 0.0f; + float dgate = 0.0f; + float unscaled = 0.0f; + gated_backward_values(static_cast(act_loader.separate()[lane]), + static_cast(gate_loader.separate()[lane]), param, &dact, + &dgate, &unscaled); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + act_storer.separate()[lane] = static_cast(scaled_grad * dact); + gate_storer.separate()[lane] = static_cast(scaled_grad * dgate); } act_storer.store(vector_idx, segment_size); gate_storer.store(vector_idx, segment_size); } - smem[threadIdx.x] = scale_grad; - __syncthreads(); - for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - smem[threadIdx.x] += smem[threadIdx.x + offset]; - } - __syncthreads(); - } + scale_grad = block_reduce_sum(scale_grad, smem); if (threadIdx.x == 0) { - grad_act_scales[row] = static_cast(smem[0]); + grad_act_scales[row] = static_cast(scale_grad); } } -template +template __global__ void scaled_srelu_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { - __shared__ float smem[kReductionThreads]; + __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; (void)rows; float scale_grad = 0.0f; Empty empty = {}; + const float scale = static_cast(act_scales[row]); - VectorizedLoader grad_loader(grad_output + row * hidden, hidden); - VectorizedLoader input_loader(input + row * hidden, hidden); - VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); + VectorizedLoader grad_loader(grad_output + row * hidden, hidden); + VectorizedLoader input_loader(input + row * hidden, hidden); + VectorizedStorer grad_input_storer(grad_input + row * hidden, hidden); for (size_t vector_idx = threadIdx.x; vector_idx < num_vectors_per_row; vector_idx += blockDim.x) { - if (vector_idx >= input_loader.num_aligned_elements()) { - continue; - } grad_loader.load(vector_idx, hidden); input_loader.load(vector_idx, hidden); #pragma unroll for (int lane = 0; lane < nvec; ++lane) { - size_t col = 0; - if (vector_lane_index(vector_idx, lane, input_loader.alignment(), hidden, - &col)) { - const float unscaled = - srelu(static_cast(input_loader.separate()[lane]), empty); - const float grad = static_cast(grad_loader.separate()[lane]); - scale_grad += grad * unscaled; - - const float scale = static_cast(act_scales[row]); - const float scaled_grad = grad * scale; - const float dact = - dsrelu(static_cast(input_loader.separate()[lane]), empty); - grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); - } + const float unscaled = + srelu(static_cast(input_loader.separate()[lane]), empty); + const float grad = static_cast(grad_loader.separate()[lane]); + scale_grad += grad * unscaled; + + const float scaled_grad = grad * scale; + const float dact = + dsrelu(static_cast(input_loader.separate()[lane]), empty); + grad_input_storer.separate()[lane] = static_cast(scaled_grad * dact); } grad_input_storer.store(vector_idx, hidden); } - smem[threadIdx.x] = scale_grad; - __syncthreads(); - for (int offset = kReductionThreads / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - smem[threadIdx.x] += smem[threadIdx.x + offset]; - } - __syncthreads(); - } + scale_grad = block_reduce_sum(scale_grad, smem); if (threadIdx.x == 0) { - grad_act_scales[row] = static_cast(smem[0]); + grad_act_scales[row] = static_cast(scale_grad); } } @@ -566,11 +524,11 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n : segment_size; const int blocks = launch_blocks(rows * num_segments * num_vectors); if (use_vector) { - scaled_gated_forward_kernel + scaled_gated_forward_kernel <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_forward_kernel<1, true, InputT, ScaleT, OutputT, Act> + scaled_gated_forward_kernel<1, InputT, ScaleT, OutputT, Act> <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, segment_size, num_segments, segment_size, param); } @@ -599,20 +557,19 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); - const size_t total = rows * hidden; - const auto align = CheckAlignment(total, nvec, input_ptr, output_ptr); + const auto align = row_vector_alignment(hidden, nvec, input_ptr, output_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) : total; - const int blocks = launch_blocks(num_vectors); + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) : hidden; + const int blocks = launch_blocks(rows * num_vectors); if (use_vector) { - scaled_srelu_forward_kernel - <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, + scaled_srelu_forward_kernel + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_forward_kernel<1, true, InputT, ScaleT, OutputT> - <<>>(input_ptr, scale_ptr, output_ptr, total, hidden, - total); + scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT> + <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, + hidden); } }); }); @@ -664,12 +621,12 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET if (grad_act_scales == nullptr) { const int blocks = launch_blocks(rows * num_segments * num_vectors); if (use_vector) { - scaled_gated_backward_kernel + scaled_gated_backward_kernel <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT, Act> + scaled_gated_backward_kernel<1, GradT, InputT, ScaleT, OutputT, Act> <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -680,13 +637,13 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { scaled_gated_backward_with_scale_grad_kernel< - nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + nvec, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { scaled_gated_backward_with_scale_grad_kernel< - 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + 1, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -728,42 +685,36 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); + const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, + grad_input_ptr); + const bool use_vector = align == Alignment::SAME_ALIGNED; + const size_t num_vectors = + use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) + : hidden; if (grad_act_scales == nullptr) { - const size_t total = rows * hidden; - const auto align = CheckAlignment(total, nvec, grad_ptr, input_ptr, grad_input_ptr); - const bool use_vector = align == Alignment::SAME_ALIGNED; - const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, total, nvec, sizeof(InputT)) - : total; - const int blocks = launch_blocks(num_vectors); + const int blocks = launch_blocks(rows * num_vectors); if (use_vector) { - scaled_srelu_backward_kernel + scaled_srelu_backward_kernel <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, - total, hidden, num_vectors); + rows, hidden, num_vectors); } else { - scaled_srelu_backward_kernel<1, true, GradT, InputT, ScaleT, OutputT> + scaled_srelu_backward_kernel<1, GradT, InputT, ScaleT, OutputT> <<>>(grad_ptr, input_ptr, scale_ptr, grad_input_ptr, - total, hidden, total); + rows, hidden, hidden); } } else { - const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, - grad_input_ptr); - const bool use_vector = align == Alignment::SAME_ALIGNED; - const size_t num_vectors = - use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) - : hidden; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { scaled_srelu_backward_with_scale_grad_kernel< - nvec, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + nvec, GradT, InputT, ScaleT, OutputT, GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { scaled_srelu_backward_with_scale_grad_kernel< - 1, true, GradT, InputT, ScaleT, OutputT, GradScaleT> + 1, GradT, InputT, ScaleT, OutputT, GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); From c73c8ea1c91db6b81c5ff2e8421155eaf52f6dd7 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Jun 2026 22:31:36 -0700 Subject: [PATCH 5/6] fix unit test failure Signed-off-by: Zhongbo Zhu --- tests/cpp/operator/test_scaled_activation.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/test_scaled_activation.cu b/tests/cpp/operator/test_scaled_activation.cu index 80641e3c5f..72a64a3c04 100644 --- a/tests/cpp/operator/test_scaled_activation.cu +++ b/tests/cpp/operator/test_scaled_activation.cu @@ -231,10 +231,12 @@ void run_scaled_activation_test(ScaledActivationCase activation, const size_t ro atol = 5e-5; rtol = 5e-5; } - compareResults("scaled_activation_output", output, ref_output.get(), atol, rtol); - compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), atol, rtol); + compareResults("scaled_activation_output", output, ref_output.get(), true, atol, rtol); + compareResults("scaled_activation_grad_input", grad_input, ref_grad_input.get(), true, atol, + rtol); if (compute_grad_scales) { - compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), atol, rtol); + compareResults("scaled_activation_grad_scales", grad_scales, ref_grad_scales.get(), true, atol, + rtol); } } From 3eb18a6e05b84cba0983a1e754e572ab2d4d1761 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 07:16:09 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/activation/scaled_activation.cu | 131 ++++++++---------- .../include/transformer_engine/activation.h | 12 +- 2 files changed, 63 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/activation/scaled_activation.cu b/transformer_engine/common/activation/scaled_activation.cu index d056966281..73df92338c 100644 --- a/transformer_engine/common/activation/scaled_activation.cu +++ b/transformer_engine/common/activation/scaled_activation.cu @@ -91,8 +91,7 @@ __device__ __forceinline__ float gated_forward_value(const float act_in, const f template __device__ __forceinline__ void gated_backward_values(const float act_in, const float gate_in, const ClampedSwiGLUParam ¶m, float *dact, - float *dgate, - float *unscaled) { + float *dgate, float *unscaled) { if constexpr (Act == ScaledActivation::kSwiGLU) { const float sigmoid = sigmoid_from_float(act_in); const float act = act_in * sigmoid; @@ -171,9 +170,8 @@ Alignment row_vector_alignment(const size_t lead_dim, const int nvec, const Ptrs template __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden, const size_t segment_size, - const size_t num_segments, + OutputT *output, const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { const size_t total_vectors = rows * num_segments * num_vectors_per_segment; @@ -186,8 +184,8 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a const size_t output_segment_offset = row * hidden + segment * segment_size; VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer output_storer(output + output_segment_offset, segment_size); act_loader.load(vector_idx, segment_size); @@ -206,8 +204,7 @@ __global__ void scaled_gated_forward_kernel(const InputT *input, const ScaleT *a template __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *act_scales, - OutputT *output, const size_t rows, - const size_t hidden, + OutputT *output, const size_t rows, const size_t hidden, const size_t num_vectors_per_row) { Empty empty = {}; const size_t total_vectors = rows * num_vectors_per_row; @@ -231,10 +228,12 @@ __global__ void scaled_srelu_forward_kernel(const InputT *input, const ScaleT *a template -__global__ void scaled_gated_backward_kernel( - const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - const size_t rows, const size_t hidden, const size_t segment_size, const size_t num_segments, - const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { +__global__ void scaled_gated_backward_kernel(const GradT *grad_output, const InputT *input, + const ScaleT *act_scales, OutputT *grad_input, + const size_t rows, const size_t hidden, + const size_t segment_size, const size_t num_segments, + const size_t num_vectors_per_segment, + const ClampedSwiGLUParam param) { const size_t total_vectors = rows * num_segments * num_vectors_per_segment; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < total_vectors; tid += gridDim.x * blockDim.x) { @@ -247,8 +246,8 @@ __global__ void scaled_gated_backward_kernel( VectorizedLoader grad_loader(grad_output + output_segment_offset, segment_size); VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer act_storer(grad_input + input_segment_offset, segment_size); VectorizedStorer gate_storer( @@ -295,9 +294,8 @@ __global__ void scaled_srelu_backward_kernel(const GradT *grad_output, const Inp #pragma unroll for (int lane = 0; lane < nvec; ++lane) { const float grad = static_cast(grad_loader.separate()[lane]) * scale; - grad_input_storer.separate()[lane] = - static_cast( - grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); + grad_input_storer.separate()[lane] = static_cast( + grad * dsrelu(static_cast(input_loader.separate()[lane]), empty)); } grad_input_storer.store(vector_idx, hidden); } @@ -307,8 +305,8 @@ template __global__ void scaled_gated_backward_with_scale_grad_kernel( const GradT *grad_output, const InputT *input, const ScaleT *act_scales, OutputT *grad_input, - GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, - const size_t segment_size, const size_t num_segments, const size_t num_vectors_per_segment, + GradScaleT *grad_act_scales, const size_t rows, const size_t hidden, const size_t segment_size, + const size_t num_segments, const size_t num_vectors_per_segment, const ClampedSwiGLUParam param) { __shared__ float smem[kReductionWarps]; const size_t row = blockIdx.x; @@ -328,8 +326,8 @@ __global__ void scaled_gated_backward_with_scale_grad_kernel( VectorizedLoader grad_loader(grad_output + output_segment_offset, segment_size); VectorizedLoader act_loader(input + input_segment_offset, segment_size); - VectorizedLoader gate_loader( - input + input_segment_offset + segment_size, segment_size); + VectorizedLoader gate_loader(input + input_segment_offset + segment_size, + segment_size); VectorizedStorer act_storer(grad_input + input_segment_offset, segment_size); VectorizedStorer gate_storer( @@ -450,9 +448,8 @@ void check_grad_scale_tensor(const Tensor *grad_act_scales, const size_t rows, void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input, const Tensor *act_scales, const Tensor *grad_input, - const Tensor *grad_act_scales, - const int64_t glu_interleave_size, const char *api_name, - size_t *rows, size_t *hidden) { + const Tensor *grad_act_scales, const int64_t glu_interleave_size, + const char *api_name, size_t *rows, size_t *hidden) { const auto grad_dims = grad_output->flat_2d_dims(); const auto input_dims = input->flat_2d_dims(); const auto grad_input_dims = grad_input->flat_2d_dims(); @@ -475,8 +472,8 @@ void check_gated_backward_tensors(const Tensor *grad_output, const Tensor *input void check_unary_backward_tensors(const Tensor *grad_output, const Tensor *input, const Tensor *act_scales, const Tensor *grad_input, - const Tensor *grad_act_scales, const char *api_name, - size_t *rows, size_t *hidden) { + const Tensor *grad_act_scales, const char *api_name, size_t *rows, + size_t *hidden) { const auto grad_dims = grad_output->flat_2d_dims(); const auto input_dims = input->flat_2d_dims(); const auto grad_input_dims = grad_input->flat_2d_dims(); @@ -507,17 +504,15 @@ void launch_scaled_gated_forward(const NVTETensor nvte_input, const NVTETensor n TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - constexpr int nvec = - sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); const size_t segment_size = glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; - const auto align = - row_vector_alignment(segment_size, nvec, input_ptr, input_ptr + segment_size, - output_ptr); + const auto align = row_vector_alignment(segment_size, nvec, input_ptr, + input_ptr + segment_size, output_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) @@ -552,8 +547,7 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(output->data.dtype, OutputT, { - constexpr int nvec = - sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; + constexpr int nvec = sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto output_ptr = reinterpret_cast(output->data.dptr); @@ -567,9 +561,8 @@ void launch_scaled_srelu_forward(const NVTETensor nvte_input, const NVTETensor n <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT> - <<>>(input_ptr, scale_ptr, output_ptr, rows, hidden, - hidden); + scaled_srelu_forward_kernel<1, InputT, ScaleT, OutputT><<>>( + input_ptr, scale_ptr, output_ptr, rows, hidden, hidden); } }); }); @@ -581,9 +574,8 @@ template void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTETensor nvte_input, const NVTETensor nvte_act_scales, NVTETensor nvte_grad_input, NVTETensor nvte_grad_act_scales, - const int64_t glu_interleave_size, - const ClampedSwiGLUParam param, cudaStream_t stream, - const char *api_name) { + const int64_t glu_interleave_size, const ClampedSwiGLUParam param, + cudaStream_t stream, const char *api_name) { const Tensor *grad_output = convertNVTETensorCheck(nvte_grad_output); const Tensor *input = convertNVTETensorCheck(nvte_input); const Tensor *act_scales = convertNVTETensorCheck(nvte_act_scales); @@ -600,8 +592,7 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { - constexpr int nvec = sizeof(GradT) == sizeof(InputT) && - sizeof(InputT) == sizeof(OutputT) + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); @@ -611,9 +602,9 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET const size_t segment_size = glu_interleave_size > 0 ? static_cast(glu_interleave_size) : hidden; const size_t num_segments = glu_interleave_size > 0 ? hidden / segment_size : 1; - const auto align = row_vector_alignment( - segment_size, nvec, grad_ptr, input_ptr, input_ptr + segment_size, grad_input_ptr, - grad_input_ptr + segment_size); + const auto align = row_vector_alignment(segment_size, nvec, grad_ptr, input_ptr, + input_ptr + segment_size, grad_input_ptr, + grad_input_ptr + segment_size); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, segment_size, nvec, sizeof(InputT)) @@ -633,17 +624,16 @@ void launch_scaled_gated_backward(const NVTETensor nvte_grad_output, const NVTET } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - auto grad_act_scales_ptr = - reinterpret_cast(grad_act_scales->data.dptr); + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { - scaled_gated_backward_with_scale_grad_kernel< - nvec, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + scaled_gated_backward_with_scale_grad_kernel <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, num_vectors, param); } else { - scaled_gated_backward_with_scale_grad_kernel< - 1, GradT, InputT, ScaleT, OutputT, GradScaleT, Act> + scaled_gated_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT, Act> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, segment_size, num_segments, segment_size, param); @@ -677,16 +667,15 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->data.dtype, InputT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(act_scales->data.dtype, ScaleT, { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_input->data.dtype, OutputT, { - constexpr int nvec = sizeof(GradT) == sizeof(InputT) && - sizeof(InputT) == sizeof(OutputT) + constexpr int nvec = sizeof(GradT) == sizeof(InputT) && sizeof(InputT) == sizeof(OutputT) ? vector_width() : 1; const auto grad_ptr = reinterpret_cast(grad_output->data.dptr); const auto input_ptr = reinterpret_cast(input->data.dptr); const auto scale_ptr = reinterpret_cast(act_scales->data.dptr); auto grad_input_ptr = reinterpret_cast(grad_input->data.dptr); - const auto align = row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, - grad_input_ptr); + const auto align = + row_vector_alignment(hidden, nvec, grad_ptr, input_ptr, grad_input_ptr); const bool use_vector = align == Alignment::SAME_ALIGNED; const size_t num_vectors = use_vector ? get_num_aligned_elements(input_ptr, hidden, nvec, sizeof(InputT)) @@ -704,17 +693,16 @@ void launch_scaled_srelu_backward(const NVTETensor nvte_grad_output, const NVTET } } else { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(grad_act_scales->data.dtype, GradScaleT, { - auto grad_act_scales_ptr = - reinterpret_cast(grad_act_scales->data.dptr); + auto grad_act_scales_ptr = reinterpret_cast(grad_act_scales->data.dptr); if (use_vector) { - scaled_srelu_backward_with_scale_grad_kernel< - nvec, GradT, InputT, ScaleT, OutputT, GradScaleT> + scaled_srelu_backward_with_scale_grad_kernel <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, num_vectors); } else { - scaled_srelu_backward_with_scale_grad_kernel< - 1, GradT, InputT, ScaleT, OutputT, GradScaleT> + scaled_srelu_backward_with_scale_grad_kernel<1, GradT, InputT, ScaleT, OutputT, + GradScaleT> <<(rows), kReductionThreads, 0, stream>>>( grad_ptr, input_ptr, scale_ptr, grad_input_ptr, grad_act_scales_ptr, rows, hidden, hidden); @@ -742,16 +730,15 @@ void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVT input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu"); } -void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, int64_t glu_interleave_size, - cudaStream_t stream) { +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_dswiglu); using namespace transformer_engine; ClampedSwiGLUParam param = {}; - launch_scaled_gated_backward( - grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream, - "nvte_scaled_dswiglu"); + launch_scaled_gated_backward(grad, input, act_scales, grad_input, + grad_act_scales, glu_interleave_size, + param, stream, "nvte_scaled_dswiglu"); } void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales, @@ -762,8 +749,7 @@ void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_sca using namespace transformer_engine; ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; launch_scaled_gated_forward( - input, act_scales, output, glu_interleave_size, param, stream, - "nvte_scaled_clamped_swiglu"); + input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_clamped_swiglu"); } void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, @@ -786,9 +772,8 @@ void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTE launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu"); } -void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, cudaStream_t stream) { +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream) { NVTE_API_CALL(nvte_scaled_dsrelu); using namespace transformer_engine; launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream, diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index f1485057ec..ed90428f8c 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -521,10 +521,9 @@ void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTE * \param[in] glu_interleave_size GLU interleave chunk size, or 0 for non-interleaved layout. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, int64_t glu_interleave_size, - cudaStream_t stream); +void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, + int64_t glu_interleave_size, cudaStream_t stream); /*! \brief Computes ScaledClampedSwiGLU backward without materializing GLU deinterleave. * @@ -608,9 +607,8 @@ void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTE * \param[in,out] grad_act_scales Optional row-wise scale gradient of shape [N], or null. * \param[in] stream CUDA stream used for the operation. */ -void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, - const NVTETensor act_scales, NVTETensor grad_input, - NVTETensor grad_act_scales, cudaStream_t stream); +void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales, + NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream); #ifdef __cplusplus } // extern "C"