From d0dbe66628a890b937f7c046d5865313199f7b39 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 26 Nov 2025 12:26:10 -0800 Subject: [PATCH 01/36] rowwise colwise RHT group quant v1 Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_grouped_linear.py | 25 +- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/cast/cast.cu | 15 + .../common/cast/dispatch/quantize.cuh | 65 + .../nvfp4/group_quantize_transpose_nvfp4.cuh | 902 ++++++++++++++ .../group_hadamard_transform_cast_fusion.cu | 1063 +++++++++++++++++ .../hadamard_transform_cast_fusion.cu | 1 - .../common/include/transformer_engine/cast.h | 14 + .../transformer_engine/hadamard_transform.h | 18 + .../pytorch/csrc/extensions/cast.cpp | 117 +- 10 files changed, 2162 insertions(+), 59 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh create mode 100644 transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 02e2bcf4b9..464219a7fe 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -173,7 +173,9 @@ def benchmark_linear( return timing_ms -def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None): +def run_benchmark_linear( + mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False +): data = [] assert not use_bias, "Bias is not supported for GroupedLinear benchmark" @@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None device = "cuda" x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] - assert m % num_gemms == 0 - m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits + m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided # Bias is not supported for GroupedLinear benchmark bias = None # Run the benchmark print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") print(f"m_splits: {m_splits}") + print(f"fwd_only: {fwd_only}") grouped_fwd_bwd_timing_ms = benchmark_linear( x, @@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None m_splits, bias, recipe_name, - mode="fwd_bwd", + mode="fwd_only" if fwd_only else "fwd_bwd", num_gemms=num_gemms, ) @@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ] ) + timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms" + df = pd.DataFrame( data=data, columns=[ @@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None "n", "recipe", "num_gemms", - "grouped_fwd_bwd_time_ms", + timing_notation, ], ) @@ -266,6 +270,12 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None default=2048, help="Output dimension to use, default is 2048", ) + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Run forward pass only, default is both forward and backward passes", + ) args = parser.parse_args() jagged_input_splits = None @@ -297,7 +307,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None if jagged_input_splits is not None: num_gemms_list = [len(jagged_input_splits)] - token_dim_list = [65536] + token_dim_list = [16384, 32768, 65536, 98304] hidden_dim_list = [7168] output_dim_list = [2048] @@ -371,7 +381,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None recipe_name, use_bias, num_gemms=num_gemms, - m_splits=jagged_input_splits, + m_splits_provided=jagged_input_splits, + fwd_only=args.fwd_only, ) df_linears = pd.concat([df_linears, df]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 264f7f9a78..34a99d10e5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -174,6 +174,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu + hadamard_transform/group_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu transpose/quantize_transpose_square_blockwise.cu diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 1ed46a3359..73467d7275 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); } } + +// Group quantize assumes contiguous inputs and outputs in memory allocation +// TODO (zhongbo): find a better way to make it a more generalized API +void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + dispatch::group_quantize_fwd_helper(input, outputs, split_sections, + num_tensors, quant_config, stream); +} diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9f7a4a9b01..6d4454402c 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -19,6 +19,7 @@ #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens } } +template +void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_tensors; + for (size_t i = 0; i < num_tensors; ++i) { + output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + } + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Take the scaling mode of the first output tensor + auto scaling_mode = output_tensors[0]->scaling_mode; + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + // Skip checking output tensor list + // output list here is allowed to have empty tensor + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + + // Launch NVFP4 group quantize kernel + nvfp4::group_quantize_transpose( + *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } +} + } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh new file mode 100644 index 0000000000..b8eec22945 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -0,0 +1,902 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace group_quantize_transpose_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed +struct MultiAmaxCastTransposeFusionArgs { + // Amax buffer for rowwise scaling + void *rowwise_amax_list[kMaxTensorsPerKernel]; + // Rowwise scale pointers with 128x4 padding included for rowwise scaling + void *output_rowwise_scale_inv_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) Amax buffer for colwise scaling + void *colwise_amax_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output data pointers for fp4 transposed output + void *output_colwise_data_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output scale inverse pointers for each tensor + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output scale stride for colwise scaling + int output_colwise_scale_stride[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +__device__ __forceinline__ int GetTensorId(MultiAmaxCastTransposeFusionArgs *kernel_args_ptr, + int offset) { + // check the kernel args and get the corresponding id + int tensor_id = 0; + while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) { + ++tensor_id; + } + return tensor_id; +} + +// Helper to get tensor id at offset, and also whether [offset_start, offset_end) crosses a split boundary. +__device__ __forceinline__ int GetTensorIdAndBoundary( + MultiAmaxCastTransposeFusionArgs *kernel_args_ptr, int offset_start, int offset_end, + bool *cross_boundary) { + int tensor_id_start = 0; + while (kernel_args_ptr->split_sections_range[tensor_id_start + 1] <= offset_start) { + ++tensor_id_start; + } + int tensor_id_end = tensor_id_start; + if (offset_end != offset_start) { + if (kernel_args_ptr->split_sections_range[tensor_id_start + 1] < offset_end) { + tensor_id_end = tensor_id_start + 1; + } + } + if (cross_boundary) { + *cross_boundary = (tensor_id_start != tensor_id_end); + } + return tensor_id_start; +} + +__device__ __forceinline__ void UpdateEncodeDecodeScaleFP32(float *amax_ptr, float *s_enc_ptr, + float *s_dec_ptr) { + float s_env_value = + (amax_ptr == nullptr) ? 1.0f : compute_global_encode_scaling_factor_FP4(*amax_ptr); + float s_dec_value = 1.0 / s_env_value; + *s_enc_ptr = s_env_value; + *s_dec_ptr = s_dec_value; + return; +} + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +template +__global__ void __launch_bounds__(THREADS_NUM) + group_quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + nvfp4_scale_t *const scales_ptr, const float *noop, + const size_t rows, const size_t cols, + const size_t scale_stride, const size_t *rng_state, + MultiAmaxCastTransposeFusionArgs kernel_args) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + // TODO(zhongbo): add back when transpose is supported + // const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + // const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + + // TODO(zhongbo): add back when transpose is supported + // const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // TODO (zhongbo): finish this + float *amax_rowwise_ptr = nullptr; + float *amax_colwise_ptr = nullptr; + nvfp4_scale_t *split_rowwise_scale_ptr = nullptr; + + // suppose the amax is fixed for the current 128x128 tile (need 128 padding) + bool need_update_tensor_id = true; + int tensor_id = GetTensorIdAndBoundary(&kernel_args, block_offset_Y, block_offset_Y + CHUNK_DIM_Y, + &need_update_tensor_id); + size_t split_start = kernel_args.split_sections_range[tensor_id]; + size_t split_end = kernel_args.split_sections_range[tensor_id + 1]; + amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); + split_rowwise_scale_ptr = + reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); + + float S_enc_rowwise = 1.0f; + float S_dec_rowwise = 1.0f; + UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); + + // TODO (zhongbo): colwise scaling disabled for now because of transpose + float S_enc_colwise = 1.0f; + float S_dec_colwise = 1.0f; + if (amax_colwise_ptr != nullptr) { + UpdateEncodeDecodeScaleFP32(amax_colwise_ptr, &S_enc_colwise, &S_dec_colwise); + } else { + S_enc_colwise = S_enc_rowwise; + S_dec_colwise = S_dec_rowwise; + } + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + // for stages from 1 to STAGES - 1, we need to update the tensor id + // skip updating tensor id if it's the last CTA, and some stages will be out of bounds + if (need_update_tensor_id && stage > 0 && (block_offset_Y + stage_offset_Y < rows)) { + int new_tensor_id = GetTensorId(&kernel_args, block_offset_Y + stage_offset_Y); + if (new_tensor_id != tensor_id) { + tensor_id = new_tensor_id; + split_start = kernel_args.split_sections_range[tensor_id]; + split_end = kernel_args.split_sections_range[tensor_id + 1]; + amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); + UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); + split_rowwise_scale_ptr = + reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); + // TODO (zhongbo): colwise scaling disabled for now because of transpose + // Skip fetching colwise amax pointer and scaling factor updates + } + } + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE) < chunk_rows; + + // TODO(zhongbo): depending on input padding multiple (whether 128 or 64), use either scale_ptr or split_rowwise_scale_ptr + // const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + // scales_ptr[scale_idx_global] = S_dec_b_fp8; + // } + + // Map to local split coordinates + const size_t split_rows = split_end - split_start; + const size_t local_scale_row = scales_offset_Y - split_start; + + // Local bounds: 0 <= local_scale_row < split_rows + const bool local_rowwise_scale_is_within_bounds_Y = local_scale_row < split_rows; + + // Index inside this split’s scale buffer + const size_t scale_idx_local = local_scale_row * scale_stride + scales_offset_X; + + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y && + local_rowwise_scale_is_within_bounds_Y) { + split_rowwise_scale_ptr[scale_idx_local] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + // TODO(zhongbo): add back when transpose is supported + // const size_t global_offset_Y_t = block_offset_Y_t; + // const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + // TODO(zhongbo): add back when transpose is supported + // if constexpr (RETURN_TRANSPOSE) { + // ptx::cp_async_bulk_tensor_2d_shared_to_global( + // reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + // global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + // } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // TODO(zhongbo): add back when transpose is supported + // Vectorized store scaling factors through SHMEM + // if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + // using ScalesVec = Vec; + // const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + // ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + // const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + // const size_t count = // number of scales in Y dimension of this chunk + // (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + // nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + // constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + // if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // // Fast path: vectorized store when destination is properly aligned + // scales_vec.store_to(dst); + // } else { + // // Safe path: element-wise store for tails or unaligned destinations + // scales_vec.store_to_elts(dst, 0, count); + // } + // } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace group_quantize_transpose_kernel + +template +void group_quantize_transpose(const Tensor &input, const Tensor *noop, + std::vector &output_list, const size_t *split_sections, + size_t num_tensors, const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace group_quantize_transpose_kernel; + using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + NVTE_CHECK(num_tensors == output_list.size(), + "Number of output tensors should match number of tensors."); + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + Tensor *output = nullptr; + // loop over the list to find the first non-empty tensor + for (size_t i = 0; i < num_tensors; ++i) { + if (output_list[i]->has_data()) { + output = output_list[i]; + break; + } + } + NVTE_CHECK(output != nullptr, "No output tensor found."); + // also check that the output has not null data pointer + NVTE_CHECK(output->data.dptr != nullptr, "Output data pointer is null."); + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + bool return_transpose = output->has_columnwise_data(); + // forbid return transpose for now because group quantize transpose is not supported yet + NVTE_CHECK(!return_transpose, "Return transpose is not supported for group quantize transpose."); + + // output_List is contiguous in memory, so take the first tensor as the contiguous output + auto output_contiguous = output->data; + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + // process the output list and produce the multi-tensor args for grouped kernel + MultiAmaxCastTransposeFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + if (split_sections[i] == 0) { + continue; + } + kernel_args.rowwise_amax_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->amax.dptr); + kernel_args.output_rowwise_scale_inv_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->scale_inv.dptr); + // kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + // check overflow + NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, + "split_sections_range overflow the int32_t"); + kernel_args.num_tensors++; + } + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + // Note (zhongbo): for group quantize of [x1, x2, ..., xn] + // for the rowwise sclaing, scaling factor stride is shared between all tensors + // for the colwise scaling, scaling factor stride is different for each tensor because of transpose + // since transpose puts token dimension splits in the last dimension of the tensor + const size_t scale_stride = output->scale_inv.shape[1]; + // const size_t scale_stride_transpose = + // return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + // alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output_contiguous, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, 4); + // if (return_transpose) { + // create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + // BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + // } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + group_quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + NVTE_ERROR("2D quantization is not supported for group quantize transpose."); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>(tensor_map_input, tensor_map_output, + scales_ptr, noop_ptr, rows, cols, + scale_stride, rng_state, kernel_args); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..2eb7e1c72f --- /dev/null +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1063 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +using Stride2D = cute::Stride>; + +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed +struct MultiAmaxHadamardCastFusionArgs { + // (output) Amax buffer for pre-RHT amax buffer + void* global_amax_list[kMaxTensorsPerKernel]; + // output C pointers for each tensor + void* output_colwise_list[kMaxTensorsPerKernel]; + // output scale inverse pointers for each tensor + void* output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // split sections of each tensor of input + int split_sections[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + // stride 2D struct for CUTE + Stride2D output_stride2d_list[kMaxTensorsPerKernel]; + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + + +__device__ __forceinline__ float* GetGlobalAmaxPtrByTensorId(MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, int tensor_id){ + // directly returns the global amax pointer by tensor id + if (tensor_id < 0 || tensor_id >= kernel_args_ptr->num_tensors) { + return nullptr; + } + return reinterpret_cast(kernel_args_ptr->global_amax_list[tensor_id]); +} + +__device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, int offset){ + // check the kernel args and get the corresponding id + int tensor_id = 0; + while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) { + ++tensor_id; + } + return tensor_id; +} + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + +}; + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), + "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), + "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static +void +group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, + TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + CSmemLayout, + TiledMMA mma, + MultiAmaxHadamardCastFusionArgs kernel_args, + const size_t* rng_state) +{ + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); + // Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) + + using TensorC = decltype( + make_tensor( + subbyte_iterator((TC*)nullptr), // engine + make_shape(int{}, int{}), // (M, N_i) + Stride2D{} // stride (dM, dN) + ) + ); + + // make an array of mCs with capacity kNumTensorsPow2, + // but kernel_args.num_tensors is its real size + TensorC mCs[kNumTensorsPow2]; + + for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + auto* output_C_i = reinterpret_cast(kernel_args.output_colwise_list[i]); + int output_C_n_i = kernel_args.split_sections[i]; + Stride2D stride2d_C_i = kernel_args.output_stride2d_list[i]; + + mCs[i] = cute::make_tensor( + cute::subbyte_iterator(output_C_i), + cute::make_shape(static_cast(M), output_C_n_i), // (M, N_i) + stride2d_C_i + ); + } + + using TensorSFC = decltype( + make_tensor( + make_gmem_ptr((TSFC*)nullptr), + make_layout( + make_shape( + int{}, // M + make_shape( make_shape(Int<16>{}, _4{}), // (16, 4) + int{} ) // n_tiles = split / 64 + ), + make_stride( + int{}, // dM = (split / 16) + make_stride( make_stride(_0{}, _1{}), // inner (16,4) layout + _4{} ) // tiles stride + ) + ) + ) + ); + + // Array of SFC tensors, capacity = kNumTensorsPow2, but we only use + // [0 .. kernel_args.num_tensors) in practice. + TensorSFC mSFCs[kNumTensorsPow2]; + + for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + // Per-tensor split length along the original N dimension + int split_N_i = kernel_args.split_sections[i]; // must be multiple of 64 + int n_tiles_i = split_N_i / 64; // # of 64-wide tiles in N + int sfc_row_stride_i = split_N_i / 16; // # SFC elements per row + + // Base pointer for this tensor’s SFC buffer + auto* SFC_i = reinterpret_cast( + kernel_args.output_colwise_scale_inv_list[i]); + + // Shape and stride for this tensor’s SFC + auto sfc_shape_i = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), n_tiles_i ) + ); + + auto sfc_stride_i = make_stride( + sfc_row_stride_i, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + auto sfc_layout_i = make_layout(sfc_shape_i, sfc_stride_i); + + // Final tensor for this SFC slice + mSFCs[i] = make_tensor( + make_gmem_ptr(SFC_i), + sfc_layout_i + ); + } + + auto cluster_shape = Shape< _1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + auto mainloop_tiler = Shape<_128,_16,_64>{}; + auto epilogue_tiler = Shape<_128,_64,_64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + // Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + using TensorGC = decltype( + local_tile( + std::declval(), + decltype(epilogue_tiler){}, + make_coord(_, _, _), + Step<_1, _1, X>{} + ) + ); + + TensorGC gCs_mn[kNumTensorsPow2]; + + for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + gCs_mn[i] = local_tile( + mCs[i], + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + } + + // Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + using TensorGSFC = decltype( + local_tile( + std::declval(), + decltype(epilogue_tiler){}, + make_coord(_, _, _), + Step<_1, _1, X>{} + ) + ); + + // One tiled SFC view per split + TensorGSFC gSFCs_mn[kNumTensorsPow2]; + + for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + gSFCs_mn[i] = local_tile( + mSFCs[i], + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + } + + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + + auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, + Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, + Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition(tma_load_a, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_load_b, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + // if (is_epilogue_warp && elect_one_sync()) { + // // prefetch to make the global amax in cache + // for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + // cute::prefetch(raw_pointer_cast(kernel_args.global_amax_list[i])); + // } + // } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,tile_idx_m,_); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) + { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) + { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + // Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + + // get global amax pointer + // float* global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, tile_idx_n * 64); + int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); + float* global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); + Tensor tCgC = thr_mma_epilogue.partition_C(gCs_mn[tensor_id]); + Tensor gSFC_mn = gSFCs_mn[tensor_id]; + + float global_amax_val = *global_amax_ptr; + float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + float global_decode_scale = 1.0f / global_encode_scale; + + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + // get the starting index of current k-tile in global tensor, to query the correct global amax + int cur_k_tile_global_elem_idx = (tile_idx_n + k_tile) * 64; + int new_tensor_id = GetTensorId(&kernel_args, cur_k_tile_global_elem_idx); + // float* new_global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, cur_k_tile_global_elem_idx); + global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, new_tensor_id); + // update the scaling factors when it's no longer the same amax pointer + // TODO(zhongbo): the math operations are very expensive + // since the kernel is persistent, we can have a cache for all the possible scaling factors + if (tensor_id != new_tensor_id) { + global_amax_val = *global_amax_ptr; + global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + global_decode_scale = 1.0f / global_encode_scale; + tCgC = thr_mma_epilogue.partition_C(gCs_mn[new_tensor_id]); + gSFC_mn = gSFCs_mn[new_tensor_id]; + tensor_id = new_tensor_id; + } + // maybe udpated to the new tensor id + int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; + int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; + + Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,local_tile_idx_n); + Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,local_tile_idx_n); + + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( + make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) + )); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + // Cast data from FP32 to BF16 to FP32. + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + + // Initialize RNG for tile + const size_t rng_sequence + = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale + ), + reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + // copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void +group_rht_gemm_ntt_w_sfc(int m, int n, + TA const* A, + TB const* B, + MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 2048) +{ + using namespace cute; + + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + for (size_t i = 0; i < kernel_args_ptr->num_tensors; ++i) { + kernel_args_ptr->output_stride2d_list[i] = make_stride(kernel_args_ptr->split_sections[i], Int<1>{}); + } + + auto cga_shape = Shape< _1, _1, _1>{}; + auto cga_tile_shape = Shape<_128,_16,_16>{}; + auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cga_tile_shape, + mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, + "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, + "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), + " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto* kernel_ptr = &group_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TC, Stride2D, decltype(sC), + TSFC, + decltype(mma), + kNumTensorsPow2, + kEnableStochasticRounding>; + + bool status = cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (status != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + (*kernel_ptr) + <<< dimGrid, dimBlock, smem_size, stream >>> + (M, N, k_tile_size, cga_tile_shape, + A, dA, sA, tma_load_a, + B, dB, sB, tma_load_b, + sC, mma, + *kernel_args_ptr, + rng_state); +} + +// this function is used to wrap the group_rht_gemm_ntt_w_sfc function +// to transpose the input tensor A +template +void +group_rht_gemm_ttt_wrapper(int m, int n, + TA const* A, + TB const* B, + MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) +{ + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + group_rht_gemm_ntt_w_sfc( + n, m, + A, B, + kernel_args_ptr, + rng_state, + sm_count, stream, + k_tile_size); +} + +} // namespace +} // namespace detail + +// clang-format on + +void group_hadamard_transform_cast_fusion_columnwise( + const Tensor &input_, std::vector &output_list, const size_t *split_sections, + size_t num_tensors, const Tensor &hadamard_matrix_, QuantizationConfig &quant_config, + cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion_columnwise); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs; + + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + NVTE_CHECK(output_list.size() == num_tensors, + "Number of output tensors should match number of tensors."); + + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + // construct the multi-tensor args + MultiAmaxHadamardCastFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(split_sections[i] % 64 == 0, "component ", i, + " of split_sections should be 64 multiple"); + if (split_sections[i] == 0) { + continue; + } + kernel_args.global_amax_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->amax.dptr); + // TODO(zhongbo): should we change API assumption to use columnwise_data instead of data? + kernel_args.output_colwise_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->data.dptr); + kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->scale_inv.dptr); + kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + kernel_args.num_tensors++; + } + + // this value already excludes the zero split sections + int num_tensors_to_process = kernel_args.num_tensors; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + + // Find the next larger power of 2 for num_tensors_to_process + int pow2_num_tensors_to_process = 1; + while (pow2_num_tensors_to_process < static_cast(num_tensors_to_process)) { + pow2_num_tensors_to_process <<= 1; + } + + switch (pow2_num_tensors_to_process) { +#define CALL_WRAPPER(kNumTensorsPow2) \ + case kNumTensorsPow2: \ + TRANSFORMER_ENGINE_SWITCH_CONDITION( \ + use_stochastic_rounding, kUseStochasticRounding, \ + detail::group_rht_gemm_ttt_wrapper( \ + /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), \ + /*B=*/reinterpret_cast(hadamard_matrix.dptr), \ + /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, \ + /*stream=*/stream, /*k_tile_size=*/k_tile_size);); \ + break; + + CALL_WRAPPER(1) + CALL_WRAPPER(2) + CALL_WRAPPER(4) + CALL_WRAPPER(8) + CALL_WRAPPER(16) + CALL_WRAPPER(32) + CALL_WRAPPER(64) +#undef CALL_WRAPPER + default: + NVTE_CHECK(false, + "num_tensors_to_process unsupported value, add to the CALL_WRAPPER macro for your " + "workload: ", + num_tensors_to_process); + break; + } +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion_columnwise( + const NVTETensor input, NVTETensor *outputs, const NVTETensor hadamard_matrix, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion_columnwise( + *input_tensor, output_list, split_sections, num_tensors, + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 12f02dba6b..b77ed9ea41 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -29,7 +29,6 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/command_line.h" -#include "cutlass/util/helper_cuda.hpp" #include "cutlass/util/print_error.hpp" // clang-format off diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index a3235e84f1..19fbe431aa 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, cudaStream_t stream); +/*! \brief Casts grouped input tensor to quantized output tensors. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 05541fe30c..e7cece03ae 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp int random_sign_mask, int random_sign_mask_t, cudaStream_t stream); +/*! + * \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] outputs Array of output tensors. + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion_columnwise( + const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix, + const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ac541435c7..9a4a415e2c 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -839,67 +839,82 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Quantize tensors individually NVTE_SCOPED_GIL_RELEASE({ - for (size_t i = 0; i < num_tensors; i++) { - if (input_list[i].numel() == 0) { - continue; // Skip tensors with no elements - } - - // Direct NVFP4 quantization for row-wise data - if (quantizer.rowwise_usage) { - auto out_rowwise_data = output_list[i].get_rowwise_data(); - auto out_rowwise_scale_inv = output_list[i].get_rowwise_scale_inv(); - auto out_rowwise_amax = output_list[i].get_amax(); - TensorWrapper out_rowwise(output_list[i].scaling_mode()); - out_rowwise.set_rowwise_data(out_rowwise_data.data_ptr, - static_cast(out_rowwise_data.dtype), - out_rowwise_data.shape); - out_rowwise.set_rowwise_scale_inv(out_rowwise_scale_inv.data_ptr, - static_cast(out_rowwise_scale_inv.dtype), - out_rowwise_scale_inv.shape); - out_rowwise.set_amax(out_rowwise_amax.data_ptr, - static_cast(out_rowwise_amax.dtype), out_rowwise_amax.shape); - nvte_quantize_v2(input_list[i].data(), out_rowwise.data(), quant_config_list[i], stream); + // Fuse the rowwise and colwise into one when the kernel is ready + // rowwise quantization fusion with grouped version + if (quantizer.rowwise_usage) { + std::vector out_identity_list; + std::vector nvte_tensor_out_identity_list; + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + bool is_empty_split = input_list[i].numel() == 0; + TensorWrapper out_identity(output_list[i].scaling_mode()); + auto out_identity_data = output_list[i].get_rowwise_data(); + auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); + auto out_identity_amax = output_list[i].get_amax(); + if (!is_empty_split) { + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, + static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + } + out_identity_list.emplace_back(std::move(out_identity)); + nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); } - - // RHT + NVFP4 quantize for column-wise data - if (quantizer.columnwise_usage) { - // Get the output column-wise data, scale_inv, and amax + nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(), + split_sections.data(), num_tensors, + quant_config_list[0], stream); + } + // columnwise RHT quantization fusion with grouped version + if (quantizer.columnwise_usage) { + // setup the output list for the grouped kernel + std::vector out_transpose_list; + std::vector nvte_tensor_out_transpose_list; + // TODO(zhongbo): can we make this less verbose? + for (size_t i = 0; i < num_tensors; i++) { + // group kernel expects the output list to have the same length with split_sections + // so we still need to pass a place holder tensor for empty splits + bool is_empty_split = input_list[i].numel() == 0; auto out_columnwise_data = output_list[i].get_columnwise_data(); auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); auto out_columnwise_amax = output_list[i].get_columnwise_amax(); - // Flatten column-wise data to 2D - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - // Create a wrapper for the columnwise output, as the rowwise output. // The reason is due to the input `rht_output_t` is already in the transposed layout. // Thus, we only need a rowwise quantization to generate the columnwise output. TensorWrapper out_transpose(output_list[i].scaling_mode()); - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); - - // RHT + NVFP4 quantize kernel - // Use separate RNG state for columnwise to ensure different random numbers than rowwise - auto &columnwise_quant_config = - need_separate_columnwise_rng ? quant_config_columnwise_list[i] : quant_config_list[i]; - nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(), - rht_matrix_nvte.data(), - columnwise_quant_config, stream); + if (!is_empty_split) { + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + } + out_transpose_list.emplace_back(std::move(out_transpose)); + nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); } + // call the grouped kernel + nvte_group_hadamard_transform_cast_fusion_columnwise( + input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], + stream); } }); From b345534c8127039cd3ee5350308a6c5bbd71d33a Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 1 Dec 2025 21:43:31 -0800 Subject: [PATCH 02/36] remove local array RW Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 219 ++++++++---------- 1 file changed, 95 insertions(+), 124 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 2eb7e1c72f..5c399657b9 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -166,7 +166,6 @@ template __global__ static void @@ -210,7 +209,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // Represent the full tensors Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); - // Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) using TensorC = decltype( make_tensor( @@ -220,22 +218,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ) ); - // make an array of mCs with capacity kNumTensorsPow2, - // but kernel_args.num_tensors is its real size - TensorC mCs[kNumTensorsPow2]; - - for (size_t i = 0; i < kernel_args.num_tensors; ++i) { - auto* output_C_i = reinterpret_cast(kernel_args.output_colwise_list[i]); - int output_C_n_i = kernel_args.split_sections[i]; - Stride2D stride2d_C_i = kernel_args.output_stride2d_list[i]; - - mCs[i] = cute::make_tensor( - cute::subbyte_iterator(output_C_i), - cute::make_shape(static_cast(M), output_C_n_i), // (M, N_i) - stride2d_C_i - ); - } - using TensorSFC = decltype( make_tensor( make_gmem_ptr((TSFC*)nullptr), @@ -254,40 +236,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ) ); - // Array of SFC tensors, capacity = kNumTensorsPow2, but we only use - // [0 .. kernel_args.num_tensors) in practice. - TensorSFC mSFCs[kNumTensorsPow2]; - - for (size_t i = 0; i < kernel_args.num_tensors; ++i) { - // Per-tensor split length along the original N dimension - int split_N_i = kernel_args.split_sections[i]; // must be multiple of 64 - int n_tiles_i = split_N_i / 64; // # of 64-wide tiles in N - int sfc_row_stride_i = split_N_i / 16; // # SFC elements per row - - // Base pointer for this tensor’s SFC buffer - auto* SFC_i = reinterpret_cast( - kernel_args.output_colwise_scale_inv_list[i]); - - // Shape and stride for this tensor’s SFC - auto sfc_shape_i = make_shape( - M, - make_shape( make_shape(Int<16>{}, _4{}), n_tiles_i ) - ); - - auto sfc_stride_i = make_stride( - sfc_row_stride_i, - make_stride( make_stride(_0{}, _1{}), _4{} ) - ); - - auto sfc_layout_i = make_layout(sfc_shape_i, sfc_stride_i); - - // Final tensor for this SFC slice - mSFCs[i] = make_tensor( - make_gmem_ptr(SFC_i), - sfc_layout_i - ); - } - auto cluster_shape = Shape< _1, _1, _1>{}; // Get the appropriate blocks for this Cluster @@ -316,18 +264,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ) ); - TensorGC gCs_mn[kNumTensorsPow2]; - - for (size_t i = 0; i < kernel_args.num_tensors; ++i) { - gCs_mn[i] = local_tile( - mCs[i], - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N) - ); - } - - // Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) using TensorGSFC = decltype( local_tile( std::declval(), @@ -337,18 +273,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ) ); - // One tiled SFC view per split - TensorGSFC gSFCs_mn[kNumTensorsPow2]; - - for (size_t i = 0; i < kernel_args.num_tensors; ++i) { - gSFCs_mn[i] = local_tile( - mSFCs[i], - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N-like) - ); - } - // Allocate SMEM extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; @@ -557,7 +481,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til bulk_tmem_epilogue.data() = tmem_base_ptr; int thread_idx = threadIdx.x % 128; - // Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); auto thr_t2r = tiled_t2r.get_slice(thread_idx); @@ -567,11 +490,50 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til static constexpr float fp4_max = 6.0f; // get global amax pointer - // float* global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, tile_idx_n * 64); int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); float* global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); - Tensor tCgC = thr_mma_epilogue.partition_C(gCs_mn[tensor_id]); - Tensor gSFC_mn = gSFCs_mn[tensor_id]; + + TC* cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + TSFC* cur_output_colwise_scale_inv_ptr = reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + int cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + TensorC cur_mC = cute::make_tensor( + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id] + ); + + auto cur_sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64) + ); + + auto cur_sfc_stride = make_stride( + cur_output_colwise_n / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + + TensorSFC cur_mSFC = cute::make_tensor( + make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride) + ); + + TensorGC cur_gC_mn = local_tile( + cur_mC, + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + TensorGSFC cur_gSFC_mn = local_tile( + cur_mSFC, + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + + Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); @@ -593,16 +555,56 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); global_decode_scale = 1.0f / global_encode_scale; - tCgC = thr_mma_epilogue.partition_C(gCs_mn[new_tensor_id]); - gSFC_mn = gSFCs_mn[new_tensor_id]; tensor_id = new_tensor_id; + // went through the cute operations to update the local tensors + cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + cur_output_colwise_scale_inv_ptr = reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + cur_mC = cute::make_tensor( + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id] + ); + + cur_sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64) + ); + + cur_sfc_stride = make_stride( + cur_output_colwise_n / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + + cur_mSFC = cute::make_tensor( + make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride) + ); + + cur_gC_mn = local_tile( + cur_mC, + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + cur_gSFC_mn = local_tile( + cur_mSFC, + epilogue_tiler, + make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + + tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); } // maybe udpated to the new tensor id int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,local_tile_idx_n); - Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,local_tile_idx_n); + Tensor tCgSFC_mn = cur_gSFC_mn(_,_,tile_idx_m,local_tile_idx_n); accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); @@ -705,7 +707,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void group_rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -818,7 +820,6 @@ group_rht_gemm_ntt_w_sfc(int m, int n, TC, Stride2D, decltype(sC), TSFC, decltype(mma), - kNumTensorsPow2, kEnableStochasticRounding>; bool status = cudaFuncSetAttribute(*kernel_ptr, @@ -841,7 +842,7 @@ group_rht_gemm_ntt_w_sfc(int m, int n, // this function is used to wrap the group_rht_gemm_ntt_w_sfc function // to transpose the input tensor A -template +template void group_rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -862,7 +863,7 @@ group_rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - group_rht_gemm_ntt_w_sfc( + group_rht_gemm_ntt_w_sfc( n, m, A, B, kernel_args_ptr, @@ -919,9 +920,6 @@ void group_hadamard_transform_cast_fusion_columnwise( kernel_args.num_tensors++; } - // this value already excludes the zero split sections - int num_tensors_to_process = kernel_args.num_tensors; - // Stochastic rounding config const bool use_stochastic_rounding = quant_config.stochastic_rounding; const size_t *rng_state = nullptr; @@ -999,40 +997,13 @@ void group_hadamard_transform_cast_fusion_columnwise( k_tile_size = 512; } - // Find the next larger power of 2 for num_tensors_to_process - int pow2_num_tensors_to_process = 1; - while (pow2_num_tensors_to_process < static_cast(num_tensors_to_process)) { - pow2_num_tensors_to_process <<= 1; - } - - switch (pow2_num_tensors_to_process) { -#define CALL_WRAPPER(kNumTensorsPow2) \ - case kNumTensorsPow2: \ - TRANSFORMER_ENGINE_SWITCH_CONDITION( \ - use_stochastic_rounding, kUseStochasticRounding, \ - detail::group_rht_gemm_ttt_wrapper( \ - /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), \ - /*B=*/reinterpret_cast(hadamard_matrix.dptr), \ - /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, \ - /*stream=*/stream, /*k_tile_size=*/k_tile_size);); \ - break; - - CALL_WRAPPER(1) - CALL_WRAPPER(2) - CALL_WRAPPER(4) - CALL_WRAPPER(8) - CALL_WRAPPER(16) - CALL_WRAPPER(32) - CALL_WRAPPER(64) -#undef CALL_WRAPPER - default: - NVTE_CHECK(false, - "num_tensors_to_process unsupported value, add to the CALL_WRAPPER macro for your " - "workload: ", - num_tensors_to_process); - break; - } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::group_rht_gemm_ttt_wrapper( + /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size);); } } // namespace transformer_engine From 2eb23b392acee4c1072d46d8650b0062b5282522 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 1 Dec 2025 23:47:23 -0800 Subject: [PATCH 03/36] change wait_barrier Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 5c399657b9..f5324a7773 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -397,7 +397,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); } - cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { bool is_first_wave = linear_tile_idx == blockIdx.x; uint32_t skip_wait = is_first_wave; @@ -435,6 +435,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; bulk_tmem_mma.data() = tmem_base_ptr; + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); do { uint32_t skip_wait = K_TILE_MAX <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); @@ -488,6 +489,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; + // static constexpr float fp4_max_inv = 1.0f / fp4_max; // get global amax pointer int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); @@ -537,6 +539,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + // float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -554,6 +557,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (tensor_id != new_tensor_id) { global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + // global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; global_decode_scale = 1.0f / global_encode_scale; tensor_id = new_tensor_id; // went through the cute operations to update the local tensors @@ -642,7 +646,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++accumulator_pipe_consumer_state; - // Cast data from FP32 to BF16 to FP32. + // TODO(zhongbo): Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); @@ -656,12 +660,14 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); // Initialize RNG for tile const size_t rng_sequence From 004e5299b78a3599647ca86f54f260730e3072ca Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 2 Dec 2025 17:25:00 -0800 Subject: [PATCH 04/36] fast math options Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 9 ++++++++- .../hadamard_transform_cast_fusion.cu | 12 +++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index f5324a7773..5943435539 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -489,6 +489,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; + // (optional) path for faster math, use multiply to repalce div // static constexpr float fp4_max_inv = 1.0f / fp4_max; // get global amax pointer @@ -539,6 +540,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + // (optional) path for faster math, use multiply to repalce div // float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; float global_decode_scale = 1.0f / global_encode_scale; @@ -557,6 +559,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (tensor_id != new_tensor_id) { global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + // (optional) path for faster math, use multiply to repalce div // global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; global_decode_scale = 1.0f / global_encode_scale; tensor_id = new_tensor_id; @@ -646,7 +649,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++accumulator_pipe_consumer_state; - // TODO(zhongbo): Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. + // TODO(zhongbo): (optional) Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); @@ -658,15 +661,19 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } + // regular path for slower math, use divide pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + // regular path for slower math, use divide to repalce div auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + // (optional) path for faster math, use fast math reciprocal approximate to repalce div // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); // Initialize RNG for tile diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index b77ed9ea41..904e41854f 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -422,8 +422,12 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; + // (optional) path for faster math, use multiply to repalce div + // static constexpr float fp4_max_inv = 1.0f / fp4_max; const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + // (optional) path for faster math, use multiply to repalce div + // const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -468,7 +472,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ++accumulator_pipe_consumer_state; - // Cast data from FP32 to BF16 to FP32. + // TODO(zhongbo): (optional) Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); @@ -480,14 +484,20 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } + // regular path for slower math, use divide pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + // (optional) path for faster math, use multiply to repalce div + // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + // regular path for slower math, use divide auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + // (optional) path for faster math, use fast math reciprocal approximate to repalce div + // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); // Initialize RNG for tile const size_t rng_sequence From d9a6c244ae9d7a4b6aff32a3f243a7ecfd1587fe Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 2 Dec 2025 21:28:52 -0800 Subject: [PATCH 05/36] use mult to replace div Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 5 +++-- .../hadamard_transform/hadamard_transform_cast_fusion.cu | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 5943435539..13e6d58a8c 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -490,7 +490,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; // (optional) path for faster math, use multiply to repalce div - // static constexpr float fp4_max_inv = 1.0f / fp4_max; + static constexpr float fp4_max_inv = 1.0f / fp4_max; // get global amax pointer int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); @@ -662,7 +662,8 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til } // regular path for slower math, use divide - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); + // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 904e41854f..d8ae88f91d 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -423,7 +423,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; // (optional) path for faster math, use multiply to repalce div - // static constexpr float fp4_max_inv = 1.0f / fp4_max; + static constexpr float fp4_max_inv = 1.0f / fp4_max; const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); // (optional) path for faster math, use multiply to repalce div @@ -485,7 +485,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, } // regular path for slower math, use divide - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); + // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); From 9b9efb8a9fe066a7697eb881e84a2617d53fd834 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 2 Dec 2025 21:36:13 -0800 Subject: [PATCH 06/36] format Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 12 ++++++------ .../hadamard_transform_cast_fusion.cu | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 13e6d58a8c..a18d1bf341 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -489,7 +489,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div static constexpr float fp4_max_inv = 1.0f / fp4_max; // get global amax pointer @@ -540,7 +540,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div // float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; float global_decode_scale = 1.0f / global_encode_scale; @@ -559,7 +559,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (tensor_id != new_tensor_id) { global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div // global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; global_decode_scale = 1.0f / global_encode_scale; tensor_id = new_tensor_id; @@ -665,16 +665,16 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); - // regular path for slower math, use divide to repalce div + // regular path for slower math, use divide to repalce div auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); - // (optional) path for faster math, use fast math reciprocal approximate to repalce div + // (optional) path for faster math, use fast math reciprocal approximate to repalce div // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); // Initialize RNG for tile diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index d8ae88f91d..ce77c5194d 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -422,11 +422,11 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div static constexpr float fp4_max_inv = 1.0f / fp4_max; const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div // const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -488,7 +488,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); - // (optional) path for faster math, use multiply to repalce div + // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); @@ -497,7 +497,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); // regular path for slower math, use divide auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); - // (optional) path for faster math, use fast math reciprocal approximate to repalce div + // (optional) path for faster math, use fast math reciprocal approximate to repalce div // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); // Initialize RNG for tile From a9d0fc5a62e54455677708d50bcce6b4d61f25c8 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 3 Dec 2025 11:25:18 -0800 Subject: [PATCH 07/36] bulk move random states Signed-off-by: Zhongbo Zhu --- .../pytorch/csrc/extensions/cast.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9a4a415e2c..f4d00184b5 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -774,8 +774,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, if (quantizer.stochastic_rounding) { // TODO(zhongbo): remove the for loop of generating rng states with a single call // with rng_elts_per_thread = 1024 * num_tensors - // Change to the bulk generate rng states api when grouped quantize is available - const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + // RHT path is fully grouped kernels, which we can be optimized + bool with_bulk_generate_rng_states = quantizer.with_rht; + const size_t rng_elts_per_thread = with_bulk_generate_rng_states ? 1024 * num_tensors : 1024; auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); @@ -797,17 +798,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); quant_config_list[i].set_stochastic_rounding(true); - - // Generate separate RNG state for columnwise quantization - if (need_separate_columnwise_rng) { - at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_columnwise_ptr = - static_cast(rng_states_columnwise_tensor.data_ptr()) + i * 2; - philox_unpack(philox_args_columnwise, rng_state_columnwise_ptr); - te_rng_state_columnwise_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_columnwise_ptr), std::vector{2}, DType::kInt64)); - quant_config_columnwise_list[i].set_rng_state(te_rng_state_columnwise_list[i].data()); - quant_config_columnwise_list[i].set_stochastic_rounding(true); + // break the loop if we are using bulk generate rng states + if (with_bulk_generate_rng_states) { + break; } } } From 1af82af058b64a44a081387dd3515aa9aa4e964e Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 3 Dec 2025 11:46:55 -0800 Subject: [PATCH 08/36] greptile Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_grouped_linear.py | 2 +- tests/pytorch/nvfp4/test_nvfp4_group_quantize.py | 4 ++-- .../common/hadamard_transform/group_hadamard_transform.cu | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 7 +++++++ transformer_engine/pytorch/csrc/quantizer.cpp | 2 +- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 464219a7fe..3560858d6d 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -53,7 +53,7 @@ --set=full \ --kernel-name "GroupHadamardAmaxTmaKernel" \ -s 5 -c 5 \ - python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 """ diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 10aa3eb505..c785e46548 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference( for i in range(len(x_qx)): if split_sections[i] == 0: - # then just assert the same same and dtype because the buffer won't be zero out + # then just assert the same shape and dtype because the buffer won't be zero out assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) @@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference( # assert with zero tolerance for i in range(len(x_qx_t)): if split_sections[i] == 0: - # then just assert the same same and dtype because the buffer won't be zero out + # then just assert the same shape and dtype because the buffer won't be zero out assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 84eb6bb5c3..ea5e22bbfb 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector& o } // Multi zero out multiple amaxes if needed - // Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel + // Currently don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel // let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block dim3 block_setup_amax(kMaxTensorsPerKernel); dim3 grid_setup_amax(1); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f4d00184b5..ee236a38b7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -723,6 +723,13 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors, " NVFP4 quantizers, but got ", quantizers.size(), "."); + // sanity check all the quantizers have the same scaling mode + bool all_same_scaling_mode = + std::all_of(quantizers.begin(), quantizers.end(), [&](const NVFP4Quantizer *quantizer) { + return quantizer->get_scaling_mode() == quantizers.front()->get_scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "All quantizers must have the same scaling mode"); + // Trivial cases if (num_tensors == 0) { return; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c73c09b317..fd748d1b21 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } } - // Restriction for the RHT cast fusion kernel. + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; From b4515d2bc5cc23c11d4574c6ca32e06a65b91192 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 3 Dec 2025 15:39:13 -0800 Subject: [PATCH 09/36] lint Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index a18d1bf341..e5a0720005 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -212,7 +212,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til using TensorC = decltype( make_tensor( - subbyte_iterator((TC*)nullptr), // engine + subbyte_iterator(recast_ptr(nullptr)), // engine make_shape(int{}, int{}), // (M, N_i) Stride2D{} // stride (dM, dN) ) @@ -220,7 +220,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til using TensorSFC = decltype( make_tensor( - make_gmem_ptr((TSFC*)nullptr), + make_gmem_ptr(recast_ptr(nullptr)), make_layout( make_shape( int{}, // M From 626e3fe9a1b2204b5521f4f24b49140c4d227fb5 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 4 Dec 2025 14:47:55 -0800 Subject: [PATCH 10/36] revert to use divides Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 4 ++-- .../hadamard_transform/hadamard_transform_cast_fusion.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index e5a0720005..632690d708 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -662,8 +662,8 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til } // regular path for slower math, use divide - pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); - // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + // pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index ce77c5194d..46c5edf80e 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -485,8 +485,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, } // regular path for slower math, use divide - pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); - // pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + // pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); // (optional) path for faster math, use multiply to repalce div // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); From fc6f7f275a8770ed819679f3b41d021f36c0ed3a Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 4 Dec 2025 14:48:47 -0800 Subject: [PATCH 11/36] avoid fp32 bf16 round-trip in RHT cast fusion Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 8 ++-- .../hadamard_transform_cast_fusion.cu | 8 ++-- .../custom_recipes/quantization_nvfp4.py | 39 ++++++++++++++++++- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 632690d708..75065f6acd 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -649,10 +649,10 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++accumulator_pipe_consumer_state; - // TODO(zhongbo): (optional) Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. - auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + // Cast data from FP32 to BF16 to FP32. + // auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + // auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + // tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 46c5edf80e..b265d7b9a2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -472,10 +472,10 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ++accumulator_pipe_consumer_state; - // TODO(zhongbo): (optional) Maybe remove it for better perf. Cast data from FP32 to BF16 to FP32. - auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + // Cast data from FP32 to BF16 to FP32. + // auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + // auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + // tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index b371ca4842..49f6a90037 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -391,7 +391,27 @@ def _build_hadamard_matrix( h = sign_mat @ h return h.to(dtype) - def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: + def _supports_rht_cast_fusion(self, x: torch.Tensor) -> bool: + """ + Check if RHT cast fusion is supported for the input tensor. + + When RHT cast fusion is supported, there is no intermediate bf16 tensor for RHT(x) results, + which means that we can directly cast from FP32 to FP4 without any intermediate bf16 tensor. + + For example, if x.shape is (128, 128), then RHT cast fusion is supported. + If x.shape is (128, 127), then RHT cast fusion is not supported. + + This function is to simulate this behavior in the reference implementation for numerical correctness. + + Args: + x: The input tensor. + + Returns: + True if RHT cast fusion is supported, False otherwise. + """ + return x.dtype == torch.bfloat16 and x.shape[0] % 64 == 0 and x.shape[1] % 128 == 0 + + def _apply_rht(self, x: torch.Tensor, with_rht_cast_fusion: bool = False) -> torch.Tensor: """Apply randomized Hadamard transform without random signs (reference path). This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). @@ -415,6 +435,12 @@ def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: x_mat = x.contiguous().view(-1, rht_dim) # Random sign matrix is identity in this reference (no sign flipping) transform = H * scale + + # If RHT cast fusion is supported, we can directly cast from FP32 to FP4 without any intermediate bf16 tensor. + if with_rht_cast_fusion: + transform = transform.float() + x_mat = x_mat.float() + out = x_mat @ transform return out.view(original_shape) @@ -599,9 +625,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ ), "NVFP4 only supports 1x16 or 16x16 tile shape." # Prepare inputs once so we can reuse for both amax and quantization # Row-input will always be the original input. + with_rht_cast_fusion = self._supports_rht_cast_fusion(tensor) row_input = tensor col_input = ( - self._apply_rht(tensor.t().contiguous()) + self._apply_rht(tensor.t().contiguous(), with_rht_cast_fusion) if self.with_rht else tensor.t().contiguous() ) @@ -612,6 +639,14 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.columnwise_usage else global_amax_row ) + # TODO: this is a horrible hack to pass zero tolerance test + # currently the amax of RHT transform still has a fp32 -> bf16 -> fp32 round-trip + # this is to simulate that behaviour and we will remove this once the amax of RHT transform is fixed + if self.columnwise_usage and with_rht_cast_fusion: + # global_amax_col = global_amax_col.to(torch.bfloat16).float() + global_amax_col = ( + torch.max(torch.abs(col_input.bfloat16())).to(torch.float32).view(1) + ) transpose_scales = False From 48e5d7544501651ef1049f1424ac9cbdb404275a Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 5 Dec 2025 13:07:45 -0800 Subject: [PATCH 12/36] trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH Signed-off-by: Zhongbo Zhu --- .../group_hadamard_transform_cast_fusion.cu | 72 +++++++++++------ .../hadamard_transform_cast_fusion.cu | 80 ++++++++++++------- 2 files changed, 99 insertions(+), 53 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 75065f6acd..8c9e16ebdb 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -31,6 +31,9 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" +// include utils for get system env +#include "../util/system.h" + // clang-format off namespace transformer_engine { @@ -166,7 +169,8 @@ template + bool kEnableStochasticRounding = false, + bool kEnableFastMath = false> __global__ static void group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, @@ -540,8 +544,11 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div - // float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + // will be used in fast math path if enabled + float global_encode_scale_multiplier = 1.0f; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -559,8 +566,10 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (tensor_id != new_tensor_id) { global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div - // global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + // will be used in fast math path if enabled + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } global_decode_scale = 1.0f / global_encode_scale; tensor_id = new_tensor_id; // went through the cute operations to update the local tensors @@ -661,21 +670,27 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - // regular path for slower math, use divide - // pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); - // (optional) path for faster math, use multiply to repalce div - // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + if constexpr (kEnableFastMath) { + // path for faster math, use multiply to repalce div + pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + } else { + // regular path for slower math, use divide + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + } auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); - // regular path for slower math, use divide to repalce div - auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); - // (optional) path for faster math, use fast math reciprocal approximate to repalce div - // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + cutlass::Array acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + } // Initialize RNG for tile const size_t rng_sequence @@ -721,7 +736,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void group_rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -834,7 +849,8 @@ group_rht_gemm_ntt_w_sfc(int m, int n, TC, Stride2D, decltype(sC), TSFC, decltype(mma), - kEnableStochasticRounding>; + kEnableStochasticRounding, + kEnableFastMath>; bool status = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -856,7 +872,7 @@ group_rht_gemm_ntt_w_sfc(int m, int n, // this function is used to wrap the group_rht_gemm_ntt_w_sfc function // to transpose the input tensor A -template +template void group_rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -877,7 +893,7 @@ group_rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - group_rht_gemm_ntt_w_sfc( + group_rht_gemm_ntt_w_sfc( n, m, A, B, kernel_args_ptr, @@ -1011,13 +1027,21 @@ void group_hadamard_transform_cast_fusion_columnwise( k_tile_size = 512; } + // TODO: haven't decided whether to expose this as a API option or not + // use fast math if there is a ENV var NVTE_RHT_CAST_FUSION_USE_FAST_MATH, default to false + static const bool use_fast_math = + transformer_engine::getenv("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); + TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kUseStochasticRounding, - detail::group_rht_gemm_ttt_wrapper( - /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size);); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kEnableFastMath, + detail::group_rht_gemm_ttt_wrapper( + /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size););); } } // namespace transformer_engine diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index b265d7b9a2..10cbbb5157 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -31,6 +31,9 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" +// include utils for get system env +#include "../util/system.h" + // clang-format off namespace transformer_engine { @@ -128,7 +131,8 @@ template + bool kEnableStochasticRounding = false, + bool kEnableFastMath = false> __global__ static void rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, @@ -426,8 +430,11 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, static constexpr float fp4_max_inv = 1.0f / fp4_max; const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // (optional) path for faster math, use multiply to repalce div - // const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + // will be used in fast math path if enabled + float global_encode_scale_multiplier = 1.0f; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } const float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -484,21 +491,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - // regular path for slower math, use divide - // pvscales = cutlass::multiplies>{}(vec_maxs, fp4_max_inv); - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); - // (optional) path for faster math, use multiply to repalce div - // pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + if constexpr (kEnableFastMath) { + // path for faster math, use multiply to repalce div + pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + } else { + // regular path for slower math, use divide + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + } auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); - // regular path for slower math, use divide - auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); - // (optional) path for faster math, use fast math reciprocal approximate to repalce div - // auto acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + cutlass::Array acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + } // Initialize RNG for tile const size_t rng_sequence @@ -542,7 +555,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -654,7 +667,8 @@ rht_gemm_ntt_w_sfc(int m, int n, TC, decltype(dC), decltype(sC), TSFC, decltype(mma), - kEnableStochasticRounding>; + kEnableStochasticRounding, + kEnableFastMath>; bool status = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -677,7 +691,7 @@ rht_gemm_ntt_w_sfc(int m, int n, // this function is used to wrap the rht_gemm_ntt_w_sfc function //to transpose the input tensor A -template +template void rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -700,7 +714,7 @@ rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - rht_gemm_ntt_w_sfc( + rht_gemm_ntt_w_sfc( n, m, A, B, C, SFC, global_amax, @@ -810,20 +824,28 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out } else if (m < 1024 || n < 1024) { k_tile_size = 512; } + + // TODO: haven't decided whether to expose this as a API option or not + // use fast math if there is a ENV var NVTE_RHT_CAST_FUSION_USE_FAST_MATH, default to false + static const bool use_fast_math = + transformer_engine::getenv("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); + TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kUseStochasticRounding, - detail::rht_gemm_ttt_wrapper( - /*m=*/m, - /*n=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*C=*/reinterpret_cast(output_t.dptr), - /*SFC=*/reinterpret_cast(scale_inv_t.dptr), - /*global_amax=*/reinterpret_cast(global_amax.dptr), - /*rng_state=*/rng_state, - /*sm_count=*/sm_count, - /*stream=*/stream, - /*k_tile_size=*/k_tile_size);); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kEnableFastMath, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size););); } } // namespace transformer_engine From 3d07a9b411ae94cec545dc1aac3631be030b735a Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 12 Dec 2025 01:59:25 -0800 Subject: [PATCH 13/36] integrate row col rht fusion, functional Signed-off-by: Zhongbo Zhu --- transformer_engine/common/CMakeLists.txt | 2 + .../customized_pipeline.hpp | 245 +++ ...cast_col_hadamard_transform_cast_fusion.cu | 1429 +++++++++++++++++ .../transformer_engine/hadamard_transform.h | 19 + .../pytorch/csrc/extensions/cast.cpp | 319 ++-- 5 files changed, 1872 insertions(+), 142 deletions(-) create mode 100644 transformer_engine/common/hadamard_transform/customized_pipeline.hpp create mode 100644 transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 34a99d10e5..decc86dc75 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -175,6 +175,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu + hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu transpose/quantize_transpose_square_blockwise.cu @@ -352,6 +353,7 @@ endforeach() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Number of parallel build jobs if($ENV{MAX_JOBS}) diff --git a/transformer_engine/common/hadamard_transform/customized_pipeline.hpp b/transformer_engine/common/hadamard_transform/customized_pipeline.hpp new file mode 100644 index 0000000000..85c76c4080 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/customized_pipeline.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; +namespace detail { +// Producer-consumer pipeline implementation +// for UMMA producer. In this case, UMMA barrier arrives are used +// by producer_commit. Use case, accumulator generation as +// the result of MMA instructions. +template , + class AtomThrShape_MNK_ = Shape<_1, _1, _1> > +class CustomizedPipelineTmaUmmaAsync { + public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + + private: + using Impl = PipelineTmaAsync; + + public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params, + ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + + uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 + if (cute::size(cluster_shape) > 1) { + multicast_consumer_arrival_count = + ((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) * + params.num_consumers; + } + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && + "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); + cutlass::arch::detail::initialize_barrier_array_pair_aligned< + decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, + multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, + dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } else { + block_id_mask_ = detail::calculate_multicast_mask( + cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + // Constructor by default initializes barriers and calculates masks. + // These operations can be explicity deferred by specifying InitBarriers and InitMasks. + // If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called. + template + CUTLASS_DEVICE CustomizedPipelineTmaUmmaAsync(SharedStorage& storage, Params params, + ClusterShape cluster_shape, InitBarriers = {}, + InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, InitMasks{}), + params_(params), + empty_barrier_ptr_(&storage.empty_barrier_[0]), + full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || + cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || + cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, + ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) { + impl_.producer_expect_transaction(state, transaction_bytes); + } + + // NOP for TMA based mainloop + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { impl_.producer_commit(state, bytes); } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { impl_.producer_tail(state); } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void umma_consumer_release(PipelineState state) { umma_consumer_release(state.index(), false); } + CUTLASS_DEVICE + void consumer_release(PipelineState state) { impl_.consumer_release(state); } + + private: + Impl impl_; + Params params_; + EmptyBarrier* empty_barrier_ptr_; + FullBarrier* full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void umma_consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + // {$nv-release-never begin} + // TODO: Needs to be updated once Blackwell specialized pipeline is implemented. + // XMMA style bar_peek will be tested. We will need to revisit skip interface and + // what skip means when we have bar_peek functionality. + // A separate MR will implement MMA_2x1SM specialized pipeline. + // {$nv-release-never end} + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; +} // namespace detail +} // namespace cutlass diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..9b438aae1d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1429 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" +#include "customized_pipeline.hpp" + +// include utils for get system env +#include "../util/system.h" + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +struct CLCResponse { + uint32_t data[4] = {0}; +}; + +constexpr int kMaxTensorsPerKernel = 64; + +struct MultiAmaxHadamardCastFusionArgs { + // (output) Amax buffer for input A amax buffer + void *global_a_amax_list[kMaxTensorsPerKernel]; + // (output) Amax buffer for pre-RHT amax buffer + void *global_d_amax_list[kMaxTensorsPerKernel]; + // output D pointers for each tensor + void *output_colwise_list[kMaxTensorsPerKernel]; + // output SFD inverse pointers for each tensor + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + // split sections of each tensor of input + int split_sections[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of split_sections of each tensor of input + int split_sections_range[kMaxTensorsPerKernel + 1]; + + // Number of tensors (splits) being processed by kernel + int num_tensors; +}; + +__device__ __forceinline__ int GetGroupIdx(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + int offset) { + // check the kernel args and get the corresponding id + int group_idx = 0; + int num_tensors = kernel_args_ptr->num_tensors; + int boundary = kernel_args_ptr->split_sections_range[num_tensors]; + if (offset >= boundary) { + return num_tensors - 1; + } + while (kernel_args_ptr->split_sections_range[group_idx + 1] <= offset) { + ++group_idx; + } + return group_idx; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + return output; +} + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) CLCResponse clc_response[SchedulerPipelineStageCount_]; + alignas(16) float global_a_amax[kMaxTensorsPerKernel]; + alignas(16) float global_d_amax[kMaxTensorsPerKernel]; + uint32_t tmem_base_ptr; +}; + +template +__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( + MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile, + TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, MultiAmaxHadamardCastFusionArgs args, + TiledMMA mma, + // float const* a_global_amax, + // float const* c_global_amax, + const size_t *rng_state) { + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kEnableFastMath = kEnableFastMath_; + static int constexpr RhtTensorSize = 16; + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + struct WorkTileInfo { + uint32_t m_idx = 0; + uint32_t n_idx = 0; + uint32_t l_idx = 0; + bool is_valid_tile = false; + }; + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + + int k_tile_max = 0; + + int wave_cnt = 0; + WorkTileInfo work_tile_info; + WorkTileInfo next_work_tile_info; + CLCResponse *clc_response_ptr_; + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + CLCResponse *clc_response_ptr) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + + k_tile_max(kmax), + work_tile_info( + {blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x < tiles_m && blockIdx.y < tiles_n}), + next_work_tile_info( + {blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x < tiles_m && blockIdx.y < tiles_n}), + clc_response_ptr_(clc_response_ptr) {} + + CUTLASS_DEVICE uint32_t tile_m() const { return work_tile_info.m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { + return work_tile_info.n_idx * uint32_t(k_tile_max); + } + + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(work_tile_info.m_idx, work_tile_info.n_idx), + cute::make_coord(tiles_in_m, tiles_in_n)) && + work_tile_info.is_valid_tile; + } + CUTLASS_DEVICE bool is_first_wave() const { return wave_cnt == 0; } + CUTLASS_DEVICE auto advance_to_next_work(CLCPipeline &clc_pipeline, + CLCPipelineState clc_pipe_producer_state) { + uint32_t mbarrier_addr = clc_pipeline.producer_get_barrier(clc_pipe_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + clc_pipeline.producer_acquire(clc_pipe_producer_state); + + if (cute::elect_one_sync()) { + issue_clc_query(clc_pipe_producer_state, mbarrier_addr, clc_response_ptr_); + } + + ++clc_pipe_producer_state; + return clc_pipe_producer_state; + } + + CUTLASS_DEVICE auto fetch_next_work(CLCPipeline &clc_pipeline, + CLCPipelineState clc_pipe_producer_state) { + clc_pipeline.consumer_wait(clc_pipe_producer_state); + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&clc_response_ptr_[clc_pipe_producer_state.index()]); + next_work_tile_info = work_tile_info_from_clc_response(smem_addr); + clc_pipeline.consumer_release(clc_pipe_producer_state); + wave_cnt++; + return; + } + + CUTLASS_DEVICE auto update_work_tile_info() { + work_tile_info = next_work_tile_info; + return; + } + + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { + return work_tile_info.m_idx + work_tile_info.n_idx * tiles_in_m; + } + + CUTLASS_HOST_DEVICE + static void issue_clc_query(CLCPipelineState state, uint32_t mbarrier_addr, + CLCResponse *clc_response_ptr) { +#if defined(CUTLASS_ARCH_CLC_ENABLED) + uint32_t result_addr = cute::cast_smem_ptr_to_uint( + reinterpret_cast(&clc_response_ptr[state.index()])); + asm volatile( + "{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes." + "multicast::cluster::all.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbarrier_addr)); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } + CUTLASS_DEVICE + static WorkTileInfo work_tile_info_from_clc_response(uint32_t result_addr) { + WorkTileInfo work_tile_info; + uint32_t valid = 0; +#if defined(CUTLASS_ARCH_CLC_ENABLED) + asm volatile( + "{\n" + ".reg .pred p1;\n\t" + ".reg .b128 clc_result;\n\t" + "ld.shared.b128 clc_result, [%4];\n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" + "selp.u32 %3, 1, 0, p1;\n\t" + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, %1, %2, _}, " + "clc_result;\n\t" + "}\n" + : "=r"(work_tile_info.m_idx), "=r"(work_tile_info.n_idx), "=r"(work_tile_info.l_idx), + "=r"(valid) + : "r"(result_addr) + : "memory"); + + cutlass::arch::fence_view_async_shared(); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + work_tile_info.is_valid_tile = (valid == 1); + return work_tile_info; + } + }; + + // Allocate SMEMork + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(packed_N, size<2>(epilogue_tiler)))); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, shared_storage.clc_response); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + // if (is_epilogue_col_quant_warp && elect_one_sync()) { + // cute::prefetch(raw_pointer_cast(c_global_amax)); + // } + // if (is_epilogue_row_quant_warp && elect_one_sync()) { + // cute::prefetch(raw_pointer_cast(a_global_amax)); + // } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (is_sched_warp) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + clc_pipeline_params.transaction_bytes = sizeof(CLCResponse); + clc_pipeline_params.initializing_warp = 3; + CLCPipeline clc_pipeline(shared_storage.clc, clc_pipeline_params, cluster_shape); + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (is_dma_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 4; + + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.clc_throttle, + clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + if (is_dma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } + + do { + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + // Throttle CLC producer + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + else if (is_mma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } else if (is_sched_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + clc_pipeline_producer_state = + scheduler.advance_to_next_work(clc_pipeline, clc_pipeline_producer_state); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // leveraging 256-bit writes to global memory + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = + __ldg(reinterpret_cast(args.global_d_amax_list[g])); + } + + size_t rng_seed = 0; + size_t rng_offset = 0; + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + + TSFDLayout sfd_layout; + int cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + Tensor mD = make_tensor( + cute::subbyte_iterator(reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = + local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + auto sfc_converter = cutlass::NumericConverter{}; + + do { + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator( + reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + } + int group_start_offset = args.split_sections_range[group_idx]; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); + } + + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_d_amax + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = + __ldg(reinterpret_cast(args.global_a_amax_list[g])); + } + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + int row_quant_barrier_id = 10; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // update amax + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = + reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} + +template +void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A, + TB const *B, TQA *QA, TSFA *SFA, + MultiAmaxHadamardCastFusionArgs &args, + const size_t *rng_state, uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), + make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), + make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape( + SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape( + SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{})); + + using SFALayout = cute::conditional_t; + using SFDLayout = cute::conditional_t; + SFALayout sfa_layout; + SFDLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, + make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}); + } else { + sfa_layout = make_layout( + make_shape(make_shape(Int{}, hidden_size / SFVecSize), packed_sequence_length), + make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize)); + sfd_layout = make_layout( + make_shape(hidden_size, make_shape(Int{}, packed_sequence_length / SFVecSize)), + make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{}))); + } + + // Define shapes (dynamic) + auto M = hidden_size; + auto N = packed_sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{}); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = LayoutRight{}; // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape<_1, _1, _1>; + auto cga_shape = ClusterShape{}; + auto cga_tile_shape = Shape<_128, Int, Int>{}; + auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cga_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS(cga_tile_shape), size<1>(cga_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div( + shape<0>(cga_tile_shape), + shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cga_tile_shape), + shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / (cute::size<0>(cga_tile_shape) * cute::size<1>(cga_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 6; + static int constexpr MainloopPipelineBytes = sizeof( + typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr ClcResponseBytes = sizeof(CLCResponse) * SchedulerPipelineStageCount; + static int constexpr CLCThrottlePipelineBytes = + sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr CLCPipelineBytes = + sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof( + typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = + ClcResponseBytes + CLCThrottlePipelineBytes + TmemBasePtrsBytes + CLCPipelineBytes + + TmemDeallocBytes + BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + // printf("\nmax stages: %d\n", int(kMaxStages)); + // printf("\nreserved bytes: %d\n", int(kReservedBytes)); + // printf("\nbytes per stage: %d\n", int(kBytesPerStage)); + // printf("\nremaining bytes: %d\n", int((kBlackwellSmemSize - kReservedBytes) % kBytesPerStage)); + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP), + Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cga_tile_shape, mma); + + // Assert checks on tile sizes -- no predication + assert(M % size<0>(cga_tile_shape) == 0); + assert(N % size<1>(cga_tile_shape) == 0); + + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cga_tile_shape)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); + uint32_t tiles = tiles_in_m * tiles_in_n; + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles_in_m, tiles_in_n, 1); + + int smem_size = sizeof( + SharedStorage); + + auto *kernel_ptr = &group_row_col_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_shape), + decltype(cga_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, + decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, + decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), + AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, + kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kEnableFastMath>; + + bool status_set_attr = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (status_set_attr != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const*)kernel_ptr, + M, N, k_tile_size, cga_shape, cga_tile_shape, A, dA, sA, tma_load_a, B, dB, sB, tma_load_b, + QA, dQA, SFA, sfa_layout, args, mma, rng_state); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + return; + } +} + +} // namespace +} // namespace detail + +void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector &output_list, + const size_t *split_sections, size_t num_tensors, + const Tensor &hadamard_matrix_, + QuantizationConfig &quant_config, cudaStream_t stream) { + NVTE_API_CALL(group_hadamard_transform_cast_fusion); + + using transformer_engine::detail::kMaxTensorsPerKernel; + using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs; + + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + NVTE_CHECK(output_list.size() == num_tensors, + "Number of output tensors should match number of tensors."); + + NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel, + "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); + + // construct the multi-tensor args + MultiAmaxHadamardCastFusionArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.split_sections_range[0] = 0; + bool all_has_row_quant = true; + bool all_has_col_quant = true; + void *rowwise_data_base_ptr = nullptr; + void *rowwise_scale_inv_base_ptr = nullptr; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(split_sections[i] % 128 == 0, "component ", i, + " of split_sections should be 128 multiple"); + if (split_sections[i] == 0) { + continue; + } + bool has_row_quant = output_list[i]->data.dptr != nullptr; + bool has_col_quant = output_list[i]->columnwise_data.dptr != nullptr; + all_has_row_quant &= has_row_quant; + all_has_col_quant &= has_col_quant; + // sanity check, the two bool flags cannot be both false + NVTE_CHECK(has_row_quant || has_col_quant, + "At least one of the output tensors must have row or column quant."); + void *amax_rowwise_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->amax.dptr) : nullptr; + void *amax_colwise_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_amax.dptr) : nullptr; + void *rowwise_data_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->data.dptr) : nullptr; + void *rowwise_scale_inv_ptr = + has_row_quant ? reinterpret_cast(output_list[i]->scale_inv.dptr) : nullptr; + if (all_has_row_quant && + (rowwise_data_base_ptr == nullptr || rowwise_scale_inv_base_ptr == nullptr)) { + rowwise_data_base_ptr = rowwise_data_ptr; + rowwise_scale_inv_base_ptr = rowwise_scale_inv_ptr; + } + void *output_colwise_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_data.dptr) : nullptr; + void *output_colwise_scale_inv_ptr = + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr) : nullptr; + kernel_args.global_a_amax_list[kernel_args.num_tensors] = amax_rowwise_ptr; + kernel_args.global_d_amax_list[kernel_args.num_tensors] = amax_colwise_ptr; + kernel_args.output_colwise_list[kernel_args.num_tensors] = output_colwise_ptr; + kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = + output_colwise_scale_inv_ptr; + kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + kernel_args.split_sections_range[kernel_args.num_tensors + 1] = + kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + kernel_args.num_tensors++; + } + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + // TODO: haven't decided whether to expose this as a API option or not + // use fast math if there is a ENV var NVTE_RHT_CAST_FUSION_USE_FAST_MATH, default to false + static const bool use_fast_math = + transformer_engine::getenv("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); + + const bool use_swizzle_sf_output = false; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_col_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kEnableFastMath, + + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kEnableFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + + ););););); +} + +} // namespace transformer_engine + +void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor *outputs, + const NVTETensor hadamard_matrix, + const size_t *split_sections, + const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + + Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_list(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list[i] = convertNVTETensorCheck(outputs[i]); + } + + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Call the multi-tensor Hadamard transform amax implementation. + group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors, + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + stream); +} diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index e7cece03ae..112cb9b54d 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -104,6 +104,25 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise( const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! + * \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation. + * + * This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] outputs Array of output tensors. + * \param[in] hadamard_matrix Hadamard matrix to use for transformation. + * \param[in] split_sections Array specifying splits in dimension 0 for each output tensor. + * \param[in] num_tensors Number of output tensors, must be > 0. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs, + const NVTETensor hadamard_matrix, + const size_t* split_sections, size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ee236a38b7..879c5c7c45 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -709,6 +709,174 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } +// Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) +void split_quantize_nvfp4_impl_with_rht_helper( + const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + const std::vector &quant_config_list, + cudaStream_t stream +) { + const size_t num_tensors = split_sections.size(); + const auto &quantizer = *quantizers.front(); + + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + + // Compute amaxes + if (quantizer.with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + nvte_group_hadamard_transform_amax( + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); + } else { + // RHT is enabled, but amax is pre-RHT amax + NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); + } + + // Check that RHT matrix is available + NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, + "RHT matrix is not available."); + auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); + + // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance + bool all_aligned_token_dim = std::all_of(split_sections.begin(), split_sections.end(), [](size_t split_section) { + return split_section % 128 == 0; + }); + + if (all_aligned_token_dim){ + // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose + nvte_group_hadamard_transform_cast_fusion( + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + } else { + // Separate quantization for rowwise usage and columnwise usage + // Rowwise quantization fusion with grouped version + if (quantizer.rowwise_usage) { + std::vector out_identity_list; + std::vector nvte_tensor_out_identity_list; + for (size_t i = 0; i < num_tensors; i++) { + bool is_empty_split = input_list[i].numel() == 0; + TensorWrapper out_identity(output_list[i].scaling_mode()); + auto out_identity_data = output_list[i].get_rowwise_data(); + auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); + auto out_identity_amax = output_list[i].get_amax(); + if (!is_empty_split) { + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, + static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + } + out_identity_list.emplace_back(std::move(out_identity)); + nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); + } + nvte_group_nvfp4_quantize_with_amax( + input.data(), nvte_tensor_out_identity_list.data(), + split_sections.data(), num_tensors, quant_config_list[0], stream); + } + + // Columnwise RHT quantization fusion with grouped version + if (quantizer.columnwise_usage) { + std::vector out_transpose_list; + std::vector nvte_tensor_out_transpose_list; + for (size_t i = 0; i < num_tensors; i++) { + bool is_empty_split = input_list[i].numel() == 0; + auto out_columnwise_data = output_list[i].get_columnwise_data(); + auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); + auto out_columnwise_amax = output_list[i].get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout. + TensorWrapper out_transpose(output_list[i].scaling_mode()); + if (!is_empty_split) { + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t j = 1; j < colwise_data_shape.ndim; ++j) { + last_dim *= colwise_data_shape.data[j]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + } + out_transpose_list.emplace_back(std::move(out_transpose)); + nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); + } + nvte_group_hadamard_transform_cast_fusion_columnwise( + input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + } + } +} + +void split_quantize_nvfp4_impl_helper( + const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + const std::vector &quant_config_list, + cudaStream_t stream +) { + const size_t num_tensors = input_list.size(); + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + for (size_t i = 0; i < num_tensors; ++i) { + nvte_tensor_input_list.push_back(input_list[i].data()); + nvte_tensor_output_list.push_back(output_list[i].data()); + } + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for input too + // Columnwise amax will be filled with a fused D2D copy from rowwise amax + // Note that the multi compute amax API expects rowwise amax pointer to be not null + // So we need to set the pointer accordingly to make colwise-only quantization work + std::vector orig_amax_ptr_list; + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + } + nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); + for (size_t i = 0; i < num_tensors; i++) { + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } + + // Quantize tensors individually + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + if (input_list[i].numel() == 0) { + continue; + } + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } +} + void split_quantize_nvfp4_impl(const TensorWrapper &input, const std::vector &input_list, std::vector &output_list, @@ -758,12 +926,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, auto stream = at::cuda::getCurrentCUDAStream(); // Objects for TE C API - std::vector nvte_tensor_input_list; - std::vector nvte_tensor_output_list; std::vector quant_config_list; for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); - nvte_tensor_output_list.push_back(output_list[i].data()); quant_config_list.emplace_back(QuantizationConfigWrapper()); } @@ -813,146 +977,17 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // Perform multi-tensor quantization - if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data - // Check that config is supported - NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); - - // Compute amaxes - if (quantizer.with_post_rht_amax) { - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for RHT(input.t) - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_hadamard_transform_amax( - input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream); - }); - } else { - // RHT is enabled, but amax is pre-RHT amax - NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax"); - } - - // Check that RHT matrix is available - NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, - "RHT matrix is not available."); - auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); - - // Quantize tensors individually - NVTE_SCOPED_GIL_RELEASE({ + NVTE_SCOPED_GIL_RELEASE({ + if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data + // Check that config is supported + NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); // Fuse the rowwise and colwise into one when the kernel is ready - // rowwise quantization fusion with grouped version - if (quantizer.rowwise_usage) { - std::vector out_identity_list; - std::vector nvte_tensor_out_identity_list; - for (size_t i = 0; i < num_tensors; i++) { - // skip this round if input is empty - bool is_empty_split = input_list[i].numel() == 0; - TensorWrapper out_identity(output_list[i].scaling_mode()); - auto out_identity_data = output_list[i].get_rowwise_data(); - auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv(); - auto out_identity_amax = output_list[i].get_amax(); - if (!is_empty_split) { - out_identity.set_rowwise_data(out_identity_data.data_ptr, - static_cast(out_identity_data.dtype), - out_identity_data.shape); - out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); - out_identity.set_amax(out_identity_amax.data_ptr, - static_cast(out_identity_amax.dtype), - out_identity_amax.shape); - } - out_identity_list.emplace_back(std::move(out_identity)); - nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); - } - nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(), - split_sections.data(), num_tensors, - quant_config_list[0], stream); - } - // columnwise RHT quantization fusion with grouped version - if (quantizer.columnwise_usage) { - // setup the output list for the grouped kernel - std::vector out_transpose_list; - std::vector nvte_tensor_out_transpose_list; - // TODO(zhongbo): can we make this less verbose? - for (size_t i = 0; i < num_tensors; i++) { - // group kernel expects the output list to have the same length with split_sections - // so we still need to pass a place holder tensor for empty splits - bool is_empty_split = input_list[i].numel() == 0; - auto out_columnwise_data = output_list[i].get_columnwise_data(); - auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv(); - auto out_columnwise_amax = output_list[i].get_columnwise_amax(); - - // Create a wrapper for the columnwise output, as the rowwise output. - // The reason is due to the input `rht_output_t` is already in the transposed layout. - // Thus, we only need a rowwise quantization to generate the columnwise output. - TensorWrapper out_transpose(output_list[i].scaling_mode()); - if (!is_empty_split) { - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); - } - out_transpose_list.emplace_back(std::move(out_transpose)); - nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); - } - // call the grouped kernel - nvte_group_hadamard_transform_cast_fusion_columnwise( - input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], - stream); - } - }); - - } else { // NVFP4 quantize - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for input too - // Columnwise amax will be filled with a fused D2D copy from rowwise amax - // Note that the multi compute amax API expects rowwise amax pointer to be not null - // So we need to set the pointer accordingly to make colwise-only quantization work - std::vector orig_amax_ptr_list; - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - }); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, quantizers, quant_config_list, stream); + } else { // NVFP4 quantize + // Fuse the rowwise and colwise into one when the kernel is ready + split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, quant_config_list, stream); } - - // Quantize tensors individually - NVTE_SCOPED_GIL_RELEASE({ - for (size_t i = 0; i < num_tensors; i++) { - // skip this round if input is empty - if (input_list[i].numel() == 0) { - continue; - } - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); - } - }); - } + }); } } // namespace From 70523c835660ec8667f3ac0c3d0ef00d28f9b932 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 12 Dec 2025 15:29:27 -0800 Subject: [PATCH 14/36] numerics aligned Signed-off-by: Zhongbo Zhu --- ...cast_col_hadamard_transform_cast_fusion.cu | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 9b438aae1d..8941fcdf11 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -724,7 +724,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( : 1.0f; float global_decode_scale = 1.0f / global_encode_scale; - float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } auto sfc_converter = cutlass::NumericConverter{}; do { @@ -747,7 +750,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::platform::numeric_limits::max()) : 1.0f; global_decode_scale = 1.0f / global_encode_scale; - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } cur_N = args.split_sections[group_idx]; if constexpr (kEnableSwizzleSFOutput) { sfd_layout = @@ -821,10 +826,14 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - pvscales = cutlass::multiplies>{}( - vec_maxs, global_encode_scale_multiplier); - auto pvscales_cvted = - cutlass::NumericArrayConverter{}(pvscales); + if constexpr (kEnableFastMath) { + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + } + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tD_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}( @@ -943,7 +952,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( : 1.0f; float global_decode_scale = 1.0f / global_encode_scale; - float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } auto sfa_converter = cutlass::NumericConverter{}; do { CUTLASS_PRAGMA_NO_UNROLL @@ -962,7 +974,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::platform::numeric_limits::max()) : 1.0f; global_decode_scale = 1.0f / global_encode_scale; - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + if constexpr (kEnableFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } } auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); @@ -999,8 +1013,15 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::NumericArrayConverter{}( compute_frgs[v]); amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); - pvscales_view(_0{}, v) = cutlass::multiplies{}( - amax_view(_0{}, v), global_encode_scale_multiplier); + if constexpr (kEnableFastMath) { + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + pvscales_view(_0{}, v) = cutlass::divides{}( + amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); auto qpvscale_ups = cutlass::NumericConverter{}(filter(tQArSFA)(v)); From 0388466e9a9b229221be4687616303cf37c4b7f4 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 12 Dec 2025 15:31:48 -0800 Subject: [PATCH 15/36] style Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_grouped_linear.py | 2 +- .../nvfp4/test_nvfp4_group_quantize.py | 1 + ...cast_col_hadamard_transform_cast_fusion.cu | 54 +++++++------ .../pytorch/csrc/extensions/cast.cpp | 76 +++++++++---------- 4 files changed, 66 insertions(+), 67 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 3560858d6d..f559928f8c 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -238,7 +238,7 @@ def run_benchmark_linear( parser = argparse.ArgumentParser() parser.add_argument("--profile", action="store_true", help="Enable profiling mode") parser.add_argument( - "--output_dir", + "--output-dir", type=str, default="benchmark_output/", help="output path for report", diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index c785e46548..a29dcb4279 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference( (1024, 256), # larger sizes (8192, 1024), + (16384, 8192), (16384, 16384), ], ) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 8941fcdf11..025a53f378 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -9,9 +9,10 @@ #include #include #include -#include #include #include + +#include #include #include #include @@ -21,24 +22,23 @@ #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" - +#include "customized_pipeline.hpp" #include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" #include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" -#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/numeric_conversion.h" -#include "cutlass/float_subbyte.h" #include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/platform/platform.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/fast_math.h" -#include "cutlass/float8.h" -#include "cutlass/cluster_launch.hpp" -#include "cutlass/detail/sm100_blockscaled_layout.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -#include "customized_pipeline.hpp" // include utils for get system env #include "../util/system.h" @@ -48,7 +48,8 @@ namespace detail { namespace { using namespace cute; -using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute:: + Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor struct CLCResponse { uint32_t data[4] = {0}; @@ -727,7 +728,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( float global_encode_scale_multiplier = 1.0f; if constexpr (kEnableFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } + } auto sfc_converter = cutlass::NumericConverter{}; do { @@ -830,10 +831,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( pvscales = cutlass::multiplies>{}( vec_maxs, global_encode_scale_multiplier); } else { - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); } - auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); tD_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}( @@ -955,7 +959,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( float global_encode_scale_multiplier = 1.0f; if constexpr (kEnableFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } + } auto sfa_converter = cutlass::NumericConverter{}; do { CUTLASS_PRAGMA_NO_UNROLL @@ -1017,8 +1021,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( pvscales_view(_0{}, v) = cutlass::multiplies{}( amax_view(_0{}, v), global_encode_scale_multiplier); } else { - pvscales_view(_0{}, v) = cutlass::divides{}( - amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); pvscales_view(_0{}, v) = cutlass::multiplies{}( pvscales_view(_0{}, v), global_encode_scale); } @@ -1072,8 +1076,7 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz TB const *B, TQA *QA, TSFA *SFA, MultiAmaxHadamardCastFusionArgs &args, const size_t *rng_state, uint32_t sm_count, - cudaStream_t stream, - int k_tile_size = 1024) { + cudaStream_t stream, int k_tile_size = 1024) { using namespace cute; static int constexpr SFVecSize = 16; static int constexpr RhtTensorSize = 16; @@ -1244,7 +1247,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kEnableFastMath>; - bool status_set_attr = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + bool status_set_attr = + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (status_set_attr != cudaSuccess) { std::cerr << "Error: Failed to set Shared Memory size." << std::endl; @@ -1253,9 +1257,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; cutlass::Status status = cutlass::launch_kernel_on_cluster( - params, (void const*)kernel_ptr, - M, N, k_tile_size, cga_shape, cga_tile_shape, A, dA, sA, tma_load_a, B, dB, sB, tma_load_b, - QA, dQA, SFA, sfa_layout, args, mma, rng_state); + params, (void const *)kernel_ptr, M, N, k_tile_size, cga_shape, cga_tile_shape, A, dA, sA, + tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, mma, rng_state); CUTE_CHECK_LAST(); if (status != cutlass::Status::kSuccess) { @@ -1324,7 +1327,8 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector(output_list[i]->columnwise_data.dptr) : nullptr; void *output_colwise_scale_inv_ptr = - has_col_quant ? reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr) : nullptr; + has_col_quant ? reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr) + : nullptr; kernel_args.global_a_amax_list[kernel_args.num_tensors] = amax_rowwise_ptr; kernel_args.global_d_amax_list[kernel_args.num_tensors] = amax_colwise_ptr; kernel_args.output_colwise_list[kernel_args.num_tensors] = output_colwise_ptr; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 879c5c7c45..82d2107535 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -711,14 +711,10 @@ std::tuple, std::vector, bool> bulk_alloc // Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) void split_quantize_nvfp4_impl_with_rht_helper( - const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, + const TensorWrapper &input, const std::vector &input_list, + std::vector &output_list, const std::vector &split_sections, const std::vector &quantizers, - const std::vector &quant_config_list, - cudaStream_t stream -) { + const std::vector &quant_config_list, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); @@ -744,19 +740,19 @@ void split_quantize_nvfp4_impl_with_rht_helper( // Check that RHT matrix is available NVTE_CHECK(quantizer.rht_matrix.defined() && quantizer.rht_matrix.numel() > 0, - "RHT matrix is not available."); + "RHT matrix is not available."); auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance - bool all_aligned_token_dim = std::all_of(split_sections.begin(), split_sections.end(), [](size_t split_section) { - return split_section % 128 == 0; - }); + bool all_aligned_token_dim = + std::all_of(split_sections.begin(), split_sections.end(), + [](size_t split_section) { return split_section % 128 == 0; }); - if (all_aligned_token_dim){ - // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose + if (all_aligned_token_dim) { + // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose nvte_group_hadamard_transform_cast_fusion( - input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); } else { // Separate quantization for rowwise usage and columnwise usage // Rowwise quantization fusion with grouped version @@ -774,8 +770,8 @@ void split_quantize_nvfp4_impl_with_rht_helper( static_cast(out_identity_data.dtype), out_identity_data.shape); out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), out_identity_amax.shape); @@ -783,9 +779,9 @@ void split_quantize_nvfp4_impl_with_rht_helper( out_identity_list.emplace_back(std::move(out_identity)); nvte_tensor_out_identity_list.push_back(out_identity_list.back().data()); } - nvte_group_nvfp4_quantize_with_amax( - input.data(), nvte_tensor_out_identity_list.data(), - split_sections.data(), num_tensors, quant_config_list[0], stream); + nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(), + split_sections.data(), num_tensors, quant_config_list[0], + stream); } // Columnwise RHT quantization fusion with grouped version @@ -811,14 +807,14 @@ void split_quantize_nvfp4_impl_with_rht_helper( colwise_data_shape_2d.push_back(last_dim); out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, static_cast(out_columnwise_scale_inv.dtype), out_columnwise_scale_inv.shape); out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); } out_transpose_list.emplace_back(std::move(out_transpose)); nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); @@ -831,14 +827,10 @@ void split_quantize_nvfp4_impl_with_rht_helper( } void split_quantize_nvfp4_impl_helper( - const TensorWrapper &input, - const std::vector &input_list, - std::vector &output_list, - const std::vector &split_sections, + const TensorWrapper &input, const std::vector &input_list, + std::vector &output_list, const std::vector &split_sections, const std::vector &quantizers, - const std::vector &quant_config_list, - cudaStream_t stream -) { + const std::vector &quant_config_list, cudaStream_t stream) { const size_t num_tensors = input_list.size(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -854,24 +846,24 @@ void split_quantize_nvfp4_impl_helper( // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); } nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); } // Quantize tensors individually for (size_t i = 0; i < num_tensors; i++) { // skip this round if input is empty if (input_list[i].numel() == 0) { - continue; + continue; } nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); } @@ -982,10 +974,12 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Check that config is supported NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); // Fuse the rowwise and colwise into one when the kernel is ready - split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, quantizers, quant_config_list, stream); + split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, + quantizers, quant_config_list, stream); } else { // NVFP4 quantize // Fuse the rowwise and colwise into one when the kernel is ready - split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, quant_config_list, stream); + split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, + quant_config_list, stream); } }); } From 27f10478e1ad1952b9d70be78b32c01d36aad951 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 12 Dec 2025 16:18:17 -0800 Subject: [PATCH 16/36] remove device sync Signed-off-by: Zhongbo Zhu --- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 025a53f378..aa24243544 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1259,7 +1259,7 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz cutlass::Status status = cutlass::launch_kernel_on_cluster( params, (void const *)kernel_ptr, M, N, k_tile_size, cga_shape, cga_tile_shape, A, dA, sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, mma, rng_state); - CUTE_CHECK_LAST(); + // CUTE_CHECK_LAST(); if (status != cutlass::Status::kSuccess) { std::cerr << "Error: Failed at kernel Launch" << std::endl; From 380a1162e68d572f7a903399a14cf4f8adcc908a Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 12 Dec 2025 16:37:33 -0800 Subject: [PATCH 17/36] 128 padding Signed-off-by: Zhongbo Zhu --- transformer_engine/pytorch/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fbe2ee6d1c..3f5995230c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int: if recipe.mxfp8(): return 32 if recipe.nvfp4(): - return 64 + return 128 return 16 From f61979a8f8aaf754b412a79861c7ac174bb16419 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Dec 2025 17:28:23 -0800 Subject: [PATCH 18/36] revert colwise rng state creation because of row-col fused kernel Signed-off-by: Zhongbo Zhu --- .../pytorch/csrc/extensions/cast.cpp | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 82d2107535..b6e25ffd50 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -784,6 +784,11 @@ void split_quantize_nvfp4_impl_with_rht_helper( stream); } + // TODO(zhongbo): move the rng states if we enable stochastic rounding + // so that rowwise and colwise will have different random numbers + // this is not needed for all_aligned_token_dim path, because row & col are both fused + // into one kernel, the kernel itself generates two different random numbers for row & col + // Columnwise RHT quantization fusion with grouped version if (quantizer.columnwise_usage) { std::vector out_transpose_list; @@ -924,15 +929,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // Stochastic rounding - // When both rowwise and columnwise quantization are used, - // we need separate RNG states for each to ensure they use different random numbers. std::vector te_rng_state_list; - std::vector te_rng_state_columnwise_list; - std::vector quant_config_columnwise_list; at::Tensor rng_states_tensor; - at::Tensor rng_states_columnwise_tensor; - const bool need_separate_columnwise_rng = - quantizer.stochastic_rounding && quantizer.with_rht && quantizer.columnwise_usage; if (quantizer.stochastic_rounding) { // TODO(zhongbo): remove the for loop of generating rng states with a single call @@ -943,13 +941,6 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - // Allocate columnwise RNG resources when separate RNG is needed - if (need_separate_columnwise_rng) { - rng_states_columnwise_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_columnwise_list.emplace_back(QuantizationConfigWrapper()); - } - } for (size_t i = 0; i < num_tensors; ++i) { auto gen = at::get_generator_or_default( std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); From 6f38c78473211ccb77828931ed754cb10d2df4c7 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Mon, 15 Dec 2025 20:26:33 -0800 Subject: [PATCH 19/36] fix CI, linter Signed-off-by: Zhongbo Zhu --- transformer_engine/common/CMakeLists.txt | 1 - .../customized_pipeline.hpp | 2 +- ...cast_col_hadamard_transform_cast_fusion.cu | 106 ++++++++++++++---- .../custom_recipes/quantization_nvfp4.py | 1 - 4 files changed, 86 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index decc86dc75..79948e28f7 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -353,7 +353,6 @@ endforeach() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Number of parallel build jobs if($ENV{MAX_JOBS}) diff --git a/transformer_engine/common/hadamard_transform/customized_pipeline.hpp b/transformer_engine/common/hadamard_transform/customized_pipeline.hpp index 85c76c4080..967927f8c0 100644 --- a/transformer_engine/common/hadamard_transform/customized_pipeline.hpp +++ b/transformer_engine/common/hadamard_transform/customized_pipeline.hpp @@ -232,7 +232,7 @@ class CustomizedPipelineTmaUmmaAsync { } } else { if (!skip) { - if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + if constexpr (cute::is_static_v && size(ClusterShape{}) == 1) { cutlass::arch::umma_arrive(smem_ptr); } else { cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index aa24243544..88526f2a82 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -43,6 +43,8 @@ // include utils for get system env #include "../util/system.h" +// clang-format off + namespace transformer_engine { namespace detail { namespace { @@ -96,14 +98,20 @@ cutlass::Array StochasticNumericConverterBase( using result_type = cutlass::Array; result_type output; auto output_ptr = reinterpret_cast(&output); - asm volatile( - "{\n" - "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" - "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" - "}" - : "=h"(output_ptr[0]), "=h"(output_ptr[1]) - : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), - "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return output; } @@ -165,6 +173,7 @@ struct SharedStorage { uint32_t tmem_base_ptr; }; +// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support template (ASmemLayout{}); using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; @@ -195,19 +204,26 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( static constexpr bool kEnableRowQuant = kEnableRowQuant_; static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; static constexpr bool kEnableFastMath = kEnableFastMath_; + + // Constant for RHT tensor processing (tile size etc) static int constexpr RhtTensorSize = 16; + + // Transaction bytes for TMA transfer on RHT tensor blocks static int constexpr kTmaRhtTensorTransactionBytes = cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays using SwizzledSFALayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; using SwizzledSFDLayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling using MainloopPipeline = cutlass::detail::CustomizedPipelineTmaUmmaAsync; @@ -221,7 +237,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( using TmemAllocator = cute::TMEM::Allocator1Sm; static int constexpr VectorSize = RhtTensorSize; - // Preconditions + + // Compile-time safety: static shapes required for shared memory layouts CUTE_STATIC_ASSERT(is_static::value); CUTE_STATIC_ASSERT(is_static::value); // CUTE_STATIC_ASSERT(is_static::value); @@ -238,7 +255,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Total number of k-tiles int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + // Dynamic scheduler for SM100 architecture to support flexible kernel tiling and scheduling. struct TileScheduler { + // Structure to represent a single work tile's identification and state. struct WorkTileInfo { uint32_t m_idx = 0; uint32_t n_idx = 0; @@ -293,6 +312,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( return clc_pipe_producer_state; } + // Consumer method: fetch information for the next tile of work CUTLASS_DEVICE auto fetch_next_work(CLCPipeline &clc_pipeline, CLCPipelineState clc_pipe_producer_state) { clc_pipeline.consumer_wait(clc_pipe_producer_state); @@ -304,15 +324,18 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( return; } + // Updates the current work tile state to the next tile CUTLASS_DEVICE auto update_work_tile_info() { work_tile_info = next_work_tile_info; return; } + // Computes the linear index for the current tile based on m/n indices CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return work_tile_info.m_idx + work_tile_info.n_idx * tiles_in_m; } + // Issues a CLC (Cluster Launch Control) query to advance scheduling state machine. CUTLASS_HOST_DEVICE static void issue_clc_query(CLCPipelineState state, uint32_t mbarrier_addr, CLCResponse *clc_response_ptr) { @@ -330,6 +353,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( CUTLASS_NOT_IMPLEMENTED(); #endif } + + // Loads CLC response from shared memory and parses WorkTileInfo from result. CUTLASS_DEVICE static WorkTileInfo work_tile_info_from_clc_response(uint32_t result_addr) { WorkTileInfo work_tile_info; @@ -359,23 +384,30 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } }; - // Allocate SMEMork + // Allocate and alias shared memory to the kernel's shared storage type extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + + // Compute the number of tiles in M and N after tiling and assign scheduler uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(packed_N, size<2>(epilogue_tiler)))); TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, shared_storage.clc_response); + // Get this block's rank within the cluster int block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Shapes for accumulated tiles in mainloop and epilogue auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + // Number of threads assigned for various epilogue roles depending on quantization settings static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; @@ -493,22 +525,29 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } __syncthreads(); + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + // Partition tensors for tiling according to the mainloop and cluster tilers. Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + // Shared memory tensors for pipeline Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + // Determine warp/tile positioning int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Partition global to local fragments for A and B Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) @@ -537,6 +576,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } do { + // is_first_wave indicates whether this scheduler wave is the first among a group. bool is_first_wave = scheduler.is_first_wave(); uint32_t skip_wait = is_first_wave; auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); @@ -563,16 +603,17 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( tAsA(_, write_stage)); } } + // Synchronize using work scheduler. scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); ++clc_pipeline_consumer_state; scheduler.update_work_tile_info(); } while (scheduler.is_valid()); mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } - - else if (is_mma_warp) { + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. cutlass::arch::warpgroup_reg_dealloc<32>(); if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), @@ -592,6 +633,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( tmem_allocation_result_barrier.arrive(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); do { uint32_t skip_wait = K_TILE_MAX <= 0; @@ -637,6 +679,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps cutlass::arch::warpgroup_reg_dealloc<32>(); do { clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); @@ -649,6 +692,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( scheduler.update_work_tile_info(); } while (scheduler.is_valid()); } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. cutlass::arch::warpgroup_reg_alloc<192>(); if constexpr (kEnableRHTColQuant) { using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; @@ -660,9 +705,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); - // leveraging 256-bit writes to global memory + // Use 256-bit fragments for aligned bulk stores static int constexpr FragmentSize = 256 / sizeof_bits_v; + // Wait for TMEM allocation for this pipeline to finish tmem_allocation_result_barrier.arrive_and_wait(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; bulk_tmem_epilogue.data() = tmem_base_ptr; @@ -677,12 +723,14 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( size_t rng_seed = 0; size_t rng_offset = 0; + // Setup RNG for stochastic rounding if constexpr (kEnableStochasticRounding) { rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + // Determine quantization scale factor layouts/output splits for this group TSFDLayout sfd_layout; int cur_N = args.split_sections[group_idx]; if constexpr (kEnableSwizzleSFOutput) { @@ -691,6 +739,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); } + // Build output tensors for columns and their quant scales Tensor mD = make_tensor( cute::subbyte_iterator(reinterpret_cast(args.output_colwise_list[group_idx])), make_shape(M, cur_N), DStride{}); // (M,packed_N) @@ -705,6 +754,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); @@ -808,12 +858,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( static int constexpr NumVecs = size(tDgD) / VectorSize; Tensor tD_rRowSFD_frg = recast>(tDrSFD); + // Compute amax and quantization scales for this tile cutlass::maximum_absolute_value_reduction, true> amax_reduction; cutlass::Array vec_maxs; cutlass::Array pvscales; - // TMEM_LOAD + // Copy from TMEM to registers copy(tiled_t2r, tDtAcc, tTR_rAcc); cutlass::arch::fence_view_async_tmem_load(); accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); @@ -827,6 +878,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } + // Scale values for quantization depending on fast-math flag if constexpr (kEnableFastMath) { pvscales = cutlass::multiplies>{}( vec_maxs, global_encode_scale_multiplier); @@ -846,15 +898,16 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( qpvscale_ups, global_decode_scale); cutlass::Array acc_scales; if constexpr (kEnableFastMath) { - // fast math: use reciprocal approximate to replace div + // Fast-math: approximate compute reciprocal instead of divide. acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { - // regular path for slower math, use divide to replace div + // Regular path: computes using division acc_scales = cutlass::divides>{}( 1.0, qpvscale_scaled); } + // Prepare stochastic rounding random state if enabled uint4 random_uint4 = uint4{0, 0, 0, 0}; transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; // "Prefetch" a stochastic rounding state for the first tile @@ -864,6 +917,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( rng.init(rng_seed, rng_sequence, rng_offset); } CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding for (int v = 0; v < NumVecs; v++) { auto acc_scale = cutlass::minimum_with_nan_propagation{}( acc_scales[v], cutlass::platform::numeric_limits::max()); @@ -880,6 +934,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } } + // Write quantized FP4 tile and dequant scale to gmem copy(tiled_r2g, src, dst); copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); } @@ -887,6 +942,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } while (scheduler.is_valid()); } } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. cutlass::arch::warpgroup_reg_alloc<136>(); if constexpr (kEnableRowQuant) { using S2RVectorType = uint128_t; @@ -895,16 +951,18 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( int local_thread_idx = global_thread_idx % 256; size_t rng_seed = 0; size_t rng_offset = 0; - // g2s load all global_d_amax + // g2s load all global_a_amax for all groups/tensors CUTLASS_PRAGMA_NO_UNROLL for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) { shared_storage.global_a_amax[g] = __ldg(reinterpret_cast(args.global_a_amax_list[g])); } + // RNG for stochastic rounding if constexpr (kEnableStochasticRounding) { rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } + // Input/output tensors/partitions for row quant warp Tensor mQA = make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); @@ -912,10 +970,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + // Set up layouts for partitioning – tile-by-warp, with vector granularity using S2RWarpLayout = Layout>; using WarpGroupLayout = Layout>; using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); @@ -932,6 +992,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + // Allocate temporary register tensors for copying quantization => output Tensor tQArA = make_tensor_like( make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); @@ -971,7 +1032,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( if (cur_group_idx != group_idx) { group_idx = cur_group_idx; a_global_amax_val = shared_storage.global_a_amax[group_idx]; - // update amax + // Update group quantization parameters/scaling global_encode_scale = a_global_amax_val > 0.0f ? cutlass::minimum_with_nan_propagation{}( (fp8_max * fp4_max) / a_global_amax_val, @@ -1056,18 +1117,21 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( compute_frgs_up, acc_scale)); } } + // Copy the output quantized data and scaling factors into global memory for later use copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); } + // Move to next work tile for row quant warp scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); ++clc_pipeline_consumer_state; scheduler.update_work_tile_info(); } while (scheduler.is_valid()); } } else { + // Any extra warpgroup that is not used for anything above gets deallocated here cutlass::arch::warpgroup_reg_dealloc<32>(); } -} +} // NOLINT(readability/fn_size) template Tuple[ if self.columnwise_usage else global_amax_row ) - # TODO: this is a horrible hack to pass zero tolerance test # currently the amax of RHT transform still has a fp32 -> bf16 -> fp32 round-trip # this is to simulate that behaviour and we will remove this once the amax of RHT transform is fixed if self.columnwise_usage and with_rht_cast_fusion: From badcf7493d4e43cfe6f04fc058d8e2e17ff1e751 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 16 Dec 2025 15:45:58 -0800 Subject: [PATCH 20/36] refactor RS for generating two random values Signed-off-by: Zhongbo Zhu --- .../pytorch/csrc/extensions/cast.cpp | 216 +++++++++++++----- 1 file changed, 156 insertions(+), 60 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b6e25ffd50..abd8a2f9aa 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -709,12 +709,96 @@ std::tuple, std::vector, bool> bulk_alloc return retval; } +// Owns all allocations/wrappers backing quant_config_list[*].set_rng_state(...). +struct StochasticRngStateResources { + at::Tensor rng_states_tensor; // [2 * num_tensors], int64, CUDA + at::Tensor rng_states_tensor_colwise; // optional, same shape/dtype/device + std::vector te_rng_state_list; + std::vector te_rng_state_list_colwise; + + bool enabled{false}; + bool need_separate_rng_states{false}; + bool with_bulk_generate_rng_states{false}; +}; + +// Populates quant_config_list (+ optional colwise list) with rng_state pointers and stochastic flag. +static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( + size_t num_tensors, bool stochastic_rounding, bool with_bulk_generate_rng_states, + bool need_separate_rng_states, + std::vector &quant_config_list_rowwise, + std::vector &quant_config_list_colwise) { + // the return object will be used to keep rng states alive + StochasticRngStateResources res; + res.enabled = stochastic_rounding; + res.need_separate_rng_states = need_separate_rng_states; + res.with_bulk_generate_rng_states = with_bulk_generate_rng_states; + + if (!stochastic_rounding) return res; + + // Basic sanity: caller usually pre-sizes these to num_tensors. + TORCH_CHECK(quant_config_list_rowwise.size() == num_tensors, + "quant_config_list_rowwise must be sized to num_tensors"); + if (need_separate_rng_states) { + TORCH_CHECK(quant_config_list_colwise.size() == num_tensors, + "quant_config_list_colwise must be sized to num_tensors when " + "need_separate_rng_states=true"); + } + + const size_t rng_elts_per_thread = + res.with_bulk_generate_rng_states ? (1024 * num_tensors) : 1024; + + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + res.rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); + if (need_separate_rng_states) { + res.rng_states_tensor_colwise = torch::empty({static_cast(2 * num_tensors)}, opts); + } + + res.te_rng_state_list.reserve(num_tensors); + if (need_separate_rng_states) res.te_rng_state_list_colwise.reserve(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // Rowwise RNG state + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; + philox_unpack(philox_args, rng_state_ptr); + + res.te_rng_state_list.push_back(makeTransformerEngineTensor( + static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); + quant_config_list_rowwise[i].set_stochastic_rounding(true); + + // Colwise RNG state (only if you truly need a different sequence) + if (need_separate_rng_states) { + // re-initialize philox_args for colwise RNG state + at::PhiloxCudaState philox_args_col = init_philox_state(gen, rng_elts_per_thread); + int64_t *rng_state_ptr_colwise = + static_cast(res.rng_states_tensor_colwise.data_ptr()) + i * 2; + + philox_unpack(philox_args_col, rng_state_ptr_colwise); + + res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( + static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); + quant_config_list_colwise[i].set_stochastic_rounding(true); + } + + // break the loop if we are using bulk generate rng states + if (res.with_bulk_generate_rng_states) break; + } + + return res; +} + // Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) -void split_quantize_nvfp4_impl_with_rht_helper( - const TensorWrapper &input, const std::vector &input_list, - std::vector &output_list, const std::vector &split_sections, - const std::vector &quantizers, - const std::vector &quant_config_list, cudaStream_t stream) { +void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); @@ -725,6 +809,36 @@ void split_quantize_nvfp4_impl_with_rht_helper( nvte_tensor_output_list.push_back(output_list[i].data()); } + // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance + bool all_aligned_token_dim = + std::all_of(split_sections.begin(), split_sections.end(), + [](size_t split_section) { return split_section % 128 == 0; }); + + // in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice + // so that rowwise and colwise will have different random numbers + bool need_separate_rng_states = + (!all_aligned_token_dim) && quantizer.rowwise_usage && quantizer.columnwise_usage; + + // Objects for TE C API + std::vector quant_config_list; + std::vector quant_config_list_colwise; + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + quant_config_list_colwise.emplace_back(QuantizationConfigWrapper()); + } + + // this is true because we have already built grouped kernels for rowwise and colwise quantization with RHT + bool with_bulk_generate_rng_states = true; + + bool need_stochastic_rounding = quantizer.stochastic_rounding; + + auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( + num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, + need_separate_rng_states, quant_config_list, quant_config_list_colwise); + + auto &quant_config_list_colwise_to_use = + need_separate_rng_states ? quant_config_list_colwise : quant_config_list; + // Compute amaxes if (quantizer.with_post_rht_amax) { // We need: @@ -743,11 +857,6 @@ void split_quantize_nvfp4_impl_with_rht_helper( "RHT matrix is not available."); auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix); - // trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance - bool all_aligned_token_dim = - std::all_of(split_sections.begin(), split_sections.end(), - [](size_t split_section) { return split_section % 128 == 0; }); - if (all_aligned_token_dim) { // call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose nvte_group_hadamard_transform_cast_fusion( @@ -784,11 +893,6 @@ void split_quantize_nvfp4_impl_with_rht_helper( stream); } - // TODO(zhongbo): move the rng states if we enable stochastic rounding - // so that rowwise and colwise will have different random numbers - // this is not needed for all_aligned_token_dim path, because row & col are both fused - // into one kernel, the kernel itself generates two different random numbers for row & col - // Columnwise RHT quantization fusion with grouped version if (quantizer.columnwise_usage) { std::vector out_transpose_list; @@ -826,23 +930,52 @@ void split_quantize_nvfp4_impl_with_rht_helper( } nvte_group_hadamard_transform_cast_fusion_columnwise( input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream); + rht_matrix_nvte.data(), split_sections.data(), num_tensors, + quant_config_list_colwise_to_use[0], stream); } } } -void split_quantize_nvfp4_impl_helper( - const TensorWrapper &input, const std::vector &input_list, - std::vector &output_list, const std::vector &split_sections, - const std::vector &quantizers, - const std::vector &quant_config_list, cudaStream_t stream) { +void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers, + cudaStream_t stream) { const size_t num_tensors = input_list.size(); + const auto &quantizer = *quantizers.front(); + std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; for (size_t i = 0; i < num_tensors; ++i) { nvte_tensor_input_list.push_back(input_list[i].data()); nvte_tensor_output_list.push_back(output_list[i].data()); } + + // In this case without RHT, the rowwise and colwise quantization are fused + // we don't need separate rng states for rowwise and colwise + bool need_separate_rng_states = false; + + // Objects for TE C API + std::vector quant_config_list; + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + } + + // TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized + // so that we can generate all rng states at once + bool with_bulk_generate_rng_states = false; + + bool need_stochastic_rounding = quantizer.stochastic_rounding; + + // place holder for colwise rng states, which are not needed in this case + std::vector dummy_quant_config_list_colwise; + + auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( + num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, + need_separate_rng_states, quant_config_list, + dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for input too @@ -922,43 +1055,6 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); - // Objects for TE C API - std::vector quant_config_list; - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_list.emplace_back(QuantizationConfigWrapper()); - } - - // Stochastic rounding - std::vector te_rng_state_list; - at::Tensor rng_states_tensor; - - if (quantizer.stochastic_rounding) { - // TODO(zhongbo): remove the for loop of generating rng states with a single call - // with rng_elts_per_thread = 1024 * num_tensors - // RHT path is fully grouped kernels, which we can be optimized - bool with_bulk_generate_rng_states = quantizer.with_rht; - const size_t rng_elts_per_thread = with_bulk_generate_rng_states ? 1024 * num_tensors : 1024; - auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - rng_states_tensor = torch::empty({static_cast(2 * num_tensors)}, opts); - - for (size_t i = 0; i < num_tensors; ++i) { - auto gen = at::get_generator_or_default( - std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - // Generate RNG state for rowwise quantization - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - int64_t *rng_state_ptr = static_cast(rng_states_tensor.data_ptr()) + i * 2; - philox_unpack(philox_args, rng_state_ptr); - te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); - quant_config_list[i].set_rng_state(te_rng_state_list[i].data()); - quant_config_list[i].set_stochastic_rounding(true); - // break the loop if we are using bulk generate rng states - if (with_bulk_generate_rng_states) { - break; - } - } - } - // Perform multi-tensor quantization NVTE_SCOPED_GIL_RELEASE({ if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data @@ -966,11 +1062,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); // Fuse the rowwise and colwise into one when the kernel is ready split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, - quantizers, quant_config_list, stream); + quantizers, stream); } else { // NVFP4 quantize // Fuse the rowwise and colwise into one when the kernel is ready split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, - quant_config_list, stream); + stream); } }); } From 0d245aeb4abb9525b576e30a218d9603969ff41e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 18 Dec 2025 07:35:15 +0000 Subject: [PATCH 21/36] Avoid invalid configs with templated kernel Signed-off-by: Tim Moon --- ...cast_col_hadamard_transform_cast_fusion.cu | 66 ++++++++++++------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 88526f2a82..7ad88706b5 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -142,7 +142,7 @@ struct SharedStorage { using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; using AccumulatorPipeline = - cutlass::PipelineUmmaAsync; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; @@ -192,6 +192,28 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // float const* c_global_amax, const size_t *rng_state) { using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR("group_row_col_rht_gemm_device is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + if constexpr (!kEnableRHTColQuant_) { + // With kEnableRHTColQuant=false, we might configure + // mainloop_pipeline and accumulator_pipeline with zero consumers, + // which causes internal problems in CUTLASS and esoteric + // compile-time errors ("ptxas fatal: internal compiler error"). + NVTE_DEVICE_ERROR("group_row_col_rht_gemm_device requires column-wise quantization."); + return; + } +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + using X = Underscore; // Accumulator data type for main computation using ElementAccumulator = float; @@ -1293,7 +1315,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cga_tile_shape)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); - uint32_t tiles = tiles_in_m * tiles_in_n; dim3 dimBlock(512); dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); @@ -1370,8 +1391,8 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vectordata.dptr != nullptr; bool has_col_quant = output_list[i]->columnwise_data.dptr != nullptr; - all_has_row_quant &= has_row_quant; - all_has_col_quant &= has_col_quant; + all_has_row_quant = all_has_row_quant && has_row_quant; + all_has_col_quant = all_has_col_quant && has_col_quant; // sanity check, the two bool flags cannot be both false NVTE_CHECK(has_row_quant || has_col_quant, "At least one of the output tensors must have row or column quant."); @@ -1404,6 +1425,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( - /*packed_sequence_length=*/m, /*hidden_size=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*QA=*/reinterpret_cast(rowwise_data_base_ptr), - /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), - /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size); - - ););););); + use_fast_math, kEnableFastMath, + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, /*kEnableRhtColQuant=*/true, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kEnableFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + );););); } } // namespace transformer_engine From 83e7bf2aa3b32480c52add1bd2f1e8817d865d64 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 18 Dec 2025 11:12:00 -0800 Subject: [PATCH 22/36] fix acc pipeline init with 0 arrival count Signed-off-by: Zhongbo Zhu --- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 7ad88706b5..9f2969e811 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -501,9 +501,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; accumulator_pipeline_params.initializing_warp = 1; + using IsInitAccumulatorPipeline = cute::conditional_t; AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, cluster_shape, - cute::true_type{}, // Perform barrier init + IsInitAccumulatorPipeline{}, // Perform barrier init cute::true_type{}); // Delay mask calculation // CLC pipeline typename CLCPipeline::Params clc_pipeline_params; From b554bef50146b7bc0ad54f8e12340c2bd6b49fcb Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Thu, 18 Dec 2025 11:22:56 -0800 Subject: [PATCH 23/36] restore rowwise-only mode Signed-off-by: Zhongbo Zhu --- ...w_cast_col_hadamard_transform_cast_fusion.cu | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 9f2969e811..2be50c7f5b 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -201,14 +201,6 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( "Try recompiling with sm_100a or similar."); return; } - if constexpr (!kEnableRHTColQuant_) { - // With kEnableRHTColQuant=false, we might configure - // mainloop_pipeline and accumulator_pipeline with zero consumers, - // which causes internal problems in CUTLASS and esoteric - // compile-time errors ("ptxas fatal: internal compiler error"). - NVTE_DEVICE_ERROR("group_row_col_rht_gemm_device requires column-wise quantization."); - return; - } #if !defined(CUTLASS_ARCH_CLC_ENABLED) CUTLASS_NOT_IMPLEMENTED(); return; @@ -1426,9 +1418,6 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( /*packed_sequence_length=*/m, /*hidden_size=*/n, /*A=*/reinterpret_cast(input.dptr), @@ -1506,7 +1497,7 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector Date: Thu, 18 Dec 2025 16:03:55 -0800 Subject: [PATCH 24/36] switch to dynamic atomic scheduler Signed-off-by: Zhongbo Zhu --- ...cast_col_hadamard_transform_cast_fusion.cu | 469 ++++++++---------- 1 file changed, 206 insertions(+), 263 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 2be50c7f5b..7691999bd8 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -50,12 +50,10 @@ namespace detail { namespace { using namespace cute; -using cute:: - Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor -struct CLCResponse { - uint32_t data[4] = {0}; -}; +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + constexpr int kMaxTensorsPerKernel = 64; @@ -151,13 +149,12 @@ struct SharedStorage { cutlass::detail::CustomizedPipelineTmaUmmaAsync, AtomThrShapeMNK>; using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; - using CLCPipeline = cutlass::PipelineCLCFetchAsync; - using CLCPipelineStorage = typename CLCPipeline::SharedStorage; - using CLCThrottlePipeline = cutlass::PipelineAsync; - using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineStorage = typename SchedPipeline::SharedStorage; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage; struct TensorStorage : cute::aligned_struct<128, _1> { - // cute::array_aligned> smem_A; cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; @@ -165,11 +162,12 @@ struct SharedStorage { alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) MainloopPipelineStorage mainloop; alignas(16) cute::uint64_t tma_barrier[1]; - alignas(16) CLCPipelineStorage clc; - alignas(16) CLCThrottlePipelineStorage clc_throttle; - alignas(16) CLCResponse clc_response[SchedulerPipelineStageCount_]; + alignas(16) SchedPipelineStorage sched; + alignas(16) SchedThrottlePipelineStorage sched_throttle; + alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_]; alignas(16) float global_a_amax[kMaxTensorsPerKernel]; alignas(16) float global_d_amax[kMaxTensorsPerKernel]; + uint32_t atomic_tile_counter[SchedulerPipelineStageCount_]; uint32_t tmem_base_ptr; }; @@ -187,9 +185,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, MultiAmaxHadamardCastFusionArgs args, + uint32_t* tile_scheduler_workspace, TiledMMA mma, - // float const* a_global_amax, - // float const* c_global_amax, const size_t *rng_state) { using namespace cute; @@ -242,12 +239,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::detail::CustomizedPipelineTmaUmmaAsync; using MainloopPipelineState = typename MainloopPipeline::PipelineState; - using CLCPipeline = cutlass::PipelineCLCFetchAsync; - using CLCPipelineState = typename CLCPipeline::PipelineState; - using CLCThrottlePipeline = cutlass::PipelineAsync; - using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; - static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + static_assert(ClusterShape{} == Shape<_1,_1,_1>{}, "ClusterShape must be Shape<_1,_1,_1>"); using TmemAllocator = cute::TMEM::Allocator1Sm; static int constexpr VectorSize = RhtTensorSize; @@ -269,132 +266,98 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Total number of k-tiles int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); - // Dynamic scheduler for SM100 architecture to support flexible kernel tiling and scheduling. struct TileScheduler { - // Structure to represent a single work tile's identification and state. - struct WorkTileInfo { - uint32_t m_idx = 0; - uint32_t n_idx = 0; - uint32_t l_idx = 0; - bool is_valid_tile = false; - }; uint32_t tiles_in_m = 0; uint32_t tiles_in_n = 0; - + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; int k_tile_max = 0; - - int wave_cnt = 0; - WorkTileInfo work_tile_info; - WorkTileInfo next_work_tile_info; - CLCResponse *clc_response_ptr_; - CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, - CLCResponse *clc_response_ptr) - : tiles_in_m(tiles_m), - tiles_in_n(tiles_n), - - k_tile_max(kmax), - work_tile_info( - {blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x < tiles_m && blockIdx.y < tiles_n}), - next_work_tile_info( - {blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x < tiles_m && blockIdx.y < tiles_n}), - clc_response_ptr_(clc_response_ptr) {} - - CUTLASS_DEVICE uint32_t tile_m() const { return work_tile_info.m_idx; } + uint32_t* atomic_tile_index_; + uint32_t* smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, uint32_t* atomic_tile_index_, uint32_t* smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index_), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { + return tile_m_idx; + } CUTLASS_DEVICE uint32_t tile_n_base() const { - return work_tile_info.n_idx * uint32_t(k_tile_max); + return tile_n_idx; } - CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + CUTLASS_DEVICE bool is_valid() const { - return cute::elem_less(cute::make_coord(work_tile_info.m_idx, work_tile_info.n_idx), - cute::make_coord(tiles_in_m, tiles_in_n)) && - work_tile_info.is_valid_tile; + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), cute::make_coord(tiles_in_m, tiles_in_n)); } - CUTLASS_DEVICE bool is_first_wave() const { return wave_cnt == 0; } - CUTLASS_DEVICE auto advance_to_next_work(CLCPipeline &clc_pipeline, - CLCPipelineState clc_pipe_producer_state) { - uint32_t mbarrier_addr = clc_pipeline.producer_get_barrier(clc_pipe_producer_state); - // Wait for clcID buffer to become empty with a flipped phase - clc_pipeline.producer_acquire(clc_pipe_producer_state); - if (cute::elect_one_sync()) { - issue_clc_query(clc_pipe_producer_state, mbarrier_addr, clc_response_ptr_); - } + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } - ++clc_pipe_producer_state; - return clc_pipe_producer_state; - } + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } - // Consumer method: fetch information for the next tile of work - CUTLASS_DEVICE auto fetch_next_work(CLCPipeline &clc_pipeline, - CLCPipelineState clc_pipe_producer_state) { - clc_pipeline.consumer_wait(clc_pipe_producer_state); - uint32_t smem_addr = - cute::cast_smem_ptr_to_uint(&clc_response_ptr_[clc_pipe_producer_state.index()]); - next_work_tile_info = work_tile_info_from_clc_response(smem_addr); - clc_pipeline.consumer_release(clc_pipe_producer_state); - wave_cnt++; - return; + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter( int pred ) { + uint32_t tile_id_counter = 0; + asm volatile( "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"( tile_id_counter ) + : "l"( atomic_tile_index_ ), "r"( pred ) ); + + return tile_id_counter; } - // Updates the current work tile state to the next tile - CUTLASS_DEVICE auto update_work_tile_info() { - work_tile_info = next_work_tile_info; + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline& sched_pipeline, SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); return; } - // Computes the linear index for the current tile based on m/n indices - CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { - return work_tile_info.m_idx + work_tile_info.n_idx * tiles_in_m; - } + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline& sched_pipeline, SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } - // Issues a CLC (Cluster Launch Control) query to advance scheduling state machine. - CUTLASS_HOST_DEVICE - static void issue_clc_query(CLCPipelineState state, uint32_t mbarrier_addr, - CLCResponse *clc_response_ptr) { -#if defined(CUTLASS_ARCH_CLC_ENABLED) - uint32_t result_addr = cute::cast_smem_ptr_to_uint( - reinterpret_cast(&clc_response_ptr[state.index()])); - asm volatile( - "{\n\t" - "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes." - "multicast::cluster::all.b128 [%0], [%1];\n\t" - "}\n" - : - : "r"(result_addr), "r"(mbarrier_addr)); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; } - // Loads CLC response from shared memory and parses WorkTileInfo from result. - CUTLASS_DEVICE - static WorkTileInfo work_tile_info_from_clc_response(uint32_t result_addr) { - WorkTileInfo work_tile_info; - uint32_t valid = 0; -#if defined(CUTLASS_ARCH_CLC_ENABLED) - asm volatile( - "{\n" - ".reg .pred p1;\n\t" - ".reg .b128 clc_result;\n\t" - "ld.shared.b128 clc_result, [%4];\n\t" - "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" - "selp.u32 %3, 1, 0, p1;\n\t" - "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, %1, %2, _}, " - "clc_result;\n\t" - "}\n" - : "=r"(work_tile_info.m_idx), "=r"(work_tile_info.n_idx), "=r"(work_tile_info.l_idx), - "=r"(valid) - : "r"(result_addr) - : "memory"); - - cutlass::arch::fence_view_async_shared(); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - work_tile_info.is_valid_tile = (valid == 1); - return work_tile_info; + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; } }; @@ -407,10 +370,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Compute the number of tiles in M and N after tiling and assign scheduler uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); - uint32_t tiles_in_n = uint32_t(size(ceil_div(packed_N, size<2>(epilogue_tiler)))); - TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, shared_storage.clc_response); + uint32_t tiles_in_n = uint32_t(size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); + + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, shared_storage.atomic_tile_counter); - // Get this block's rank within the cluster int block_rank_in_cluster = cute::block_rank_in_cluster(); // Shapes for accumulated tiles in mainloop and epilogue @@ -445,13 +408,6 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); - // if (is_epilogue_col_quant_warp && elect_one_sync()) { - // cute::prefetch(raw_pointer_cast(c_global_amax)); - // } - // if (is_epilogue_row_quant_warp && elect_one_sync()) { - // cute::prefetch(raw_pointer_cast(a_global_amax)); - // } - typename MainloopPipeline::Params mainloop_pipeline_params; if (is_dma_warp) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; @@ -498,42 +454,38 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cluster_shape, IsInitAccumulatorPipeline{}, // Perform barrier init cute::true_type{}); // Delay mask calculation - // CLC pipeline - typename CLCPipeline::Params clc_pipeline_params; + typename SchedPipeline::Params sched_pipeline_params; if (is_sched_warp) { - clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; - } else { - clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } + else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; } - clc_pipeline_params.producer_blockid = 0; - clc_pipeline_params.producer_arv_count = 1; - clc_pipeline_params.consumer_arv_count = - NumSchedThreads + - cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); - clc_pipeline_params.transaction_bytes = sizeof(CLCResponse); - clc_pipeline_params.initializing_warp = 3; - CLCPipeline clc_pipeline(shared_storage.clc, clc_pipeline_params, cluster_shape); - CLCPipelineState clc_pipeline_consumer_state; - CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); - - // CLC throttle pipeline - typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; if (is_dma_warp) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; } if (is_sched_warp) { - clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; } - clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - clc_throttle_pipeline_params.dst_blockid = 0; - clc_throttle_pipeline_params.initializing_warp = 4; + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; - CLCThrottlePipeline clc_throttle_pipeline(shared_storage.clc_throttle, - clc_throttle_pipeline_params); - CLCThrottlePipelineState clc_pipe_throttle_consumer_state; - CLCThrottlePipelineState clc_pipe_throttle_producer_state = - cutlass::make_producer_start_state(); + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = cutlass::make_producer_start_state(); if (warp_idx == 2 && elect_one_sync()) { cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); @@ -597,11 +549,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); int k_tile = 0; - // Throttle CLC producer - clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); - clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); - ++clc_pipe_throttle_producer_state; - + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; CUTLASS_PRAGMA_NO_UNROLL while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { int k_tile_idx_n = scheduler.tile_n_base() + k_tile; @@ -618,10 +568,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( tAsA(_, write_stage)); } } - // Synchronize using work scheduler. - scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); - ++clc_pipeline_consumer_state; + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); + // scheduler.advance(); } while (scheduler.is_valid()); mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); } else if (is_mma_warp) { @@ -653,13 +603,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( do { uint32_t skip_wait = K_TILE_MAX <= 0; - auto barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); - ++clc_pipeline_consumer_state; + auto barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state, + skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); int read_stage = mainloop_pipe_consumer_state.index(); auto tCrA_mk = tCrA(_, _, _, read_stage); @@ -697,13 +647,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Scheduler warp manages tile assignment and pipeline progress for warps cutlass::arch::warpgroup_reg_dealloc<32>(); do { - clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); - clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); - ++clc_pipe_throttle_consumer_state; - clc_pipeline_producer_state = - scheduler.advance_to_next_work(clc_pipeline, clc_pipeline_producer_state); - scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); - ++clc_pipeline_consumer_state; + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); } while (scheduler.is_valid()); } else if (is_epilogue_col_quant_warp) { @@ -797,15 +746,14 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto sfc_converter = cutlass::NumericConverter{}; do { - scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); - ++clc_pipeline_consumer_state; + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); - ++k_tile) { - int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() +k_tile) * size<1>(epilogue_tiler); int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + if (cur_group_idx != group_idx) { group_idx = cur_group_idx; c_global_amax_val = shared_storage.global_d_amax[group_idx]; @@ -1132,18 +1080,17 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( compute_frgs_up, acc_scale)); } } - // Copy the output quantized data and scaling factors into global memory for later use copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); } - // Move to next work tile for row quant warp - scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); - ++clc_pipeline_consumer_state; + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); + }while (scheduler.is_valid()); } + } else { - // Any extra warpgroup that is not used for anything above gets deallocated here cutlass::arch::warpgroup_reg_dealloc<32>(); } } // NOLINT(readability/fn_size) @@ -1208,37 +1155,32 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz auto dD = LayoutRight{}; // (dM,dN) auto dQA = stride(tensorQA); // (dM,dK) using ClusterShape = Shape<_1, _1, _1>; - auto cga_shape = ClusterShape{}; - auto cga_tile_shape = Shape<_128, Int, Int>{}; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128,Int,Int>{}; auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles - static int constexpr EpilogueUnrollFactor = - size<2>(cluster_tile_mainloop) / size<2>(cga_tile_shape); + static int constexpr EpilogueUnrollFactor = size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); // Construct the MMA - auto mma = make_tiled_mma( - SM100_MMA_F16BF16_SS(cga_tile_shape), size<1>(cga_tile_shape), - UMMA::Major::MN, UMMA::Major::MN>{}, - Layout>{}); + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + // Assert that the TiledMMA uses all CTAs in the CGA. - CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); - CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); // Determine the A and B shapes - auto mma_shape_B = - partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); using TiledMma = decltype(mma); using AtomThrID = typename TiledMma::AtomThrID; - using SmemShape_M = decltype(shape_div( - shape<0>(cga_tile_shape), - shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); - using SmemShape_N = decltype(shape_div( - shape<1>(cga_tile_shape), - shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); - using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + using SmemShape_M = decltype(shape_div(shape<0>(cluster_tile_shape), shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cluster_tile_shape), shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector(cga_tile_shape) * cute::size<1>(cga_tile_shape)); + TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); // Define the smem layouts (static) // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory - constexpr int SchedulerPipelineStageCount = 6; + constexpr int SchedulerPipelineStageCount = 4; static int constexpr MainloopPipelineBytes = sizeof( typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, Shape<_1, _1, _1>>::SharedStorage); - static int constexpr ClcResponseBytes = sizeof(CLCResponse) * SchedulerPipelineStageCount; - static int constexpr CLCThrottlePipelineBytes = - sizeof(typename cutlass::PipelineAsync::SharedStorage); - static int constexpr CLCPipelineBytes = - sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + + static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount; + static int constexpr SchedulerThrottlePipelineBytes = sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr SchedulerPipelineBytes = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); static int constexpr AccPipelineBytes = sizeof( @@ -1282,15 +1223,11 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes static int constexpr kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; - static int constexpr kReservedBytes = - ClcResponseBytes + CLCThrottlePipelineBytes + TmemBasePtrsBytes + CLCPipelineBytes + - TmemDeallocBytes + BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + SchedulerPipelineBytes + + TmemBasePtrsBytes + TmemDeallocBytes+BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; auto sP = Int{}; // SMEM pipelines - // printf("\nmax stages: %d\n", int(kMaxStages)); - // printf("\nreserved bytes: %d\n", int(kReservedBytes)); - // printf("\nbytes per stage: %d\n", int(kBytesPerStage)); - // printf("\nremaining bytes: %d\n", int((kBlackwellSmemSize - kReservedBytes) % kBytesPerStage)); + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP), Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE) auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, @@ -1300,26 +1237,24 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz auto tma_load_a = make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); auto tma_load_b = - make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cga_tile_shape, mma); + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma); // Assert checks on tile sizes -- no predication - assert(M % size<0>(cga_tile_shape) == 0); - assert(N % size<1>(cga_tile_shape) == 0); + assert(M % size<0>(cluster_tile_shape) == 0); + assert(N % size<1>(cluster_tile_shape) == 0); - uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cga_tile_shape)))); - uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); dim3 dimBlock(512); - dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); - dim3 dimGrid(tiles_in_m, tiles_in_n, 1); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(sm_count, 1, 1); int smem_size = sizeof( SharedStorage); auto *kernel_ptr = &group_row_col_rht_gemm_device< - decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_shape), - decltype(cga_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, + decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape), + decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB, decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, @@ -1333,12 +1268,18 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz return; } + void *tile_scheduler_workspace = nullptr; + cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream); + // reset the tile_scheduler_workspace to 0 + cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream); cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; cutlass::Status status = cutlass::launch_kernel_on_cluster( - params, (void const *)kernel_ptr, M, N, k_tile_size, cga_shape, cga_tile_shape, A, dA, sA, - tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, mma, rng_state); + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, sA, + tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, tile_scheduler_workspace, mma, rng_state); // CUTE_CHECK_LAST(); + cudaFreeAsync(tile_scheduler_workspace, stream); + if (status != cutlass::Status::kSuccess) { std::cerr << "Error: Failed at kernel Launch" << std::endl; return; @@ -1477,27 +1418,29 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( - /*packed_sequence_length=*/m, /*hidden_size=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*QA=*/reinterpret_cast(rowwise_data_base_ptr), - /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), - /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size); - ););););); + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_col_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kEnableFastMath, + + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kEnableFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + + ););););); } } // namespace transformer_engine From 4df34cefbcf720d24ea24b298bc1bc739a65ddf6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 19 Dec 2025 00:45:25 +0000 Subject: [PATCH 25/36] Avoid instantiating group RHT+cast kernel without row-wise or col-wise output Signed-off-by: Tim Moon --- ...cast_col_hadamard_transform_cast_fusion.cu | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 7691999bd8..fb916bdce9 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -198,6 +198,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( "Try recompiling with sm_100a or similar."); return; } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "group_row_col_rht_gemm_device must generate row-wise " + "and/or column-wise output."); #if !defined(CUTLASS_ARCH_CLC_ENABLED) CUTLASS_NOT_IMPLEMENTED(); return; @@ -432,6 +435,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::PipelineUmmaAsync; using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; AccumulatorPipelineState accumulator_pipe_consumer_state; AccumulatorPipelineState accumulator_pipe_producer_state = @@ -449,10 +453,9 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; accumulator_pipeline_params.initializing_warp = 1; - using IsInitAccumulatorPipeline = cute::conditional_t; AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, cluster_shape, - IsInitAccumulatorPipeline{}, // Perform barrier init + AccumulatorPipelineInitBarriers{}, cute::true_type{}); // Delay mask calculation typename SchedPipeline::Params sched_pipeline_params; if (is_sched_warp) { @@ -1420,27 +1423,36 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( - /*packed_sequence_length=*/m, /*hidden_size=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*QA=*/reinterpret_cast(rowwise_data_base_ptr), - /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), - /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size); - - ););););); + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kEnableFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, + kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, + TA, TB, TQA, TSFA, TD, TSFD, + kEnableFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", + kEnableRowQuant, ")."); + } + + ););););); } } // namespace transformer_engine From cbdda209cfd5f00dc1a1f4ae632af1a9fa6d96b1 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 19 Dec 2025 02:43:24 +0000 Subject: [PATCH 26/36] Include fast math option in quantization config Signed-off-by: Tim Moon --- transformer_engine/common/common.h | 4 +- .../group_hadamard_transform_cast_fusion.cu | 41 ++++++------- ...cast_col_hadamard_transform_cast_fusion.cu | 60 +++++++++---------- .../hadamard_transform_cast_fusion.cu | 42 ++++++------- .../transformer_engine/transformer_engine.h | 12 ++++ .../common/transformer_engine.cpp | 20 ++++++- .../pytorch/csrc/extensions/cast.cpp | 16 ++++- 7 files changed, 114 insertions(+), 81 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 38b437b994..0e264eaae3 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -394,6 +394,7 @@ struct QuantizationConfig { NVTETensor rng_state = nullptr; bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; + bool use_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(bool), // force_pow_2_scales @@ -402,7 +403,8 @@ struct QuantizationConfig { sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format sizeof(NVTETensor), // rng_seed and offset sizeof(bool), // nvfp4_2d_quantization - sizeof(bool) // stochastic_rounding + sizeof(bool), // stochastic_rounding + sizeof(bool) // use_fast_math }; }; diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 8c9e16ebdb..afa8335316 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -170,7 +170,7 @@ template + bool kUseFastMath = false> __global__ static void group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, @@ -493,7 +493,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; - // (optional) path for faster math, use multiply to repalce div static constexpr float fp4_max_inv = 1.0f / fp4_max; // get global amax pointer @@ -544,11 +543,13 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til float global_amax_val = *global_amax_ptr; float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // will be used in fast math path if enabled + + // Scaling factor for fast math path float global_encode_scale_multiplier = 1.0f; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } + float global_decode_scale = 1.0f / global_encode_scale; auto sfd_converter = cutlass::NumericConverter{}; @@ -566,8 +567,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (tensor_id != new_tensor_id) { global_amax_val = *global_amax_ptr; global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // will be used in fast math path if enabled - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } global_decode_scale = 1.0f / global_encode_scale; @@ -670,11 +670,11 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - if constexpr (kEnableFastMath) { - // path for faster math, use multiply to repalce div + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); } else { - // regular path for slower math, use divide + // Accurate math: perform division pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); } @@ -684,11 +684,11 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); cutlass::Array acc_scales; - if constexpr (kEnableFastMath) { - // fast math: use reciprocal approximate to replace div + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { - // regular path for slower math, use divide to replace div + // Accurate math: compute reciprocal with division acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); } @@ -736,7 +736,7 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void group_rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -850,7 +850,7 @@ group_rht_gemm_ntt_w_sfc(int m, int n, TSFC, decltype(mma), kEnableStochasticRounding, - kEnableFastMath>; + kUseFastMath>; bool status = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -872,7 +872,7 @@ group_rht_gemm_ntt_w_sfc(int m, int n, // this function is used to wrap the group_rht_gemm_ntt_w_sfc function // to transpose the input tensor A -template +template void group_rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -893,7 +893,7 @@ group_rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - group_rht_gemm_ntt_w_sfc( + group_rht_gemm_ntt_w_sfc( n, m, A, B, kernel_args_ptr, @@ -1027,17 +1027,12 @@ void group_hadamard_transform_cast_fusion_columnwise( k_tile_size = 512; } - // TODO: haven't decided whether to expose this as a API option or not - // use fast math if there is a ENV var NVTE_RHT_CAST_FUSION_USE_FAST_MATH, default to false - static const bool use_fast_math = - transformer_engine::getenv("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); - TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kUseStochasticRounding, TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, kEnableFastMath, + quant_config.use_fast_math, kUseFastMath, detail::group_rht_gemm_ttt_wrapper( + kUseFastMath>( /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), /*B=*/reinterpret_cast(hadamard_matrix.dptr), /*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count, diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index fb916bdce9..26afaf8ae7 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -179,7 +179,7 @@ template + bool kUseFastMath_ = false> __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile, TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, @@ -217,7 +217,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; static constexpr bool kEnableRowQuant = kEnableRowQuant_; static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; - static constexpr bool kEnableFastMath = kEnableFastMath_; + static constexpr bool kUseFastMath = kUseFastMath_; // Constant for RHT tensor processing (tile size etc) static int constexpr RhtTensorSize = 16; @@ -733,20 +733,20 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} static constexpr float fp4_max = 6.0f; static constexpr float fp8_max = 448.0f; - float const fp4_max_inv = 1.0f / fp4_max; + static constexpr float fp4_max_inv = 1.0f / fp4_max; float c_global_amax_val = shared_storage.global_d_amax[group_idx]; float global_encode_scale = c_global_amax_val > 0.0f ? cutlass::minimum_with_nan_propagation{}( (fp8_max * fp4_max) / c_global_amax_val, cutlass::platform::numeric_limits::max()) : 1.0f; - float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path float global_encode_scale_multiplier = 1.0f; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } - auto sfc_converter = cutlass::NumericConverter{}; do { scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); @@ -767,7 +767,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::platform::numeric_limits::max()) : 1.0f; global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } cur_N = args.split_sections[group_idx]; @@ -844,11 +844,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - // Scale values for quantization depending on fast-math flag - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal pvscales = cutlass::multiplies>{}( vec_maxs, global_encode_scale_multiplier); } else { + // Accurate math: perform division pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}( @@ -863,12 +864,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto qpvscale_scaled = cutlass::multiplies>{}( qpvscale_ups, global_decode_scale); cutlass::Array acc_scales; - if constexpr (kEnableFastMath) { - // Fast-math: approximate compute reciprocal instead of divide. + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { - // Regular path: computes using division + // Accurate math: compute reciprocal with division acc_scales = cutlass::divides>{}( 1.0, qpvscale_scaled); } @@ -975,7 +976,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} static constexpr float fp4_max = 6.0f; static constexpr float fp8_max = 448.0f; - float const fp4_max_inv = 1.0f / fp4_max; + static constexpr float fp4_max_inv = 1.0f / fp4_max; float global_encode_scale = a_global_amax_val > 0.0f ? cutlass::minimum_with_nan_propagation{}( (fp8_max * fp4_max) / a_global_amax_val, @@ -984,7 +985,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( float global_decode_scale = 1.0f / global_encode_scale; float global_encode_scale_multiplier = 1.0f; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } auto sfa_converter = cutlass::NumericConverter{}; @@ -1005,7 +1006,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::platform::numeric_limits::max()) : 1.0f; global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } } @@ -1044,10 +1045,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( cutlass::NumericArrayConverter{}( compute_frgs[v]); amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal pvscales_view(_0{}, v) = cutlass::multiplies{}( amax_view(_0{}, v), global_encode_scale_multiplier); } else { + // Accurate math: perform division pvscales_view(_0{}, v) = cutlass::divides{}(amax_view(_0{}, v), fp4_max); pvscales_view(_0{}, v) = cutlass::multiplies{}( @@ -1059,12 +1062,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto qpvscale_scaled = cutlass::multiplies{}(qpvscale_ups, global_decode_scale); ElementAccumulator acc_scales; - if constexpr (kEnableFastMath) { - // fast math: use reciprocal approximate to replace div + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { - // regular path for slower math, use divide to replace div + // Accurate math: compute reciprocal with division acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); } auto acc_scale = cutlass::minimum_with_nan_propagation{}( @@ -1100,7 +1103,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( template + class TSFD = TSFA, bool kUseFastMath = false> void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A, TB const *B, TQA *QA, TSFA *SFA, MultiAmaxHadamardCastFusionArgs &args, @@ -1261,7 +1264,7 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD, decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma), AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, - kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kEnableFastMath>; + kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; bool status_set_attr = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); @@ -1365,8 +1368,10 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector{2}, @@ -1413,11 +1418,6 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); - const bool use_swizzle_sf_output = false; TRANSFORMER_ENGINE_SWITCH_CONDITION( @@ -1429,7 +1429,7 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( + kUseFastMath>( /*packed_sequence_length=*/m, /*hidden_size=*/n, /*A=*/reinterpret_cast(input.dptr), /*B=*/reinterpret_cast(hadamard_matrix.dptr), diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 10cbbb5157..2a5fa53b8a 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -132,7 +132,7 @@ template + bool kUseFastMath = false> __global__ static void rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, @@ -426,17 +426,16 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // NVFP4 non-E8 recipe constants and global scales static constexpr float fp4_max = 6.0f; - // (optional) path for faster math, use multiply to repalce div - static constexpr float fp4_max_inv = 1.0f / fp4_max; const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - // will be used in fast math path if enabled + const float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path float global_encode_scale_multiplier = 1.0f; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { + static constexpr float fp4_max_inv = 1.0f / fp4_max; global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; } - const float global_decode_scale = 1.0f / global_encode_scale; - auto sfd_converter = cutlass::NumericConverter{}; do { for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { @@ -491,11 +490,11 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - if constexpr (kEnableFastMath) { - // path for faster math, use multiply to repalce div + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); } else { - // regular path for slower math, use divide + // Accurate math: perform division pvscales = cutlass::divides>{}(vec_maxs, fp4_max); pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); } @@ -505,11 +504,11 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); cutlass::Array acc_scales; - if constexpr (kEnableFastMath) { - // fast math: use reciprocal approximate to replace div + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { - // regular path for slower math, use divide to replace div + // Accurate math: compute reciprocal with division acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); } @@ -555,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template +template void rht_gemm_ntt_w_sfc(int m, int n, TA const* A, @@ -668,7 +667,7 @@ rht_gemm_ntt_w_sfc(int m, int n, TSFC, decltype(mma), kEnableStochasticRounding, - kEnableFastMath>; + kUseFastMath>; bool status = cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -691,7 +690,7 @@ rht_gemm_ntt_w_sfc(int m, int n, // this function is used to wrap the rht_gemm_ntt_w_sfc function //to transpose the input tensor A -template +template void rht_gemm_ttt_wrapper(int m, int n, TA const* A, @@ -714,7 +713,7 @@ rht_gemm_ttt_wrapper(int m, int n, // B: 16 x 16: row-major // C: n x m: row-major // SFC: n x (m/16): row-major - rht_gemm_ntt_w_sfc( + rht_gemm_ntt_w_sfc( n, m, A, B, C, SFC, global_amax, @@ -825,16 +824,11 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out k_tile_size = 512; } - // TODO: haven't decided whether to expose this as a API option or not - // use fast math if there is a ENV var NVTE_RHT_CAST_FUSION_USE_FAST_MATH, default to false - static const bool use_fast_math = - transformer_engine::getenv("NVTE_RHT_CAST_FUSION_USE_FAST_MATH", false); - TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kUseStochasticRounding, TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, kEnableFastMath, - detail::rht_gemm_ttt_wrapper( + quant_config.use_fast_math, kUseFastMath, + detail::rht_gemm_ttt_wrapper( /*m=*/m, /*n=*/n, /*A=*/reinterpret_cast(input.dptr), diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b2e04ba69f..19cb646be2 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute { kNVTEQuantizationConfigNVFP42DQuantization = 5, /*! Whether to enable stochastic rounding */ kNVTEQuantizationConfigStochasticRounding = 6, + /*! Whether to enable fast math operations with reduced accuracy. + * + * Optimizations are kernel-specific and they may be applied + * inconsistently between kernels. + */ + kNVTEQuantizationConfigUseFastMath = 7, kNVTEQuantizationConfigNumAttributes }; @@ -997,6 +1003,12 @@ class QuantizationConfigWrapper { &stochastic_rounding, sizeof(bool)); } + /*! \brief Set whether to enable fast math operations */ + void set_use_fast_math(bool use_fast_math) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath, + &use_fast_math, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 8d9563b789..4a140b4376 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, // Write attribute size NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); - NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; - *size_written = attr_size; + if (size_written != nullptr) { + *size_written = attr_size; + } // Return immediately if buffer is not provided if (buf == nullptr) { @@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(buf, &config_.rng_state, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(buf, &config_.stochastic_rounding, attr_size); + break; + case kNVTEQuantizationConfigUseFastMath: + std::memcpy(buf, &config_.use_fast_math, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigStochasticRounding: std::memcpy(&config_.stochastic_rounding, buf, attr_size); break; + case kNVTEQuantizationConfigUseFastMath: + std::memcpy(&config_.use_fast_math, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index abd8a2f9aa..6927f8031c 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -16,6 +16,7 @@ #include "../extensions.h" #include "common.h" +#include "common/util/system.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" @@ -830,12 +831,25 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // this is true because we have already built grouped kernels for rowwise and colwise quantization with RHT bool with_bulk_generate_rng_states = true; + // Stochastic rounding bool need_stochastic_rounding = quantizer.stochastic_rounding; - auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); + // Enable NVFP4 kernels to use math operations that sacrifice + // accuracy for performance. These optimizations are experimental + // and inconsistently implemented. + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + for (auto &config: quant_config_list) { + config.set_use_fast_math(true); + } + for (auto &config: quant_config_list_colwise) { + config.set_use_fast_math(true); + } + } + auto &quant_config_list_colwise_to_use = need_separate_rng_states ? quant_config_list_colwise : quant_config_list; From 0ac4d74e9cf13e5a4cc0746a3e0372d5aec2b6a4 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 19 Dec 2025 03:23:18 +0000 Subject: [PATCH 27/36] Fix linter warnings and review nits Signed-off-by: Tim Moon --- ...ed_pipeline.hpp => customized_pipeline.cuh} | 4 +--- .../group_hadamard_transform_cast_fusion.cu | 11 +++++------ ..._cast_col_hadamard_transform_cast_fusion.cu | 18 ++++++++---------- .../hadamard_transform_cast_fusion.cu | 5 ----- 4 files changed, 14 insertions(+), 24 deletions(-) rename transformer_engine/common/hadamard_transform/{customized_pipeline.hpp => customized_pipeline.cuh} (98%) diff --git a/transformer_engine/common/hadamard_transform/customized_pipeline.hpp b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh similarity index 98% rename from transformer_engine/common/hadamard_transform/customized_pipeline.hpp rename to transformer_engine/common/hadamard_transform/customized_pipeline.cuh index 967927f8c0..c5d9e63138 100644 --- a/transformer_engine/common/hadamard_transform/customized_pipeline.hpp +++ b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh @@ -94,7 +94,6 @@ class CustomizedPipelineTmaUmmaAsync { dim3 block_id_in_cluster = cute::block_id_in_cluster()) { // Calculate consumer mask if (params_.role == ThreadCategory::Consumer) { - auto cluster_layout = make_layout(cluster_shape); block_id_mask_ = detail::calculate_multicast_mask( cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); } @@ -104,7 +103,6 @@ class CustomizedPipelineTmaUmmaAsync { void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { // Calculate consumer mask dim3 block_id_in_cluster = cute::block_id_in_cluster(); - auto cluster_layout = make_layout(cluster_shape); if (mcast_direction == McastDirection::kRow) { block_id_mask_ = detail::calculate_multicast_mask( cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); @@ -215,7 +213,7 @@ class CustomizedPipelineTmaUmmaAsync { static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; // Consumer signalling Producer of completion - // Ensures all blocks in the Same Row and Column get notifed. + // Ensures all blocks in the Same Row and Column get notified. CUTLASS_DEVICE void umma_consumer_release(uint32_t stage, uint32_t skip) { detail::pipeline_check_is_consumer(params_.role); diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index afa8335316..f7e4ab9b53 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -73,7 +73,11 @@ __device__ __forceinline__ float* GetGlobalAmaxPtrByTensorId(MultiAmaxHadamardCa } __device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, int offset){ - // check the kernel args and get the corresponding id + // Check the kernel args and get the corresponding id + const int num_tensors = kernel_args_ptr->num_tensors; + if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) { + return num_tensors - 1; + } int tensor_id = 0; while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) { ++tensor_id; @@ -658,11 +662,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++accumulator_pipe_consumer_state; - // Cast data from FP32 to BF16 to FP32. - // auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - // auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; - // tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); - auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); CUTLASS_PRAGMA_UNROLL diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 26afaf8ae7..e4f13b49bc 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -22,7 +22,7 @@ #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "customized_pipeline.hpp" +#include "customized_pipeline.cuh" #include "cutlass/arch/barrier.h" #include "cutlass/arch/reg_reconfig.h" #include "cutlass/cluster_launch.hpp" @@ -77,13 +77,12 @@ struct MultiAmaxHadamardCastFusionArgs { __device__ __forceinline__ int GetGroupIdx(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, int offset) { - // check the kernel args and get the corresponding id - int group_idx = 0; - int num_tensors = kernel_args_ptr->num_tensors; - int boundary = kernel_args_ptr->split_sections_range[num_tensors]; - if (offset >= boundary) { + // Check the kernel args and get the corresponding id + const int num_tensors = kernel_args_ptr->num_tensors; + if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) { return num_tensors - 1; } + int group_idx = 0; while (kernel_args_ptr->split_sections_range[group_idx + 1] <= offset) { ++group_idx; } @@ -283,14 +282,14 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( uint32_t atomic_offset; cutlass::FastDivmodU64 divmod_tiles_in_m; - CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, uint32_t* atomic_tile_index_, uint32_t* smem_tile_counter) + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, uint32_t* atomic_tile_index, uint32_t* smem_tile_counter) : tiles_in_m(tiles_m), tiles_in_n(tiles_n), linear_idx(blockIdx.x), next_linear_idx(blockIdx.x), start_idx(blockIdx.x), k_tile_max(kmax), - atomic_tile_index_(atomic_tile_index_), + atomic_tile_index_(atomic_tile_index), smem_tile_counter(smem_tile_counter), atomic_offset(gridDim.x), divmod_tiles_in_m(uint64_t(tiles_m)) { @@ -460,8 +459,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( typename SchedPipeline::Params sched_pipeline_params; if (is_sched_warp) { sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; - } - else { + } else { sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; } sched_pipeline_params.producer_blockid = 0; diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 2a5fa53b8a..3e331547a1 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -478,11 +478,6 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ++accumulator_pipe_consumer_state; - // Cast data from FP32 to BF16 to FP32. - // auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - // auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; - // tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); - auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); CUTLASS_PRAGMA_UNROLL From c14b1563070bd1818f9040fceab34e9fd1d54325 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 03:26:53 +0000 Subject: [PATCH 28/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/cast.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 6927f8031c..aa9d800c7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -842,10 +842,10 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // and inconsistently implemented. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math) { - for (auto &config: quant_config_list) { + for (auto &config : quant_config_list) { config.set_use_fast_math(true); } - for (auto &config: quant_config_list_colwise) { + for (auto &config : quant_config_list_colwise) { config.set_use_fast_math(true); } } From 40ae64c80fbc1bec340b526c7445852285de8ce7 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 19 Dec 2025 03:57:50 +0000 Subject: [PATCH 29/36] Use TE license Signed-off-by: Tim Moon --- .../customized_pipeline.cuh | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/customized_pipeline.cuh b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh index c5d9e63138..b6f6799a49 100644 --- a/transformer_engine/common/hadamard_transform/customized_pipeline.cuh +++ b/transformer_engine/common/hadamard_transform/customized_pipeline.cuh @@ -1,34 +1,11 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ +#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ #include "cutlass/pipeline/sm100_pipeline.hpp" @@ -241,3 +218,5 @@ class CustomizedPipelineTmaUmmaAsync { }; } // namespace detail } // namespace cutlass + +#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_ From 15e1edb374a80c96f15238c479032abd85c905ab Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 20 Dec 2025 00:19:10 +0000 Subject: [PATCH 30/36] Fix bug where kernel is always launched on stream Signed-off-by: Tim Moon --- .../nvfp4/group_quantize_transpose_nvfp4.cuh | 5 +++- .../group_hadamard_transform_cast_fusion.cu | 17 +++++------ ...cast_col_hadamard_transform_cast_fusion.cu | 30 ++++++++----------- .../hadamard_transform_cast_fusion.cu | 13 ++++---- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index b8eec22945..f81ce9b0de 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -885,10 +885,13 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, NVTE_ERROR("2D quantization is not supported for group quantize transpose."); } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size) + ); kernel<<>>(tensor_map_input, tensor_map_output, scales_ptr, noop_ptr, rows, cols, scale_stride, rng_state, kernel_args); + NVTE_CHECK_CUDA(cudaGetLastError()); });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index f7e4ab9b53..209d0cf2ea 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -21,6 +21,7 @@ #include "common/util/cuda_runtime.h" #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" +#include "common/util/system.h" #include "common/utils.cuh" #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" @@ -31,9 +32,6 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -// include utils for get system env -#include "../util/system.h" - // clang-format off namespace transformer_engine { @@ -851,14 +849,12 @@ group_rht_gemm_ntt_w_sfc(int m, int n, kEnableStochasticRounding, kUseFastMath>; - bool status = cudaFuncSetAttribute(*kernel_ptr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) + ); - if (status != cudaSuccess) { - std::cerr << "Error: Failed to set Shared Memory size." << std::endl; - return; - } (*kernel_ptr) <<< dimGrid, dimBlock, smem_size, stream >>> (M, N, k_tile_size, cga_tile_shape, @@ -867,6 +863,7 @@ group_rht_gemm_ntt_w_sfc(int m, int n, sC, mma, *kernel_args_ptr, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); } // this function is used to wrap the group_rht_gemm_ntt_w_sfc function diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index e4f13b49bc..feb41233fe 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1264,30 +1264,24 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding, kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; - bool status_set_attr = - cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (status_set_attr != cudaSuccess) { - std::cerr << "Error: Failed to set Shared Memory size." << std::endl; - return; - } + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + ); + // Allocate workspace and set to zero void *tile_scheduler_workspace = nullptr; - cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream); - // reset the tile_scheduler_workspace to 0 - cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream); - cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream)); + + // Launch kernel + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; cutlass::Status status = cutlass::launch_kernel_on_cluster( params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, tile_scheduler_workspace, mma, rng_state); - // CUTE_CHECK_LAST(); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); - cudaFreeAsync(tile_scheduler_workspace, stream); - - if (status != cutlass::Status::kSuccess) { - std::cerr << "Error: Failed at kernel Launch" << std::endl; - return; - } + NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream)); } } // namespace diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 3e331547a1..716a0d197d 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -664,14 +664,12 @@ rht_gemm_ntt_w_sfc(int m, int n, kEnableStochasticRounding, kUseFastMath>; - bool status = cudaFuncSetAttribute(*kernel_ptr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) + ); - if (status != cudaSuccess) { - std::cerr << "Error: Failed to set Shared Memory size." << std::endl; - return; - } (*kernel_ptr) <<< dimGrid, dimBlock, smem_size, stream >>> (M, N, k_tile_size, cga_tile_shape, @@ -681,6 +679,7 @@ rht_gemm_ntt_w_sfc(int m, int n, SFC, mma, global_amax, rng_state); + NVTE_CHECK_CUDA(cudaGetLastError()); } // this function is used to wrap the rht_gemm_ntt_w_sfc function From 79cc660fad48f67d0035e314c3b3bc7e5883d818 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Dec 2025 00:20:02 +0000 Subject: [PATCH 31/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index f81ce9b0de..28b47e32d2 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -886,8 +886,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, } NVTE_CHECK_CUDA( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size) - ); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); kernel<<>>(tensor_map_input, tensor_map_output, scales_ptr, noop_ptr, rows, cols, scale_stride, rng_state, kernel_args); From 8534c38e6e94d8be1c60f06b8ca8dbdc56a8bb17 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 20 Dec 2025 03:44:11 +0000 Subject: [PATCH 32/36] Restore BF16 intermediate downcast in fused RHT-cast kernels Signed-off-by: Tim Moon --- .../group_hadamard_transform_cast_fusion.cu | 8 ++++ ...cast_col_hadamard_transform_cast_fusion.cu | 8 ++++ .../hadamard_transform_cast_fusion.cu | 8 ++++ .../custom_recipes/quantization_nvfp4.py | 38 +------------------ 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index 209d0cf2ea..d2bf16a460 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -660,6 +660,14 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++accumulator_pipe_consumer_state; + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with unfused + // kernels + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + } + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); CUTLASS_PRAGMA_UNROLL diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index feb41233fe..177bfc3b41 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -834,6 +834,14 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); ++accumulator_pipe_consumer_state; + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + } + auto compute_frgs = reinterpret_cast *>( tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 716a0d197d..2f76ea2cb2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -478,6 +478,14 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ++accumulator_pipe_consumer_state; + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with unfused + // kernels + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + } + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); CUTLASS_PRAGMA_UNROLL diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index f878edece6..b371ca4842 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -391,27 +391,7 @@ def _build_hadamard_matrix( h = sign_mat @ h return h.to(dtype) - def _supports_rht_cast_fusion(self, x: torch.Tensor) -> bool: - """ - Check if RHT cast fusion is supported for the input tensor. - - When RHT cast fusion is supported, there is no intermediate bf16 tensor for RHT(x) results, - which means that we can directly cast from FP32 to FP4 without any intermediate bf16 tensor. - - For example, if x.shape is (128, 128), then RHT cast fusion is supported. - If x.shape is (128, 127), then RHT cast fusion is not supported. - - This function is to simulate this behavior in the reference implementation for numerical correctness. - - Args: - x: The input tensor. - - Returns: - True if RHT cast fusion is supported, False otherwise. - """ - return x.dtype == torch.bfloat16 and x.shape[0] % 64 == 0 and x.shape[1] % 128 == 0 - - def _apply_rht(self, x: torch.Tensor, with_rht_cast_fusion: bool = False) -> torch.Tensor: + def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: """Apply randomized Hadamard transform without random signs (reference path). This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). @@ -435,12 +415,6 @@ def _apply_rht(self, x: torch.Tensor, with_rht_cast_fusion: bool = False) -> tor x_mat = x.contiguous().view(-1, rht_dim) # Random sign matrix is identity in this reference (no sign flipping) transform = H * scale - - # If RHT cast fusion is supported, we can directly cast from FP32 to FP4 without any intermediate bf16 tensor. - if with_rht_cast_fusion: - transform = transform.float() - x_mat = x_mat.float() - out = x_mat @ transform return out.view(original_shape) @@ -625,10 +599,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ ), "NVFP4 only supports 1x16 or 16x16 tile shape." # Prepare inputs once so we can reuse for both amax and quantization # Row-input will always be the original input. - with_rht_cast_fusion = self._supports_rht_cast_fusion(tensor) row_input = tensor col_input = ( - self._apply_rht(tensor.t().contiguous(), with_rht_cast_fusion) + self._apply_rht(tensor.t().contiguous()) if self.with_rht else tensor.t().contiguous() ) @@ -639,13 +612,6 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.columnwise_usage else global_amax_row ) - # currently the amax of RHT transform still has a fp32 -> bf16 -> fp32 round-trip - # this is to simulate that behaviour and we will remove this once the amax of RHT transform is fixed - if self.columnwise_usage and with_rht_cast_fusion: - # global_amax_col = global_amax_col.to(torch.bfloat16).float() - global_amax_col = ( - torch.max(torch.abs(col_input.bfloat16())).to(torch.float32).view(1) - ) transpose_scales = False From 57db30f21d0bb13f9cfeddf800e45565c694eb8e Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 19 Dec 2025 21:02:59 -0800 Subject: [PATCH 33/36] fix numerical test of grouped kernel Signed-off-by: Zhongbo Zhu --- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 177bfc3b41..539fb54230 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -840,6 +840,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); } auto compute_frgs = reinterpret_cast *>( From b258ca9f81e9193ceb01cfa675b89824fa4d89fa Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 19 Dec 2025 21:31:12 -0800 Subject: [PATCH 34/36] Make sure row-wise and col-wise quantization use different RNG seeds Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 539fb54230..75eb861ba4 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1041,7 +1041,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; if constexpr (kEnableStochasticRounding) { const size_t rng_sequence = global_thread_idx + k_tile * 512 + - scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; rng.init(rng_seed, rng_sequence, rng_offset); } CUTLASS_PRAGMA_UNROLL From 66ac7560b47db629167ecf986d2314ff73ed04c4 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 20 Dec 2025 05:35:25 +0000 Subject: [PATCH 35/36] Restore autoformatter Signed-off-by: Tim Moon --- .../group_hadamard_transform_cast_fusion.cu | 5 ----- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 5 ----- .../hadamard_transform/hadamard_transform_cast_fusion.cu | 3 --- 3 files changed, 13 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index d2bf16a460..fff4ed1fe9 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -21,7 +21,6 @@ #include "common/util/cuda_runtime.h" #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" -#include "common/util/system.h" #include "common/utils.cuh" #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" @@ -32,8 +31,6 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -// clang-format off - namespace transformer_engine { namespace detail { namespace { @@ -909,8 +906,6 @@ group_rht_gemm_ttt_wrapper(int m, int n, } // namespace } // namespace detail -// clang-format on - void group_hadamard_transform_cast_fusion_columnwise( const Tensor &input_, std::vector &output_list, const size_t *split_sections, size_t num_tensors, const Tensor &hadamard_matrix_, QuantizationConfig &quant_config, diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 75eb861ba4..c5d4afbc91 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -40,11 +40,6 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -// include utils for get system env -#include "../util/system.h" - -// clang-format off - namespace transformer_engine { namespace detail { namespace { diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 2f76ea2cb2..11325041ae 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -31,9 +31,6 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -// include utils for get system env -#include "../util/system.h" - // clang-format off namespace transformer_engine { From 376687ca3956205362ae60c7b7d0b89c970505f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Dec 2025 05:37:51 +0000 Subject: [PATCH 36/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../group_hadamard_transform_cast_fusion.cu | 663 ++++++++---------- ...cast_col_hadamard_transform_cast_fusion.cu | 247 ++++--- 2 files changed, 432 insertions(+), 478 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index fff4ed1fe9..6e071ec79f 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -36,18 +36,19 @@ namespace detail { namespace { using namespace cute; -using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute:: + Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor using Stride2D = cute::Stride>; constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed struct MultiAmaxHadamardCastFusionArgs { // (output) Amax buffer for pre-RHT amax buffer - void* global_amax_list[kMaxTensorsPerKernel]; + void *global_amax_list[kMaxTensorsPerKernel]; // output C pointers for each tensor - void* output_colwise_list[kMaxTensorsPerKernel]; + void *output_colwise_list[kMaxTensorsPerKernel]; // output scale inverse pointers for each tensor - void* output_colwise_scale_inv_list[kMaxTensorsPerKernel]; + void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; // split sections of each tensor of input int split_sections[kMaxTensorsPerKernel]; // Prefix sum (with leading zero) of split_sections of each tensor of input @@ -58,16 +59,17 @@ struct MultiAmaxHadamardCastFusionArgs { int num_tensors; }; - -__device__ __forceinline__ float* GetGlobalAmaxPtrByTensorId(MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, int tensor_id){ +__device__ __forceinline__ float *GetGlobalAmaxPtrByTensorId( + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, int tensor_id) { // directly returns the global amax pointer by tensor id if (tensor_id < 0 || tensor_id >= kernel_args_ptr->num_tensors) { return nullptr; } - return reinterpret_cast(kernel_args_ptr->global_amax_list[tensor_id]); + return reinterpret_cast(kernel_args_ptr->global_amax_list[tensor_id]); } -__device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, int offset){ +__device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + int offset) { // Check the kernel args and get the corresponding id const int num_tensors = kernel_args_ptr->num_tensors; if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) { @@ -86,27 +88,23 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_ constexpr float kFP4E2M1Max = 6.0f; // If scale is infinity, return max value of float32 float global_encode_scale = cutlass::minimum_with_nan_propagation{}( - kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); // If global amax is 0 or infinity, return 1 return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; } -template +template struct SharedStorage { static constexpr int AccumulatorPipelineStageCount = 16; using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; - using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); - using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< - MainloopPipelineStageCount, - Shape<_1,_1,_1>, - AtomThrShapeMNK>; + using MainloopPipeline = + cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; alignas(16) AccumulatorPipelineStorage accumulator; @@ -119,42 +117,43 @@ struct SharedStorage { cute::array_aligned> smem_A; cute::array_aligned> smem_B; } tensors; - }; CUTLASS_DEVICE -cutlass::Array -StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { using result_type = cutlass::Array; result_type output; constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; if constexpr (has_rs) { auto output_ptr = reinterpret_cast(&output); - asm volatile( \ - "{\n" \ - "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ - "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ - "}" \ - : "=h"(output_ptr[0]), - "=h"(output_ptr[1]) - : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), - "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), - "r"(rbits[0]), "r"(rbits[1])); + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); } else { - NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); } return output; } CUTLASS_DEVICE -cutlass::Array -StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { +cutlass::Array StochasticNumericConverter( + cutlass::Array const &input, cutlass::Array const *rbits) { using result_type = cutlass::Array; result_type output; - cutlass::Array *result_ptr = reinterpret_cast *>(&output); - cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); - cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(rbits); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 2; i++) { result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); @@ -162,42 +161,31 @@ StochasticNumericConverter(cutlass::Array const &input, cutlass::Arra return output; } -template -__global__ static -void -group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, - TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, - TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, - CSmemLayout, - TiledMMA mma, - MultiAmaxHadamardCastFusionArgs kernel_args, - const size_t* rng_state) -{ +template +__global__ static void group_rht_gemm_device( + MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, TA const *A, AStride dA, + ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, TB const *B, BStride dB, + BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, CSmemLayout, TiledMMA mma, + MultiAmaxHadamardCastFusionArgs kernel_args, const size_t *rng_state) { using namespace cute; using X = Underscore; // static constexpr bool kApplyStochasticRounding = true; using ElementAccumulator = float; static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; - static constexpr uint32_t kTmaTransactionBytes = - cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); static constexpr int kTmaRhtTensorTransactionBytes = - cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); static constexpr int AccumulatorPipelineStageCount = 16; static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); - using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< - MainloopPipelineStageCount, - Shape<_1,_1,_1>, - AtomThrShapeMNK>; + using MainloopPipeline = + cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; using MainloopPipelineState = typename MainloopPipeline::PipelineState; using TmemAllocator = cute::TMEM::Allocator1Sm; @@ -210,124 +198,103 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til CUTE_STATIC_ASSERT(is_static::value); // Represent the full tensors - Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); - Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); - - using TensorC = decltype( - make_tensor( - subbyte_iterator(recast_ptr(nullptr)), // engine - make_shape(int{}, int{}), // (M, N_i) - Stride2D{} // stride (dM, dN) - ) - ); - - using TensorSFC = decltype( - make_tensor( - make_gmem_ptr(recast_ptr(nullptr)), - make_layout( - make_shape( - int{}, // M - make_shape( make_shape(Int<16>{}, _4{}), // (16, 4) - int{} ) // n_tiles = split / 64 - ), - make_stride( - int{}, // dM = (split / 16) - make_stride( make_stride(_0{}, _1{}), // inner (16,4) layout - _4{} ) // tiles stride - ) - ) - ) - ); - - auto cluster_shape = Shape< _1, _1, _1>{}; + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16, 16)); + + using TensorC = decltype(make_tensor(subbyte_iterator(recast_ptr(nullptr)), // engine + make_shape(int{}, int{}), // (M, N_i) + Stride2D{} // stride (dM, dN) + )); + + using TensorSFC = decltype(make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(make_shape(int{}, // M + make_shape(make_shape(Int<16>{}, _4{}), // (16, 4) + int{}) // n_tiles = split / 64 + ), + make_stride(int{}, // dM = (split / 16) + make_stride(make_stride(_0{}, _1{}), // inner (16,4) layout + _4{}) // tiles stride + )))); + + auto cluster_shape = Shape<_1, _1, _1>{}; // Get the appropriate blocks for this Cluster dim3 cluster_coord_in_grid = cluster_id_in_grid(); // Total number of k-tiles - const int K_TILE_MAX = min(N, K) / 64; + const int K_TILE_MAX = min(N, K) / 64; uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); uint32_t tiles_in_n = (N + 64 - 1) / 64; uint32_t linear_tile_idx = blockIdx.x; uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; - auto mainloop_tiler = Shape<_128,_16,_64>{}; - auto epilogue_tiler = Shape<_128,_64,_64>{}; - Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); - Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + auto mainloop_tiler = Shape<_128, _16, _64>{}; + auto epilogue_tiler = Shape<_128, _64, _64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) // Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - using TensorGC = decltype( - local_tile( - std::declval(), - decltype(epilogue_tiler){}, - make_coord(_, _, _), - Step<_1, _1, X>{} - ) - ); - - using TensorGSFC = decltype( - local_tile( - std::declval(), - decltype(epilogue_tiler){}, - make_coord(_, _, _), - Step<_1, _1, X>{} - ) - ); + using TensorGC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + make_coord(_, _, _), Step<_1, _1, X>{})); + + using TensorGSFC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + make_coord(_, _, _), Step<_1, _1, X>{})); // Allocate SMEM extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) // // MMA: Define C accumulators and A/B partitioning // int block_rank_in_cluster = cute::block_rank_in_cluster(); - ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) - auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, - Layout>{}); + auto mma_epilogue = make_tiled_mma( + SM100_MMA_F16BF16_SS{}, + Layout>{}); ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); - using TiledMmaEpilogue = decltype(mma_epilogue); Tensor tCgA = thr_mma.partition_A(gA_mk); // Allocate "fragments" -- these are actually umma smem descriptors - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) - auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); - auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0, 2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0, 2>(epilogue_tiler)); - auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, - Int{})); + auto bulk_tmem_mma = + TiledMMA::make_fragment_C(append(acc_shape_mma, Int{})); - auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, - Int{})); + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C( + append(acc_shape_epilogue, Int{})); TmemAllocator tmem_allocator{}; - cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + 32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); - Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_mnk = make_layout(cluster_shape); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); - auto [tAgA, tAsA] = tma_partition(tma_load_a, - get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); - auto [tBgB, tBsB] = tma_partition(tma_load_b, - get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), - group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); @@ -355,22 +322,21 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; mainloop_pipeline_params.initializing_warp = 0; - MainloopPipeline mainloop_pipeline(shared_storage.mainloop, - mainloop_pipeline_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); - - using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; AccumulatorPipelineState accumulator_pipe_consumer_state; - AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); typename AccumulatorPipeline::Params accumulator_pipeline_params; if (is_mma_warp) { @@ -383,11 +349,10 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til accumulator_pipeline_params.producer_arv_count = 1; accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; accumulator_pipeline_params.initializing_warp = 1; - AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, - accumulator_pipeline_params, + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, cluster_shape, cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation + cute::true_type{}); // Delay mask calculation if (warp_idx == 2 && elect_one_sync()) { cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); @@ -397,17 +362,19 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if (is_dma_warp) { if (elect_one_sync()) { - cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); - copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); } do { bool is_first_wave = linear_tile_idx == blockIdx.x; uint32_t skip_wait = is_first_wave; - auto tAgA_mk = tAgA(_,tile_idx_m,_); + auto tAgA_mk = tAgA(_, tile_idx_m, _); int k_tile = 0; - auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); - + auto barrier_token = + mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); CUTE_NO_UNROLL while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { @@ -416,12 +383,15 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); int write_stage = mainloop_pipe_producer_state.index(); ++mainloop_pipe_producer_state; - barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + barrier_token = + mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); if (cute::elect_one_sync()) { - copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); } } linear_tile_idx += gridDim.x; @@ -441,22 +411,22 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); do { uint32_t skip_wait = K_TILE_MAX <= 0; - auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); CUTE_NO_UNROLL - for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) - { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n;) { mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); int read_stage = mainloop_pipe_consumer_state.index(); - auto tCrA_mk = tCrA(_,_,_,read_stage); - auto tCrB_nk = tCrB(_,_,0,0); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); CUTE_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) - { + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) { accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); CUTE_UNROLL for (int i = 0; i < 4; i++) { - auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); - gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + auto accumulators = + bulk_tmem_mma(_, _, _, accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_, _, k_block * 4 + i), tCrB_nk, accumulators); } accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); @@ -466,7 +436,8 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til ++mainloop_pipe_consumer_state; ++k_tile; skip_wait = k_tile >= K_TILE_MAX; - barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } @@ -485,9 +456,10 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til bulk_tmem_epilogue.data() = tmem_base_ptr; int thread_idx = threadIdx.x % 128; - auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); - auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); auto thr_r2g = tiled_r2g.get_slice(thread_idx); // NVFP4 non-E8 recipe constants and global scales @@ -496,46 +468,33 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // get global amax pointer int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); - float* global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); + float *global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); - TC* cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); - TSFC* cur_output_colwise_scale_inv_ptr = reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + TC *cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + TSFC *cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); int cur_output_colwise_n = kernel_args.split_sections[tensor_id]; - TensorC cur_mC = cute::make_tensor( - cute::subbyte_iterator(cur_output_colwise_ptr), - cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) - kernel_args.output_stride2d_list[tensor_id] - ); - - auto cur_sfc_shape = make_shape( - M, - make_shape( make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64) - ); + TensorC cur_mC = + cute::make_tensor(cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); - auto cur_sfc_stride = make_stride( - cur_output_colwise_n / 16, - make_stride( make_stride(_0{}, _1{}), _4{} ) - ); + auto cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); + auto cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); - TensorSFC cur_mSFC = cute::make_tensor( - make_gmem_ptr(cur_output_colwise_scale_inv_ptr), - make_layout(cur_sfc_shape, cur_sfc_stride) - ); + TensorSFC cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); - TensorGC cur_gC_mn = local_tile( - cur_mC, - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N) - ); + TensorGC cur_gC_mn = + local_tile(cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); TensorGSFC cur_gSFC_mn = local_tile( - cur_mSFC, - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N-like) ); Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); @@ -572,44 +531,32 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til global_decode_scale = 1.0f / global_encode_scale; tensor_id = new_tensor_id; // went through the cute operations to update the local tensors - cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); - cur_output_colwise_scale_inv_ptr = reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + cur_output_colwise_ptr = + reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); cur_output_colwise_n = kernel_args.split_sections[tensor_id]; cur_mC = cute::make_tensor( - cute::subbyte_iterator(cur_output_colwise_ptr), - cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) - kernel_args.output_stride2d_list[tensor_id] - ); + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); - cur_sfc_shape = make_shape( - M, - make_shape( make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64) - ); + cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); - cur_sfc_stride = make_stride( - cur_output_colwise_n / 16, - make_stride( make_stride(_0{}, _1{}), _4{} ) - ); + cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); - - cur_mSFC = cute::make_tensor( - make_gmem_ptr(cur_output_colwise_scale_inv_ptr), - make_layout(cur_sfc_shape, cur_sfc_stride) - ); + cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); cur_gC_mn = local_tile( - cur_mC, - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N) + cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) ); - cur_gSFC_mn = local_tile( - cur_mSFC, - epilogue_tiler, - make_coord(_, _, _), - Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + cur_gSFC_mn = local_tile(cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} + // (BLK_M, BLK_N-like) ); tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); @@ -618,27 +565,28 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; - Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,local_tile_idx_n); - Tensor tCgSFC_mn = cur_gSFC_mn(_,_,tile_idx_m,local_tile_idx_n); + Tensor tCgC_mn = tCgC(_, _, _, tile_idx_m, local_tile_idx_n); + Tensor tCgSFC_mn = cur_gSFC_mn(_, _, tile_idx_m, local_tile_idx_n); accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); - auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); - Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + auto tCtC = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tTR_rAcc = + make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) Tensor tDrC = make_tensor(shape(tDgC)); - Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); Tensor tDrC_frag = recast>(coalesce(tDrC)); Tensor src = thr_r2g.retile_S(tDrC); Tensor dst = thr_r2g.retile_D(tDgC); - Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( - make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), - make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) - )); + Tensor tCgSFC = make_tensor( + tCgSFC_mn.data(), make_layout(make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}))); Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); Tensor tDrSFC = make_tensor(shape(tDgSFC)); @@ -646,7 +594,9 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til static constexpr int NumVecs = size(tDgC) / VectorSize; Tensor tC_rRowSFD_frg = recast>(tDrSFC); - cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; cutlass::Array vec_maxs; cutlass::Array pvscales; // TMEM_LOAD @@ -660,13 +610,18 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if constexpr (!kUseFastMath) { // Downcast to BF16 for bit-wise compatibility with unfused // kernels - auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); } - auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); - auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); CUTLASS_PRAGMA_UNROLL for (int v = 0; v < NumVecs; v++) { vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); @@ -674,29 +629,36 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til if constexpr (kUseFastMath) { // Fast math: multiply with precomputed reciprocal - pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); } else { // Accurate math: perform division - pvscales = cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); } - auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); tC_rRowSFD_frg(_0{}) = pvscales_cvted; - auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); - auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); cutlass::Array acc_scales; if constexpr (kUseFastMath) { // Fast math: compute approximate reciprocal - acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { // Accurate math: compute reciprocal with division - acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + acc_scales = + cutlass::divides>{}(1.0, qpvscale_scaled); } // Initialize RNG for tile - const size_t rng_sequence - = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; rng.init(rng_seed, rng_sequence, rng_offset); @@ -704,18 +666,19 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til CUTLASS_PRAGMA_UNROLL for (int v = 0; v < NumVecs; v++) { - auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); // auto acc_scale = acc_scales[v]; if constexpr (kEnableStochasticRounding) { random_uint4 = rng.generate4(); output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs[v], - acc_scale - ), - reinterpret_cast*>(&random_uint4)); + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + reinterpret_cast *>(&random_uint4)); } else { - output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); } } @@ -724,7 +687,6 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC); copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); - } linear_tile_idx += gridDim.x; tile_idx_m = linear_tile_idx % tiles_in_m; @@ -738,40 +700,34 @@ group_rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_til // B: 16 x 16: row-major // C: m x n: row-major // SFC: m x (n/16): row-major -template -void -group_rht_gemm_ntt_w_sfc(int m, int n, - TA const* A, - TB const* B, - MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, - const size_t* rng_state, - uint32_t sm_count, - cudaStream_t stream, - int k_tile_size = 2048) -{ +template +void group_rht_gemm_ntt_w_sfc(int m, int n, TA const *A, TB const *B, + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + const size_t *rng_state, uint32_t sm_count, cudaStream_t stream, + int k_tile_size = 2048) { using namespace cute; - // Define shapes (dynamic) auto M = static_cast(m); auto N = static_cast(n); // Define strides (mixed) - auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) auto dB = make_stride(Int<1>{}, 16); // (dN,dK) for (size_t i = 0; i < kernel_args_ptr->num_tensors; ++i) { - kernel_args_ptr->output_stride2d_list[i] = make_stride(kernel_args_ptr->split_sections[i], Int<1>{}); + kernel_args_ptr->output_stride2d_list[i] = + make_stride(kernel_args_ptr->split_sections[i], Int<1>{}); } - auto cga_shape = Shape< _1, _1, _1>{}; - auto cga_tile_shape = Shape<_128,_16,_16>{}; - auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + auto cga_shape = Shape<_1, _1, _1>{}; + auto cga_tile_shape = Shape<_128, _16, _16>{}; + auto cluster_tile_mainloop = Shape<_128, _16, _64>{}; // Construct the MMA - auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, - Layout>{}); + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS{}, + Layout>{}); // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} @@ -780,60 +736,64 @@ group_rht_gemm_ntt_w_sfc(int m, int n, CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); // Determine the A and B shapes - auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); using TiledMma = decltype(mma); using AtomThrID = typename TiledMma::AtomThrID; - using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); - using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_M = decltype(shape_div( + shape<0>(cga_tile_shape), + shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cga_tile_shape), + shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); - using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< - cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); - auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); - using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + auto mma_shape_A = partition_shape_A( + mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = + decltype(shape_div(shape<0>(cluster_tile_mainloop), + shape_div(shape<0>(cluster_tile_mainloop), + size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< - cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); // Define the smem layouts (static) // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory - constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes - constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); - constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; - auto sP = Int{}; // SMEM pipelines - auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) - auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) - auto sC = Layout<_1>{}; // XXX Dummy + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, + append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, + append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy // Create GMEM tensors - Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) - Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + Tensor tensorA = make_tensor(A, make_layout(make_shape(M, N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16, 16), dB)); // (16,16) // Create the TiledCopy - auto tma_load_a = make_tma_copy_A_sm100( - SM90_TMA_LOAD{}, - tensorA, - sA(_,_,_,0), - cluster_tile_mainloop, - mma); - auto tma_load_b = make_tma_copy_B_sm100( - SM90_TMA_LOAD{}, - tensorB, - sB(_,_,_,0), - cga_tile_shape, - mma); + auto tma_load_a = + make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma); + auto tma_load_b = + make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cga_tile_shape, mma); // Assert checks on tile sizes -- no predication - NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, - "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); - NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, - "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), - " but got ", N, "."); + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, "Inner dimension must be divisible by ", + static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, "Outer dimension must be divisible by ", + 4 * static_cast(size<1>(cga_tile_shape)), " but got ", N, "."); uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); @@ -844,46 +804,28 @@ group_rht_gemm_ntt_w_sfc(int m, int n, dim3 dimGrid(tiles, 1, 1); int smem_size = sizeof(SharedStorage); - auto* kernel_ptr = &group_rht_gemm_device< - decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), - TA, decltype(dA), decltype(sA), decltype(tma_load_a), - TB, decltype(dB), decltype(sB), decltype(tma_load_b), - TC, Stride2D, decltype(sC), - TSFC, - decltype(mma), - kEnableStochasticRounding, - kUseFastMath>; + auto *kernel_ptr = &group_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), TA, decltype(dA), + decltype(sA), decltype(tma_load_a), TB, decltype(dB), decltype(sB), decltype(tma_load_b), TC, + Stride2D, decltype(sC), TSFC, decltype(mma), kEnableStochasticRounding, kUseFastMath>; NVTE_CHECK_CUDA( - cudaFuncSetAttribute(*kernel_ptr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size) - ); - - (*kernel_ptr) - <<< dimGrid, dimBlock, smem_size, stream >>> - (M, N, k_tile_size, cga_tile_shape, - A, dA, sA, tma_load_a, - B, dB, sB, tma_load_b, - sC, mma, - *kernel_args_ptr, - rng_state); + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + (*kernel_ptr)<<>>(M, N, k_tile_size, cga_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, sC, + mma, *kernel_args_ptr, rng_state); NVTE_CHECK_CUDA(cudaGetLastError()); } // this function is used to wrap the group_rht_gemm_ntt_w_sfc function // to transpose the input tensor A -template -void -group_rht_gemm_ttt_wrapper(int m, int n, - TA const* A, - TB const* B, - MultiAmaxHadamardCastFusionArgs* kernel_args_ptr, - const size_t* rng_state, - uint32_t sm_count, - cudaStream_t stream, - int k_tile_size = 1024) -{ +template +void group_rht_gemm_ttt_wrapper(int m, int n, TA const *A, TB const *B, + MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, + const size_t *rng_state, uint32_t sm_count, cudaStream_t stream, + int k_tile_size = 1024) { // in addition to transpose the input tensor A // we also need to reshape m, n to at best // ultilize as many SMs as possible while keeping @@ -895,12 +837,7 @@ group_rht_gemm_ttt_wrapper(int m, int n, // C: n x m: row-major // SFC: n x (m/16): row-major group_rht_gemm_ntt_w_sfc( - n, m, - A, B, - kernel_args_ptr, - rng_state, - sm_count, stream, - k_tile_size); + n, m, A, B, kernel_args_ptr, rng_state, sm_count, stream, k_tile_size); } } // namespace diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index c5d4afbc91..3932b328ae 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -49,7 +49,6 @@ using namespace cute; // Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor using cute::Tensor; - constexpr int kMaxTensorsPerKernel = 64; struct MultiAmaxHadamardCastFusionArgs { @@ -101,7 +100,8 @@ cutlass::Array StochasticNumericConverterBase( : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); } else { - NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); } return output; @@ -179,17 +179,16 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, MultiAmaxHadamardCastFusionArgs args, - uint32_t* tile_scheduler_workspace, - TiledMMA mma, - const size_t *rng_state) { + uint32_t *tile_scheduler_workspace, TiledMMA mma, const size_t *rng_state) { using namespace cute; // Abort immediately if compilation is not supported constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; if constexpr (!is_blackwell_arch) { - NVTE_DEVICE_ERROR("group_row_col_rht_gemm_device is only supported on Blackwell " - "with architecture-specific compilation. " - "Try recompiling with sm_100a or similar."); + NVTE_DEVICE_ERROR( + "group_row_col_rht_gemm_device is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); return; } static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, @@ -241,7 +240,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( using SchedThrottlePipeline = cutlass::PipelineAsync; using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; - static_assert(ClusterShape{} == Shape<_1,_1,_1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); using TmemAllocator = cute::TMEM::Allocator1Sm; static int constexpr VectorSize = RhtTensorSize; @@ -272,22 +271,23 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( uint32_t tile_m_idx = 0; uint32_t tile_n_idx = 0; int k_tile_max = 0; - uint32_t* atomic_tile_index_; - uint32_t* smem_tile_counter; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; uint32_t atomic_offset; cutlass::FastDivmodU64 divmod_tiles_in_m; - CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, uint32_t* atomic_tile_index, uint32_t* smem_tile_counter) - : tiles_in_m(tiles_m), - tiles_in_n(tiles_n), - linear_idx(blockIdx.x), - next_linear_idx(blockIdx.x), - start_idx(blockIdx.x), - k_tile_max(kmax), - atomic_tile_index_(atomic_tile_index), - smem_tile_counter(smem_tile_counter), - atomic_offset(gridDim.x), - divmod_tiles_in_m(uint64_t(tiles_m)) { + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { update_tile_idx(); } CUTLASS_DEVICE void update_tile_idx() { @@ -296,18 +296,15 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( tile_m_idx = static_cast(r); tile_n_idx = static_cast(q) * uint32_t(k_tile_max); } - CUTLASS_DEVICE uint32_t tile_m() const { - return tile_m_idx; - } - CUTLASS_DEVICE uint32_t tile_n_base() const { - return tile_n_idx; - } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } CUTLASS_DEVICE bool is_valid() const { - return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), cute::make_coord(tiles_in_m, tiles_in_n)); + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); } CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } @@ -315,20 +312,22 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } // Fetch a new tile_id using atomics. - CUTLASS_DEVICE uint32_t fetch_tile_id_counter( int pred ) { - uint32_t tile_id_counter = 0; - asm volatile( "{\n\t" - ".reg .pred p;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p atom.global.add.u32 %0, [%1], 1; \n\t" - "}" - : "=r"( tile_id_counter ) - : "l"( atomic_tile_index_ ), "r"( pred ) ); - - return tile_id_counter; + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); + + return tile_id_counter; } - CUTLASS_DEVICE auto fetch_next_work(SchedPipeline& sched_pipeline, SchedPipelineState sched_pipeline_consumer_state) { + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { sched_pipeline.consumer_wait(sched_pipeline_consumer_state); next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; cutlass::arch::fence_view_async_shared(); @@ -336,13 +335,15 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( return; } - CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline& sched_pipeline, SchedPipelineState sched_pipeline_producer_state) { + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); // Wait for clcID buffer to become empty with a flipped phase sched_pipeline.producer_acquire(sched_pipeline_producer_state); auto is_leading_thread = cute::elect_one_sync(); uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); if (is_leading_thread) { cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); } @@ -367,9 +368,11 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( // Compute the number of tiles in M and N after tiling and assign scheduler uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); - uint32_t tiles_in_n = uint32_t(size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); + uint32_t tiles_in_n = uint32_t( + size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); - TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, shared_storage.atomic_tile_counter); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); int block_rank_in_cluster = cute::block_rank_in_cluster(); @@ -448,8 +451,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; accumulator_pipeline_params.initializing_warp = 1; AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, - cluster_shape, - AccumulatorPipelineInitBarriers{}, + cluster_shape, AccumulatorPipelineInitBarriers{}, cute::true_type{}); // Delay mask calculation typename SchedPipeline::Params sched_pipeline_params; if (is_sched_warp) { @@ -459,13 +461,15 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } sched_pipeline_params.producer_blockid = 0; sched_pipeline_params.producer_arv_count = 1; - sched_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * - (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); sched_pipeline_params.transaction_bytes = sizeof(uint32_t); sched_pipeline_params.initializing_warp = 3; SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); SchedPipelineState sched_pipeline_consumer_state; - SchedPipelineState sched_pipeline_producer_state = cutlass::make_producer_start_state(); + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; if (is_dma_warp) { @@ -479,9 +483,11 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( sched_throttle_pipeline_params.dst_blockid = 0; sched_throttle_pipeline_params.initializing_warp = 4; - SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, sched_throttle_pipeline_params); + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; - SchedThrottlePipelineState sched_pipeline_throttle_producer_state = cutlass::make_producer_start_state(); + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); if (warp_idx == 2 && elect_one_sync()) { cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); @@ -511,8 +517,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx // Partition global to local fragments for A and B - Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) Layout cta_layout_mnk = make_layout(cluster_shape); Layout cta_layout_vmnk = @@ -599,13 +605,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( do { uint32_t skip_wait = K_TILE_MAX <= 0; - auto barrier_token = mainloop_pipeline.consumer_try_wait( - mainloop_pipe_consumer_state, - skip_wait); + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); int read_stage = mainloop_pipe_consumer_state.index(); auto tCrA_mk = tCrA(_, _, _, read_stage); @@ -646,7 +652,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); ++sched_pipeline_throttle_consumer_state; - sched_pipeline_producer_state = scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); @@ -745,8 +752,10 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ++k_tile) { - int global_tile_n_offset = (scheduler.tile_n_base() +k_tile) * size<1>(epilogue_tiler); + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); @@ -832,8 +841,12 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( if constexpr (!kUseFastMath) { // Downcast to BF16 for bit-wise compatibility with // unfused kernels - auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); } @@ -1096,13 +1109,13 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); - }while (scheduler.is_valid()); + } while (scheduler.is_valid()); } } else { cutlass::arch::warpgroup_reg_dealloc<32>(); } -} // NOLINT(readability/fn_size) +} // NOLINT(readability/fn_size) template ; auto cluster_shape = ClusterShape{}; - auto cluster_tile_shape = Shape<_128,Int,Int>{}; + auto cluster_tile_shape = Shape<_128, Int, Int>{}; auto cluster_tile_mainloop = Shape<_128, Int, _128>{}; // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles - static int constexpr EpilogueUnrollFactor = size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); // Construct the MMA - auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), - UMMA::Major::MN, UMMA::Major::MN>{}, - Layout>{}); - + auto mma = make_tiled_mma( + SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); // Assert that the TiledMMA uses all CTAs in the CGA. CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); // Determine the A and B shapes - auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + auto mma_shape_B = + partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); using TiledMma = decltype(mma); using AtomThrID = typename TiledMma::AtomThrID; - using SmemShape_M = decltype(shape_div(shape<0>(cluster_tile_shape), shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); - using SmemShape_N = decltype(shape_div(shape<1>(cluster_tile_shape), shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_M = decltype(shape_div( + shape<0>(cluster_tile_shape), + shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div( + shape<1>(cluster_tile_shape), + shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); using SmemLayoutAtomB = @@ -1218,10 +1236,12 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>, Shape<_1, _1, _1>>::SharedStorage); - static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount; - static int constexpr SchedulerThrottlePipelineBytes = sizeof(typename cutlass::PipelineAsync::SharedStorage); - static int constexpr SchedulerPipelineBytes = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + static int constexpr SchedulerThrottlePipelineBytes = + sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr SchedulerPipelineBytes = + sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); @@ -1232,8 +1252,10 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes static int constexpr kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; - static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + SchedulerPipelineBytes + - TmemBasePtrsBytes + TmemDeallocBytes+BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes + + SchedulerPipelineBytes + TmemBasePtrsBytes + + TmemDeallocBytes + BTensorBytes + + AccPipelineBytes; // Reserve for barriers and other uses static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; auto sP = Int{}; // SMEM pipelines @@ -1252,7 +1274,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz assert(M % size<0>(cluster_tile_shape) == 0); assert(N % size<1>(cluster_tile_shape) == 0); - dim3 dimBlock(512); dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); dim3 dimGrid(sm_count, 1, 1); @@ -1270,8 +1291,7 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>; NVTE_CHECK_CUDA( - cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - ); + cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Allocate workspace and set to zero void *tile_scheduler_workspace = nullptr; @@ -1281,8 +1301,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz // Launch kernel cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; cutlass::Status status = cutlass::launch_kernel_on_cluster( - params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, sA, - tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, tile_scheduler_workspace, mma, rng_state); + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA, + sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args, + tile_scheduler_workspace, mma, rng_state); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); @@ -1418,38 +1439,34 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector( - /*packed_sequence_length=*/m, /*hidden_size=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*QA=*/reinterpret_cast(rowwise_data_base_ptr), - /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), - /*args=*/kernel_args, - /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size); - } else { - NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", - kEnableRhtColQuant, ", kEnableRowQuant=", - kEnableRowQuant, ")."); - } - - ););););); + all_has_row_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::group_row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>( + /*packed_sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*QA=*/reinterpret_cast(rowwise_data_base_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_base_ptr), + /*args=*/kernel_args, + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); } } // namespace transformer_engine